rig_s3vectors/
lib.rs

1#[macro_use]
2mod document;
3
4use aws_sdk_s3vectors::{
5    Client,
6    types::{PutInputVector, VectorData},
7};
8use aws_smithy_types::Document;
9use rig::{
10    embeddings::EmbeddingModel,
11    vector_store::{
12        InsertDocuments, VectorStoreError, VectorStoreIndex,
13        request::{SearchFilter, VectorSearchRequest},
14    },
15};
16use serde::{Deserialize, Serialize};
17use serde_json::Value;
18use std::collections::HashMap;
19use uuid::Uuid;
20
21#[derive(Debug, Serialize, Deserialize)]
22pub struct CreateRecord {
23    document: serde_json::Value,
24    embedded_text: String,
25}
26
27#[derive(Clone, Debug)]
28pub struct S3SearchFilter(aws_smithy_types::Document);
29
30impl SearchFilter for S3SearchFilter {
31    type Value = aws_smithy_types::Document;
32
33    fn eq(key: String, value: Self::Value) -> Self {
34        Self(document!({ key: { "$eq": value } }))
35    }
36
37    fn gt(key: String, value: Self::Value) -> Self {
38        Self(document!({ key: { "$gt": value } }))
39    }
40
41    fn lt(key: String, value: Self::Value) -> Self {
42        Self(document!({ key: { "$lt": value } }))
43    }
44
45    fn and(self, rhs: Self) -> Self {
46        Self(document!({ "$and": [ self.0, rhs.0 ]}))
47    }
48
49    fn or(self, rhs: Self) -> Self {
50        Self(document!({ "$or": [ self.0, rhs.0 ]}))
51    }
52}
53
54impl S3SearchFilter {
55    pub fn inner(&self) -> &aws_smithy_types::Document {
56        &self.0
57    }
58
59    pub fn into_inner(self) -> aws_smithy_types::Document {
60        self.0
61    }
62
63    pub fn gte(key: String, value: <Self as SearchFilter>::Value) -> Self {
64        Self(document!({ key: { "$gte": value } }))
65    }
66
67    pub fn lte(key: String, value: <Self as SearchFilter>::Value) -> Self {
68        Self(document!({ key: { "$lte": value } }))
69    }
70
71    pub fn exists(key: String) -> Self {
72        Self(document!({ "$exists": { key: true } }))
73    }
74
75    #[allow(clippy::should_implement_trait)]
76    pub fn not(self) -> Self {
77        Self(document!({ "$not": self.0 }))
78    }
79}
80
81pub struct S3VectorsVectorStore<M> {
82    embedding_model: M,
83    client: Client,
84    bucket_name: String,
85    index_name: String,
86}
87
88impl<M> S3VectorsVectorStore<M>
89where
90    M: EmbeddingModel,
91{
92    pub fn new(
93        embedding_model: M,
94        client: aws_sdk_s3vectors::Client,
95        bucket_name: &str,
96        index_name: &str,
97    ) -> Self {
98        Self {
99            embedding_model,
100            client,
101            bucket_name: bucket_name.to_string(),
102            index_name: index_name.to_string(),
103        }
104    }
105
106    pub fn bucket_name(&self) -> &str {
107        &self.bucket_name
108    }
109
110    pub fn set_bucket_name(&mut self, bucket_name: &str) {
111        self.bucket_name = bucket_name.to_string();
112    }
113
114    pub fn index_name(&self) -> &str {
115        &self.index_name
116    }
117
118    pub fn set_index_name(&mut self, index_name: &str) {
119        self.index_name = index_name.to_string();
120    }
121
122    pub fn client(&self) -> &Client {
123        &self.client
124    }
125}
126
127impl<M> InsertDocuments for S3VectorsVectorStore<M>
128where
129    M: EmbeddingModel,
130{
131    async fn insert_documents<Doc: serde::Serialize + rig::Embed + Send>(
132        &self,
133        documents: Vec<(Doc, rig::OneOrMany<rig::embeddings::Embedding>)>,
134    ) -> Result<(), rig::vector_store::VectorStoreError> {
135        let docs: Vec<PutInputVector> = documents
136            .into_iter()
137            .map(|x| {
138                let json_value = serde_json::to_value(&x.0).map_err(VectorStoreError::JsonError)?;
139
140                x.1.into_iter()
141                    .map(|y| {
142                        let document = CreateRecord {
143                            document: json_value.clone(),
144                            embedded_text: y.document,
145                        };
146                        let document =
147                            serde_json::to_value(&document).map_err(VectorStoreError::JsonError)?;
148                        let document = json_value_to_document(&document);
149                        let vec = y.vec.into_iter().map(|item| item as f32).collect();
150                        PutInputVector::builder()
151                            .metadata(document.clone())
152                            .data(VectorData::Float32(vec))
153                            .key(Uuid::new_v4())
154                            .build()
155                            .map_err(|x| {
156                                VectorStoreError::DatastoreError(
157                                    format!("Couldn't build vector input: {x}").into(),
158                                )
159                            })
160                    })
161                    .collect()
162            })
163            .collect::<Result<Vec<Vec<PutInputVector>>, VectorStoreError>>()
164            .map_err(|x| {
165                VectorStoreError::DatastoreError(
166                    format!("Could not build vector store data: {x}").into(),
167                )
168            })?
169            .into_iter()
170            .flatten()
171            .collect();
172
173        self.client
174            .put_vectors()
175            .vector_bucket_name(self.bucket_name())
176            .set_vectors(Some(docs))
177            .set_index_name(Some(self.index_name.clone()))
178            .send()
179            .await
180            .map_err(|x| {
181                VectorStoreError::DatastoreError(
182                    format!("Error while submitting document insertion request: {x}").into(),
183                )
184            })?;
185
186        Ok(())
187    }
188}
189
190fn json_value_to_document(value: &Value) -> Document {
191    match value {
192        Value::Null => Document::Null,
193        Value::Bool(b) => Document::Bool(*b),
194        Value::Number(n) => {
195            if let Some(i) = n.as_i64() {
196                Document::Number(aws_smithy_types::Number::NegInt(i))
197            } else if let Some(u) = n.as_u64() {
198                Document::Number(aws_smithy_types::Number::PosInt(u))
199            } else if let Some(f) = n.as_f64() {
200                Document::Number(aws_smithy_types::Number::Float(f))
201            } else {
202                Document::Null // fallback, should never happen
203            }
204        }
205        Value::String(s) => Document::String(s.clone()),
206        Value::Array(arr) => Document::Array(arr.iter().map(json_value_to_document).collect()),
207        Value::Object(obj) => Document::Object(
208            obj.iter()
209                .map(|(k, v)| (k.clone(), json_value_to_document(v)))
210                .collect::<HashMap<_, _>>(),
211        ),
212    }
213}
214
215fn document_to_json_value(value: &Document) -> Value {
216    match value {
217        Document::Null => Value::Null,
218        Document::Bool(b) => Value::Bool(*b),
219        Document::Number(n) => {
220            let res = match n {
221                aws_smithy_types::Number::Float(f) => {
222                    serde_json::Number::from_f64(f.to_owned()).unwrap()
223                }
224                aws_smithy_types::Number::NegInt(i) => {
225                    serde_json::Number::from_i128(*i as i128).unwrap()
226                }
227                aws_smithy_types::Number::PosInt(u) => {
228                    serde_json::Number::from_u128(*u as u128).unwrap()
229                }
230            };
231
232            serde_json::Value::Number(res)
233        }
234        Document::String(s) => Value::String(s.clone()),
235        Document::Array(arr) => Value::Array(arr.iter().map(document_to_json_value).collect()),
236        Document::Object(obj) => {
237            let res = obj
238                .iter()
239                .map(|(k, v)| (k.clone(), document_to_json_value(v)))
240                .collect::<serde_json::Map<String, serde_json::Value>>();
241
242            serde_json::Value::Object(res)
243        }
244    }
245}
246
247impl<M> VectorStoreIndex for S3VectorsVectorStore<M>
248where
249    M: EmbeddingModel,
250{
251    type Filter = S3SearchFilter;
252
253    async fn top_n<T: for<'a> serde::Deserialize<'a> + Send>(
254        &self,
255        req: VectorSearchRequest<S3SearchFilter>,
256    ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
257        if req.samples() > i32::MAX as u64 {
258            return Err(VectorStoreError::DatastoreError(format!("The number of samples to return with the `rig` AWS S3Vectors integration cannot be higher than {}", i32::MAX).into()));
259        }
260
261        let embedding = self
262            .embedding_model
263            .embed_text(req.query())
264            .await?
265            .vec
266            .into_iter()
267            .map(|x| x as f32)
268            .collect();
269
270        let mut query_builder = self
271            .client
272            .query_vectors()
273            .query_vector(VectorData::Float32(embedding))
274            .top_k(req.samples() as i32)
275            .return_distance(true)
276            .return_metadata(true)
277            .vector_bucket_name(self.bucket_name())
278            .index_name(self.index_name());
279
280        if let Some(filter) = req.filter() {
281            query_builder = query_builder.filter(filter.inner().clone())
282        }
283
284        let query = query_builder.send().await.unwrap();
285
286        let res: Vec<(f64, String, T)> = query
287            .vectors
288            .into_iter()
289            .filter(|vector| {
290                req.threshold().is_none_or(|threshold| {
291                    (vector
292                        .distance()
293                        .expect("vector distance should always exist") as f64)
294                        >= threshold
295                })
296            })
297            .map(|x| {
298                let distance = x.distance.expect("vector distance should always exist") as f64;
299                let val =
300                    document_to_json_value(&x.metadata.expect("metadata should always exist"));
301
302                let metadata: T = serde_json::from_value(val)
303                    .expect("converting JSON from S3Vectors to valid T should always work");
304
305                (distance, x.key, metadata)
306            })
307            .collect();
308
309        Ok(res)
310    }
311
312    async fn top_n_ids(
313        &self,
314        req: VectorSearchRequest<S3SearchFilter>,
315    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
316        if req.samples() > i32::MAX as u64 {
317            return Err(VectorStoreError::DatastoreError(format!("The number of samples to return with the `rig` AWS S3Vectors integration cannot be higher than {}", i32::MAX).into()));
318        }
319
320        let embedding = self
321            .embedding_model
322            .embed_text(req.query())
323            .await?
324            .vec
325            .into_iter()
326            .map(|x| x as f32)
327            .collect();
328
329        let mut query_builder = self
330            .client
331            .query_vectors()
332            .query_vector(VectorData::Float32(embedding))
333            .top_k(req.samples() as i32)
334            .return_distance(true)
335            .vector_bucket_name(self.bucket_name())
336            .index_name(self.index_name());
337
338        if let Some(filter) = req.filter() {
339            query_builder = query_builder.filter(filter.inner().clone())
340        }
341
342        let query = query_builder.send().await.unwrap();
343
344        let res: Vec<(f64, String)> = query
345            .vectors
346            .into_iter()
347            .filter(|vector| {
348                req.threshold().is_none_or(|threshold| {
349                    (vector
350                        .distance()
351                        .expect("vector distance should always exist") as f64)
352                        >= threshold
353                })
354            })
355            .map(|x| {
356                let distance = x.distance.expect("vector distance should always exist") as f64;
357
358                (distance, x.key)
359            })
360            .collect();
361
362        Ok(res)
363    }
364}