Skip to main content

rig_bedrock/types/
message.rs

1use aws_sdk_bedrockruntime::types as aws_bedrock;
2
3use rig::{
4    OneOrMany,
5    completion::CompletionError,
6    message::{AssistantContent, Message, UserContent},
7};
8
9use super::{assistant_content::RigAssistantContent, user_content::RigUserContent};
10
11pub struct RigMessage(pub Message);
12
13impl TryFrom<RigMessage> for aws_bedrock::Message {
14    type Error = CompletionError;
15
16    fn try_from(value: RigMessage) -> Result<Self, Self::Error> {
17        let result = match value.0 {
18            Message::System { .. } => {
19                return Err(CompletionError::ProviderError(
20                    "System messages must be sent via Bedrock system blocks".to_string(),
21                ));
22            }
23            Message::User { content } => {
24                let message_content = content
25                    .into_iter()
26                    .map(|user_content| RigUserContent(user_content).try_into())
27                    .collect::<Result<Vec<Vec<_>>, _>>()
28                    .map_err(|e| CompletionError::RequestError(Box::new(e)))
29                    .map(|nested| nested.into_iter().flatten().collect())?;
30
31                aws_bedrock::Message::builder()
32                    .role(aws_bedrock::ConversationRole::User)
33                    .set_content(Some(message_content))
34                    .build()
35                    .map_err(|e| CompletionError::RequestError(Box::new(e)))?
36            }
37            Message::Assistant { content, .. } => aws_bedrock::Message::builder()
38                .role(aws_bedrock::ConversationRole::Assistant)
39                .set_content(Some(
40                    content
41                        .into_iter()
42                        .map(|content| RigAssistantContent(content).try_into())
43                        .collect::<Result<Vec<aws_bedrock::ContentBlock>, _>>()?,
44                ))
45                .build()
46                .map_err(|e| CompletionError::RequestError(Box::new(e)))?,
47        };
48        Ok(result)
49    }
50}
51
52impl TryFrom<aws_bedrock::Message> for RigMessage {
53    type Error = CompletionError;
54
55    fn try_from(message: aws_bedrock::Message) -> Result<Self, Self::Error> {
56        match message.role {
57            aws_bedrock::ConversationRole::Assistant => {
58                let assistant_content = message
59                    .content
60                    .into_iter()
61                    .map(|c| c.try_into())
62                    .collect::<Result<Vec<RigAssistantContent>, _>>()?
63                    .into_iter()
64                    .map(|rig_assistant_content| rig_assistant_content.0)
65                    .collect::<Vec<AssistantContent>>();
66
67                let content = OneOrMany::many(assistant_content)
68                    .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
69
70                Ok(RigMessage(Message::Assistant { content, id: None }))
71            }
72            aws_bedrock::ConversationRole::User => {
73                let user_content = message
74                    .content
75                    .into_iter()
76                    .map(|c| c.try_into())
77                    .collect::<Result<Vec<RigUserContent>, _>>()?
78                    .into_iter()
79                    .map(|user_content| user_content.0)
80                    .collect::<Vec<UserContent>>();
81
82                let content = OneOrMany::many(user_content)
83                    .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
84                Ok(RigMessage(Message::User { content }))
85            }
86            _ => Err(CompletionError::ProviderError(
87                "AWS Bedrock returned unsupported ConversationRole".into(),
88            )),
89        }
90    }
91}
92
93impl TryFrom<super::converse_output::Message> for RigMessage {
94    type Error = CompletionError;
95
96    fn try_from(message: super::converse_output::Message) -> Result<Self, Self::Error> {
97        let message = aws_bedrock::Message::try_from(message)
98            .map_err(|x| CompletionError::ProviderError(format!("Type conversion error: {x}")))?;
99
100        Self::try_from(message)
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use crate::types::message::RigMessage;
107    use aws_sdk_bedrockruntime::types as aws_bedrock;
108    use rig::{
109        OneOrMany,
110        message::{Message, UserContent},
111    };
112
113    #[test]
114    fn message_to_aws_message() {
115        let message = Message::User {
116            content: OneOrMany::one(UserContent::Text("text".into())),
117        };
118        let aws_message: Result<aws_bedrock::Message, _> = RigMessage(message).try_into();
119        assert!(aws_message.is_ok());
120        let aws_message = aws_message.unwrap();
121        assert_eq!(aws_message.role, aws_bedrock::ConversationRole::User);
122        assert_eq!(
123            aws_message.content,
124            vec![aws_bedrock::ContentBlock::Text("text".into())]
125        );
126    }
127}