rig_bedrock/types/
message.rs1use aws_sdk_bedrockruntime::types as aws_bedrock;
2
3use rig::{
4 OneOrMany,
5 completion::CompletionError,
6 message::{AssistantContent, Message, UserContent},
7};
8
9use super::{assistant_content::RigAssistantContent, user_content::RigUserContent};
10
11pub struct RigMessage(pub Message);
12
13impl TryFrom<RigMessage> for aws_bedrock::Message {
14 type Error = CompletionError;
15
16 fn try_from(value: RigMessage) -> Result<Self, Self::Error> {
17 let result = match value.0 {
18 Message::User { content } => {
19 let message_content = content
20 .into_iter()
21 .map(|user_content| RigUserContent(user_content).try_into())
22 .collect::<Result<Vec<Vec<_>>, _>>()
23 .map_err(|e| CompletionError::RequestError(Box::new(e)))
24 .map(|nested| nested.into_iter().flatten().collect())?;
25
26 aws_bedrock::Message::builder()
27 .role(aws_bedrock::ConversationRole::User)
28 .set_content(Some(message_content))
29 .build()
30 .map_err(|e| CompletionError::RequestError(Box::new(e)))?
31 }
32 Message::Assistant { content, .. } => aws_bedrock::Message::builder()
33 .role(aws_bedrock::ConversationRole::Assistant)
34 .set_content(Some(
35 content
36 .into_iter()
37 .map(|content| RigAssistantContent(content).try_into())
38 .collect::<Result<Vec<aws_bedrock::ContentBlock>, _>>()?,
39 ))
40 .build()
41 .map_err(|e| CompletionError::RequestError(Box::new(e)))?,
42 };
43 Ok(result)
44 }
45}
46
47impl TryFrom<aws_bedrock::Message> for RigMessage {
48 type Error = CompletionError;
49
50 fn try_from(message: aws_bedrock::Message) -> Result<Self, Self::Error> {
51 match message.role {
52 aws_bedrock::ConversationRole::Assistant => {
53 let assistant_content = message
54 .content
55 .into_iter()
56 .map(|c| c.try_into())
57 .collect::<Result<Vec<RigAssistantContent>, _>>()?
58 .into_iter()
59 .map(|rig_assistant_content| rig_assistant_content.0)
60 .collect::<Vec<AssistantContent>>();
61
62 let content = OneOrMany::many(assistant_content)
63 .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
64
65 Ok(RigMessage(Message::Assistant { content, id: None }))
66 }
67 aws_bedrock::ConversationRole::User => {
68 let user_content = message
69 .content
70 .into_iter()
71 .map(|c| c.try_into())
72 .collect::<Result<Vec<RigUserContent>, _>>()?
73 .into_iter()
74 .map(|user_content| user_content.0)
75 .collect::<Vec<UserContent>>();
76
77 let content = OneOrMany::many(user_content)
78 .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
79 Ok(RigMessage(Message::User { content }))
80 }
81 _ => Err(CompletionError::ProviderError(
82 "AWS Bedrock returned unsupported ConversationRole".into(),
83 )),
84 }
85 }
86}
87
88impl TryFrom<super::converse_output::Message> for RigMessage {
89 type Error = CompletionError;
90
91 fn try_from(message: super::converse_output::Message) -> Result<Self, Self::Error> {
92 let message = aws_bedrock::Message::try_from(message)
93 .map_err(|x| CompletionError::ProviderError(format!("Type conversion error: {x}")))?;
94
95 Self::try_from(message)
96 }
97}
98
99#[cfg(test)]
100mod tests {
101 use crate::types::message::RigMessage;
102 use aws_sdk_bedrockruntime::types as aws_bedrock;
103 use rig::{
104 OneOrMany,
105 message::{Message, UserContent},
106 };
107
108 #[test]
109 fn message_to_aws_message() {
110 let message = Message::User {
111 content: OneOrMany::one(UserContent::Text("text".into())),
112 };
113 let aws_message: Result<aws_bedrock::Message, _> = RigMessage(message).try_into();
114 assert!(aws_message.is_ok());
115 let aws_message = aws_message.unwrap();
116 assert_eq!(aws_message.role, aws_bedrock::ConversationRole::User);
117 assert_eq!(
118 aws_message.content,
119 vec![aws_bedrock::ContentBlock::Text("text".into())]
120 );
121 }
122}