rig_bedrock/types/
document.rs

1use aws_sdk_bedrockruntime::types as aws_bedrock;
2use rig::{
3    completion::CompletionError,
4    message::{Document, DocumentSourceKind},
5};
6
7pub(crate) use crate::types::media_types::RigDocumentMediaType;
8use base64::{Engine, prelude::BASE64_STANDARD};
9use uuid::Uuid;
10
11#[derive(Clone)]
12pub struct RigDocument(pub Document);
13
14impl TryFrom<RigDocument> for aws_bedrock::DocumentBlock {
15    type Error = CompletionError;
16
17    fn try_from(
18        RigDocument(Document {
19            data, media_type, ..
20        }): RigDocument,
21    ) -> Result<Self, Self::Error> {
22        let document_media_type = media_type.map(|doc| RigDocumentMediaType(doc).try_into());
23
24        let document_media_type = match document_media_type {
25            Some(Ok(document_format)) => Ok(Some(document_format)),
26            Some(Err(err)) => Err(err),
27            None => Ok(None),
28        }?;
29
30        let document_source = match data {
31            DocumentSourceKind::Base64(blob) => {
32                let bytes = BASE64_STANDARD
33                    .decode(blob)
34                    .map_err(|e| CompletionError::RequestError(e.into()))?;
35
36                aws_bedrock::DocumentSource::Bytes(aws_smithy_types::Blob::new(bytes))
37            }
38            // NOTE: until [aws-sdk-bedrockruntime DocumentSource bug #1365](https://github.com/awslabs/aws-sdk-rust/issues/1365)
39            // is resolved we will use this as a workaround
40            // DocumentSourceKind::String(str) => aws_bedrock::DocumentSource::Text(str),
41            DocumentSourceKind::String(str) => {
42                aws_bedrock::DocumentSource::Bytes(aws_smithy_types::Blob::new(str.as_bytes()))
43            }
44            doc => {
45                return Err(CompletionError::RequestError(
46                    format!("Unsupported document kind: {doc}").into(),
47                ));
48            }
49        };
50
51        let random_string = Uuid::new_v4().simple().to_string();
52        let document_name = format!("document-{random_string}");
53        let result = aws_bedrock::DocumentBlock::builder()
54            .source(document_source)
55            .name(document_name)
56            .set_format(document_media_type)
57            .build()
58            .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
59        Ok(result)
60    }
61}
62
63impl TryFrom<aws_bedrock::DocumentBlock> for RigDocument {
64    type Error = CompletionError;
65
66    fn try_from(value: aws_bedrock::DocumentBlock) -> Result<Self, Self::Error> {
67        let media_type: RigDocumentMediaType = value.format.try_into()?;
68        let media_type = media_type.0;
69
70        let data = match value.source {
71            Some(aws_bedrock::DocumentSource::Bytes(blob)) => {
72                let encoded_data = BASE64_STANDARD.encode(blob.into_inner());
73                Ok(DocumentSourceKind::Base64(encoded_data))
74            }
75            Some(aws_bedrock::DocumentSource::Text(str)) => Ok(DocumentSourceKind::String(str)),
76            doc => Err(CompletionError::ProviderError(format!(
77                "Unsupported document type: {doc:?}"
78            ))),
79        }?;
80
81        Ok(RigDocument(Document {
82            data,
83            media_type: Some(media_type),
84            additional_params: None,
85        }))
86    }
87}
88
89#[cfg(test)]
90mod tests {
91    use aws_sdk_bedrockruntime::types as aws_bedrock;
92    use base64::{Engine, prelude::BASE64_STANDARD};
93    use rig::{
94        completion::CompletionError,
95        message::{Document, DocumentMediaType, DocumentSourceKind},
96    };
97
98    use crate::types::document::RigDocument;
99
100    #[test]
101    fn test_document_to_aws_document() {
102        let rig_document = RigDocument(Document {
103            data: DocumentSourceKind::Base64("data".into()),
104            media_type: Some(DocumentMediaType::PDF),
105            additional_params: None,
106        });
107
108        let aws_document: Result<aws_bedrock::DocumentBlock, _> = rig_document.clone().try_into();
109        assert!(aws_document.is_ok());
110
111        let aws_document = aws_document.unwrap();
112        assert_eq!(aws_document.format, aws_bedrock::DocumentFormat::Pdf);
113
114        let document_data = rig_document
115            .0
116            .data
117            .try_into_inner()
118            .unwrap()
119            .as_bytes()
120            .to_vec();
121
122        let document_data = BASE64_STANDARD.decode(document_data).unwrap();
123
124        let aws_document_bytes = aws_document
125            .source()
126            .unwrap()
127            .as_bytes()
128            .unwrap()
129            .as_ref()
130            .to_owned();
131
132        let doc_name = aws_document.name;
133        assert!(doc_name.starts_with("document-"));
134        assert_eq!(aws_document_bytes, document_data)
135    }
136
137    #[test]
138    fn test_base64_document_to_aws_document() {
139        let rig_document = RigDocument(Document {
140            data: DocumentSourceKind::Base64("data".into()),
141            media_type: Some(DocumentMediaType::PDF),
142            additional_params: None,
143        });
144
145        let aws_document: aws_bedrock::DocumentBlock = rig_document.clone().try_into().unwrap();
146        let document_data = BASE64_STANDARD
147            .decode(rig_document.0.data.try_into_inner().unwrap())
148            .unwrap();
149        let aws_document_bytes = aws_document
150            .source()
151            .unwrap()
152            .as_bytes()
153            .unwrap()
154            .as_ref()
155            .to_owned();
156        assert_eq!(aws_document_bytes, document_data)
157    }
158
159    #[test]
160    fn test_unsupported_document_to_aws_document() {
161        let rig_document = RigDocument(Document {
162            data: DocumentSourceKind::Base64("data".into()),
163            media_type: Some(DocumentMediaType::Javascript),
164            additional_params: None,
165        });
166        let aws_document: Result<aws_bedrock::DocumentBlock, _> = rig_document.clone().try_into();
167        assert_eq!(
168            aws_document.err().unwrap().to_string(),
169            CompletionError::ProviderError(
170                "Unsupported media type application/x-javascript".into()
171            )
172            .to_string()
173        )
174    }
175
176    #[test]
177    fn test_aws_document_to_rig_document() {
178        let data = aws_smithy_types::Blob::new("document_data");
179        let document_source = aws_bedrock::DocumentSource::Bytes(data);
180        let aws_document = aws_bedrock::DocumentBlock::builder()
181            .format(aws_bedrock::DocumentFormat::Pdf)
182            .name("Document")
183            .source(document_source)
184            .build()
185            .unwrap();
186        let rig_document: Result<RigDocument, _> = aws_document.clone().try_into();
187        assert!(rig_document.is_ok());
188        let rig_document = rig_document.unwrap().0;
189        assert_eq!(rig_document.media_type.unwrap(), DocumentMediaType::PDF)
190    }
191
192    #[test]
193    fn test_unsupported_aws_document_to_rig_document() {
194        let data = aws_smithy_types::Blob::new("document_data");
195        let document_source = aws_bedrock::DocumentSource::Bytes(data);
196        let aws_document = aws_bedrock::DocumentBlock::builder()
197            .format(aws_bedrock::DocumentFormat::Xlsx)
198            .name("Document")
199            .source(document_source)
200            .build()
201            .unwrap();
202        let rig_document: Result<RigDocument, _> = aws_document.clone().try_into();
203        assert!(rig_document.is_err());
204        assert_eq!(
205            rig_document.err().unwrap().to_string(),
206            CompletionError::ProviderError("Unsupported media type xlsx".into()).to_string()
207        )
208    }
209}