rig_bedrock/types/
document.rs1use aws_sdk_bedrockruntime::types as aws_bedrock;
2use rig::{
3 completion::CompletionError,
4 message::{Document, DocumentSourceKind},
5};
6
7pub(crate) use crate::types::media_types::RigDocumentMediaType;
8use base64::{Engine, prelude::BASE64_STANDARD};
9use uuid::Uuid;
10
11#[derive(Clone)]
12pub struct RigDocument(pub Document);
13
14impl TryFrom<RigDocument> for aws_bedrock::DocumentBlock {
15 type Error = CompletionError;
16
17 fn try_from(
18 RigDocument(Document {
19 data, media_type, ..
20 }): RigDocument,
21 ) -> Result<Self, Self::Error> {
22 let document_media_type = media_type.map(|doc| RigDocumentMediaType(doc).try_into());
23
24 let document_media_type = match document_media_type {
25 Some(Ok(document_format)) => Ok(Some(document_format)),
26 Some(Err(err)) => Err(err),
27 None => Ok(None),
28 }?;
29
30 let document_source = match data {
31 DocumentSourceKind::Base64(blob) => {
32 let bytes = BASE64_STANDARD
33 .decode(blob)
34 .map_err(|e| CompletionError::RequestError(e.into()))?;
35
36 aws_bedrock::DocumentSource::Bytes(aws_smithy_types::Blob::new(bytes))
37 }
38 DocumentSourceKind::String(str) => {
42 aws_bedrock::DocumentSource::Bytes(aws_smithy_types::Blob::new(str.as_bytes()))
43 }
44 doc => {
45 return Err(CompletionError::RequestError(
46 format!("Unsupported document kind: {doc}").into(),
47 ));
48 }
49 };
50
51 let random_string = Uuid::new_v4().simple().to_string();
52 let document_name = format!("document-{random_string}");
53 let result = aws_bedrock::DocumentBlock::builder()
54 .source(document_source)
55 .name(document_name)
56 .set_format(document_media_type)
57 .build()
58 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
59 Ok(result)
60 }
61}
62
63impl TryFrom<aws_bedrock::DocumentBlock> for RigDocument {
64 type Error = CompletionError;
65
66 fn try_from(value: aws_bedrock::DocumentBlock) -> Result<Self, Self::Error> {
67 let media_type: RigDocumentMediaType = value.format.try_into()?;
68 let media_type = media_type.0;
69
70 let data = match value.source {
71 Some(aws_bedrock::DocumentSource::Bytes(blob)) => {
72 let encoded_data = BASE64_STANDARD.encode(blob.into_inner());
73 Ok(DocumentSourceKind::Base64(encoded_data))
74 }
75 Some(aws_bedrock::DocumentSource::Text(str)) => Ok(DocumentSourceKind::String(str)),
76 doc => Err(CompletionError::ProviderError(format!(
77 "Unsupported document type: {doc:?}"
78 ))),
79 }?;
80
81 Ok(RigDocument(Document {
82 data,
83 media_type: Some(media_type),
84 additional_params: None,
85 }))
86 }
87}
88
89#[cfg(test)]
90mod tests {
91 use aws_sdk_bedrockruntime::types as aws_bedrock;
92 use base64::{Engine, prelude::BASE64_STANDARD};
93 use rig::{
94 completion::CompletionError,
95 message::{Document, DocumentMediaType, DocumentSourceKind},
96 };
97
98 use crate::types::document::RigDocument;
99
100 #[test]
101 fn test_document_to_aws_document() {
102 let rig_document = RigDocument(Document {
103 data: DocumentSourceKind::Base64("data".into()),
104 media_type: Some(DocumentMediaType::PDF),
105 additional_params: None,
106 });
107
108 let aws_document: Result<aws_bedrock::DocumentBlock, _> = rig_document.clone().try_into();
109 assert!(aws_document.is_ok());
110
111 let aws_document = aws_document.unwrap();
112 assert_eq!(aws_document.format, aws_bedrock::DocumentFormat::Pdf);
113
114 let document_data = rig_document
115 .0
116 .data
117 .try_into_inner()
118 .unwrap()
119 .as_bytes()
120 .to_vec();
121
122 let document_data = BASE64_STANDARD.decode(document_data).unwrap();
123
124 let aws_document_bytes = aws_document
125 .source()
126 .unwrap()
127 .as_bytes()
128 .unwrap()
129 .as_ref()
130 .to_owned();
131
132 let doc_name = aws_document.name;
133 assert!(doc_name.starts_with("document-"));
134 assert_eq!(aws_document_bytes, document_data)
135 }
136
137 #[test]
138 fn test_base64_document_to_aws_document() {
139 let rig_document = RigDocument(Document {
140 data: DocumentSourceKind::Base64("data".into()),
141 media_type: Some(DocumentMediaType::PDF),
142 additional_params: None,
143 });
144
145 let aws_document: aws_bedrock::DocumentBlock = rig_document.clone().try_into().unwrap();
146 let document_data = BASE64_STANDARD
147 .decode(rig_document.0.data.try_into_inner().unwrap())
148 .unwrap();
149 let aws_document_bytes = aws_document
150 .source()
151 .unwrap()
152 .as_bytes()
153 .unwrap()
154 .as_ref()
155 .to_owned();
156 assert_eq!(aws_document_bytes, document_data)
157 }
158
159 #[test]
160 fn test_unsupported_document_to_aws_document() {
161 let rig_document = RigDocument(Document {
162 data: DocumentSourceKind::Base64("data".into()),
163 media_type: Some(DocumentMediaType::Javascript),
164 additional_params: None,
165 });
166 let aws_document: Result<aws_bedrock::DocumentBlock, _> = rig_document.clone().try_into();
167 assert_eq!(
168 aws_document.err().unwrap().to_string(),
169 CompletionError::ProviderError(
170 "Unsupported media type application/x-javascript".into()
171 )
172 .to_string()
173 )
174 }
175
176 #[test]
177 fn test_aws_document_to_rig_document() {
178 let data = aws_smithy_types::Blob::new("document_data");
179 let document_source = aws_bedrock::DocumentSource::Bytes(data);
180 let aws_document = aws_bedrock::DocumentBlock::builder()
181 .format(aws_bedrock::DocumentFormat::Pdf)
182 .name("Document")
183 .source(document_source)
184 .build()
185 .unwrap();
186 let rig_document: Result<RigDocument, _> = aws_document.clone().try_into();
187 assert!(rig_document.is_ok());
188 let rig_document = rig_document.unwrap().0;
189 assert_eq!(rig_document.media_type.unwrap(), DocumentMediaType::PDF)
190 }
191
192 #[test]
193 fn test_unsupported_aws_document_to_rig_document() {
194 let data = aws_smithy_types::Blob::new("document_data");
195 let document_source = aws_bedrock::DocumentSource::Bytes(data);
196 let aws_document = aws_bedrock::DocumentBlock::builder()
197 .format(aws_bedrock::DocumentFormat::Xlsx)
198 .name("Document")
199 .source(document_source)
200 .build()
201 .unwrap();
202 let rig_document: Result<RigDocument, _> = aws_document.clone().try_into();
203 assert!(rig_document.is_err());
204 assert_eq!(
205 rig_document.err().unwrap().to_string(),
206 CompletionError::ProviderError("Unsupported media type xlsx".into()).to_string()
207 )
208 }
209}