rig_bedrock/types/
message.rs

1use aws_sdk_bedrockruntime::types as aws_bedrock;
2
3use rig::{
4    OneOrMany,
5    completion::CompletionError,
6    message::{AssistantContent, Message, UserContent},
7};
8
9use super::{assistant_content::RigAssistantContent, user_content::RigUserContent};
10
11pub struct RigMessage(pub Message);
12
13impl TryFrom<RigMessage> for aws_bedrock::Message {
14    type Error = CompletionError;
15
16    fn try_from(value: RigMessage) -> Result<Self, Self::Error> {
17        let result = match value.0 {
18            Message::User { content } => {
19                let message_content = content
20                    .into_iter()
21                    .map(|user_content| RigUserContent(user_content).try_into())
22                    .collect::<Result<Vec<Vec<_>>, _>>()
23                    .map_err(|e| CompletionError::RequestError(Box::new(e)))
24                    .map(|nested| nested.into_iter().flatten().collect())?;
25
26                aws_bedrock::Message::builder()
27                    .role(aws_bedrock::ConversationRole::User)
28                    .set_content(Some(message_content))
29                    .build()
30                    .map_err(|e| CompletionError::RequestError(Box::new(e)))?
31            }
32            Message::Assistant { content, .. } => aws_bedrock::Message::builder()
33                .role(aws_bedrock::ConversationRole::Assistant)
34                .set_content(Some(
35                    content
36                        .into_iter()
37                        .map(|content| RigAssistantContent(content).try_into())
38                        .collect::<Result<Vec<aws_bedrock::ContentBlock>, _>>()?,
39                ))
40                .build()
41                .map_err(|e| CompletionError::RequestError(Box::new(e)))?,
42        };
43        Ok(result)
44    }
45}
46
47impl TryFrom<aws_bedrock::Message> for RigMessage {
48    type Error = CompletionError;
49
50    fn try_from(message: aws_bedrock::Message) -> Result<Self, Self::Error> {
51        match message.role {
52            aws_bedrock::ConversationRole::Assistant => {
53                let assistant_content = message
54                    .content
55                    .into_iter()
56                    .map(|c| c.try_into())
57                    .collect::<Result<Vec<RigAssistantContent>, _>>()?
58                    .into_iter()
59                    .map(|rig_assistant_content| rig_assistant_content.0)
60                    .collect::<Vec<AssistantContent>>();
61
62                let content = OneOrMany::many(assistant_content)
63                    .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
64
65                Ok(RigMessage(Message::Assistant { content, id: None }))
66            }
67            aws_bedrock::ConversationRole::User => {
68                let user_content = message
69                    .content
70                    .into_iter()
71                    .map(|c| c.try_into())
72                    .collect::<Result<Vec<RigUserContent>, _>>()?
73                    .into_iter()
74                    .map(|user_content| user_content.0)
75                    .collect::<Vec<UserContent>>();
76
77                let content = OneOrMany::many(user_content)
78                    .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
79                Ok(RigMessage(Message::User { content }))
80            }
81            _ => Err(CompletionError::ProviderError(
82                "AWS Bedrock returned unsupported ConversationRole".into(),
83            )),
84        }
85    }
86}
87
88impl TryFrom<super::converse_output::Message> for RigMessage {
89    type Error = CompletionError;
90
91    fn try_from(message: super::converse_output::Message) -> Result<Self, Self::Error> {
92        let message = aws_bedrock::Message::try_from(message)
93            .map_err(|x| CompletionError::ProviderError(format!("Type conversion error: {x}")))?;
94
95        Self::try_from(message)
96    }
97}
98
99#[cfg(test)]
100mod tests {
101    use crate::types::message::RigMessage;
102    use aws_sdk_bedrockruntime::types as aws_bedrock;
103    use rig::{
104        OneOrMany,
105        message::{Message, UserContent},
106    };
107
108    #[test]
109    fn message_to_aws_message() {
110        let message = Message::User {
111            content: OneOrMany::one(UserContent::Text("text".into())),
112        };
113        let aws_message: Result<aws_bedrock::Message, _> = RigMessage(message).try_into();
114        assert!(aws_message.is_ok());
115        let aws_message = aws_message.unwrap();
116        assert_eq!(aws_message.role, aws_bedrock::ConversationRole::User);
117        assert_eq!(
118            aws_message.content,
119            vec![aws_bedrock::ContentBlock::Text("text".into())]
120        );
121    }
122}