rig_bedrock/types/
message.rs1use aws_sdk_bedrockruntime::types as aws_bedrock;
2
3use rig::{
4 OneOrMany,
5 completion::CompletionError,
6 message::{AssistantContent, Message, UserContent},
7};
8
9use super::{assistant_content::RigAssistantContent, user_content::RigUserContent};
10
11pub struct RigMessage(pub Message);
12
13impl TryFrom<RigMessage> for aws_bedrock::Message {
14 type Error = CompletionError;
15
16 fn try_from(value: RigMessage) -> Result<Self, Self::Error> {
17 let result = match value.0 {
18 Message::System { .. } => {
19 return Err(CompletionError::ProviderError(
20 "System messages must be sent via Bedrock system blocks".to_string(),
21 ));
22 }
23 Message::User { content } => {
24 let message_content = content
25 .into_iter()
26 .map(|user_content| RigUserContent(user_content).try_into())
27 .collect::<Result<Vec<Vec<_>>, _>>()
28 .map_err(|e| CompletionError::RequestError(Box::new(e)))
29 .map(|nested| nested.into_iter().flatten().collect())?;
30
31 aws_bedrock::Message::builder()
32 .role(aws_bedrock::ConversationRole::User)
33 .set_content(Some(message_content))
34 .build()
35 .map_err(|e| CompletionError::RequestError(Box::new(e)))?
36 }
37 Message::Assistant { content, .. } => aws_bedrock::Message::builder()
38 .role(aws_bedrock::ConversationRole::Assistant)
39 .set_content(Some(
40 content
41 .into_iter()
42 .map(|content| RigAssistantContent(content).try_into())
43 .collect::<Result<Vec<aws_bedrock::ContentBlock>, _>>()?,
44 ))
45 .build()
46 .map_err(|e| CompletionError::RequestError(Box::new(e)))?,
47 };
48 Ok(result)
49 }
50}
51
52impl TryFrom<aws_bedrock::Message> for RigMessage {
53 type Error = CompletionError;
54
55 fn try_from(message: aws_bedrock::Message) -> Result<Self, Self::Error> {
56 match message.role {
57 aws_bedrock::ConversationRole::Assistant => {
58 let assistant_content = message
59 .content
60 .into_iter()
61 .map(|c| c.try_into())
62 .collect::<Result<Vec<RigAssistantContent>, _>>()?
63 .into_iter()
64 .map(|rig_assistant_content| rig_assistant_content.0)
65 .collect::<Vec<AssistantContent>>();
66
67 let content = OneOrMany::many(assistant_content)
68 .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
69
70 Ok(RigMessage(Message::Assistant { content, id: None }))
71 }
72 aws_bedrock::ConversationRole::User => {
73 let user_content = message
74 .content
75 .into_iter()
76 .map(|c| c.try_into())
77 .collect::<Result<Vec<RigUserContent>, _>>()?
78 .into_iter()
79 .map(|user_content| user_content.0)
80 .collect::<Vec<UserContent>>();
81
82 let content = OneOrMany::many(user_content)
83 .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
84 Ok(RigMessage(Message::User { content }))
85 }
86 _ => Err(CompletionError::ProviderError(
87 "AWS Bedrock returned unsupported ConversationRole".into(),
88 )),
89 }
90 }
91}
92
93impl TryFrom<super::converse_output::Message> for RigMessage {
94 type Error = CompletionError;
95
96 fn try_from(message: super::converse_output::Message) -> Result<Self, Self::Error> {
97 let message = aws_bedrock::Message::try_from(message)
98 .map_err(|x| CompletionError::ProviderError(format!("Type conversion error: {x}")))?;
99
100 Self::try_from(message)
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use crate::types::message::RigMessage;
107 use aws_sdk_bedrockruntime::types as aws_bedrock;
108 use rig::{
109 OneOrMany,
110 message::{Message, UserContent},
111 };
112
113 #[test]
114 fn message_to_aws_message() {
115 let message = Message::User {
116 content: OneOrMany::one(UserContent::Text("text".into())),
117 };
118 let aws_message: Result<aws_bedrock::Message, _> = RigMessage(message).try_into();
119 assert!(aws_message.is_ok());
120 let aws_message = aws_message.unwrap();
121 assert_eq!(aws_message.role, aws_bedrock::ConversationRole::User);
122 assert_eq!(
123 aws_message.content,
124 vec![aws_bedrock::ContentBlock::Text("text".into())]
125 );
126 }
127}