Skip to main content

rig_s3vectors/
lib.rs

1#[macro_use]
2mod document;
3
4use aws_sdk_s3vectors::{
5    Client,
6    types::{PutInputVector, VectorData},
7};
8use aws_smithy_types::Document;
9use rig::{
10    embeddings::EmbeddingModel,
11    vector_store::{
12        InsertDocuments, VectorStoreError, VectorStoreIndex,
13        request::{SearchFilter, VectorSearchRequest},
14    },
15};
16use serde::{Deserialize, Serialize};
17use serde_json::Value;
18use std::collections::HashMap;
19use uuid::Uuid;
20
21#[derive(Debug, Serialize, Deserialize)]
22pub struct CreateRecord {
23    document: serde_json::Value,
24    embedded_text: String,
25}
26
27// NOTE: Cannot be used in dynamic store due to aws_smithy_types::Document not impl'ing Serialize or Deserialize
28#[derive(Clone, Debug)]
29pub struct S3SearchFilter(aws_smithy_types::Document);
30
31impl SearchFilter for S3SearchFilter {
32    type Value = aws_smithy_types::Document;
33
34    fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
35        let key = key.as_ref().to_owned();
36        Self(document!({ key: { "$eq": value } }))
37    }
38
39    fn gt(key: impl AsRef<str>, value: Self::Value) -> Self {
40        let key = key.as_ref().to_owned();
41        Self(document!({ key: { "$gt": value } }))
42    }
43
44    fn lt(key: impl AsRef<str>, value: Self::Value) -> Self {
45        let key = key.as_ref().to_owned();
46        Self(document!({ key: { "$lt": value } }))
47    }
48
49    fn and(self, rhs: Self) -> Self {
50        Self(document!({ "$and": [ self.0, rhs.0 ]}))
51    }
52
53    fn or(self, rhs: Self) -> Self {
54        Self(document!({ "$or": [ self.0, rhs.0 ]}))
55    }
56}
57
58impl S3SearchFilter {
59    pub fn inner(&self) -> &aws_smithy_types::Document {
60        &self.0
61    }
62
63    pub fn into_inner(self) -> aws_smithy_types::Document {
64        self.0
65    }
66
67    pub fn gte(key: String, value: <Self as SearchFilter>::Value) -> Self {
68        Self(document!({ key: { "$gte": value } }))
69    }
70
71    pub fn lte(key: String, value: <Self as SearchFilter>::Value) -> Self {
72        Self(document!({ key: { "$lte": value } }))
73    }
74
75    pub fn exists(key: String) -> Self {
76        Self(document!({ "$exists": { key: true } }))
77    }
78
79    #[allow(clippy::should_implement_trait)]
80    pub fn not(self) -> Self {
81        Self(document!({ "$not": self.0 }))
82    }
83}
84
85pub struct S3VectorsVectorStore<M> {
86    embedding_model: M,
87    client: Client,
88    bucket_name: String,
89    index_name: String,
90}
91
92impl<M> S3VectorsVectorStore<M>
93where
94    M: EmbeddingModel,
95{
96    pub fn new(
97        embedding_model: M,
98        client: aws_sdk_s3vectors::Client,
99        bucket_name: &str,
100        index_name: &str,
101    ) -> Self {
102        Self {
103            embedding_model,
104            client,
105            bucket_name: bucket_name.to_string(),
106            index_name: index_name.to_string(),
107        }
108    }
109
110    pub fn bucket_name(&self) -> &str {
111        &self.bucket_name
112    }
113
114    pub fn set_bucket_name(&mut self, bucket_name: &str) {
115        self.bucket_name = bucket_name.to_string();
116    }
117
118    pub fn index_name(&self) -> &str {
119        &self.index_name
120    }
121
122    pub fn set_index_name(&mut self, index_name: &str) {
123        self.index_name = index_name.to_string();
124    }
125
126    pub fn client(&self) -> &Client {
127        &self.client
128    }
129}
130
131impl<M> InsertDocuments for S3VectorsVectorStore<M>
132where
133    M: EmbeddingModel,
134{
135    async fn insert_documents<Doc: serde::Serialize + rig::Embed + Send>(
136        &self,
137        documents: Vec<(Doc, rig::OneOrMany<rig::embeddings::Embedding>)>,
138    ) -> Result<(), rig::vector_store::VectorStoreError> {
139        let docs: Vec<PutInputVector> = documents
140            .into_iter()
141            .map(|x| {
142                let json_value = serde_json::to_value(&x.0).map_err(VectorStoreError::JsonError)?;
143
144                x.1.into_iter()
145                    .map(|y| {
146                        let document = CreateRecord {
147                            document: json_value.clone(),
148                            embedded_text: y.document,
149                        };
150                        let document =
151                            serde_json::to_value(&document).map_err(VectorStoreError::JsonError)?;
152                        let document = json_value_to_document(&document);
153                        let vec = y.vec.into_iter().map(|item| item as f32).collect();
154                        PutInputVector::builder()
155                            .metadata(document.clone())
156                            .data(VectorData::Float32(vec))
157                            .key(Uuid::new_v4())
158                            .build()
159                            .map_err(|x| {
160                                VectorStoreError::DatastoreError(
161                                    format!("Couldn't build vector input: {x}").into(),
162                                )
163                            })
164                    })
165                    .collect()
166            })
167            .collect::<Result<Vec<Vec<PutInputVector>>, VectorStoreError>>()
168            .map_err(|x| {
169                VectorStoreError::DatastoreError(
170                    format!("Could not build vector store data: {x}").into(),
171                )
172            })?
173            .into_iter()
174            .flatten()
175            .collect();
176
177        self.client
178            .put_vectors()
179            .vector_bucket_name(self.bucket_name())
180            .set_vectors(Some(docs))
181            .set_index_name(Some(self.index_name.clone()))
182            .send()
183            .await
184            .map_err(|x| {
185                VectorStoreError::DatastoreError(
186                    format!("Error while submitting document insertion request: {x}").into(),
187                )
188            })?;
189
190        Ok(())
191    }
192}
193
194fn json_value_to_document(value: &Value) -> Document {
195    match value {
196        Value::Null => Document::Null,
197        Value::Bool(b) => Document::Bool(*b),
198        Value::Number(n) => {
199            if let Some(i) = n.as_i64() {
200                Document::Number(aws_smithy_types::Number::NegInt(i))
201            } else if let Some(u) = n.as_u64() {
202                Document::Number(aws_smithy_types::Number::PosInt(u))
203            } else if let Some(f) = n.as_f64() {
204                Document::Number(aws_smithy_types::Number::Float(f))
205            } else {
206                Document::Null // fallback, should never happen
207            }
208        }
209        Value::String(s) => Document::String(s.clone()),
210        Value::Array(arr) => Document::Array(arr.iter().map(json_value_to_document).collect()),
211        Value::Object(obj) => Document::Object(
212            obj.iter()
213                .map(|(k, v)| (k.clone(), json_value_to_document(v)))
214                .collect::<HashMap<_, _>>(),
215        ),
216    }
217}
218
219fn document_to_json_value(value: &Document) -> Value {
220    match value {
221        Document::Null => Value::Null,
222        Document::Bool(b) => Value::Bool(*b),
223        Document::Number(n) => match n {
224            aws_smithy_types::Number::Float(f) => serde_json::Number::from_f64(*f)
225                .map(Value::Number)
226                .unwrap_or_else(|| Value::String(f.to_string())),
227            aws_smithy_types::Number::NegInt(i) => {
228                serde_json::Value::Number(serde_json::Number::from(*i))
229            }
230            aws_smithy_types::Number::PosInt(u) => {
231                serde_json::Value::Number(serde_json::Number::from(*u))
232            }
233        },
234        Document::String(s) => Value::String(s.clone()),
235        Document::Array(arr) => Value::Array(arr.iter().map(document_to_json_value).collect()),
236        Document::Object(obj) => {
237            let res = obj
238                .iter()
239                .map(|(k, v)| (k.clone(), document_to_json_value(v)))
240                .collect::<serde_json::Map<String, serde_json::Value>>();
241
242            serde_json::Value::Object(res)
243        }
244    }
245}
246
247impl<M> VectorStoreIndex for S3VectorsVectorStore<M>
248where
249    M: EmbeddingModel,
250{
251    type Filter = S3SearchFilter;
252
253    async fn top_n<T: for<'a> serde::Deserialize<'a> + Send>(
254        &self,
255        req: VectorSearchRequest<S3SearchFilter>,
256    ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
257        if req.samples() > i32::MAX as u64 {
258            return Err(VectorStoreError::DatastoreError(format!("The number of samples to return with the `rig` AWS S3Vectors integration cannot be higher than {}", i32::MAX).into()));
259        }
260
261        let embedding = self
262            .embedding_model
263            .embed_text(req.query())
264            .await?
265            .vec
266            .into_iter()
267            .map(|x| x as f32)
268            .collect();
269
270        let mut query_builder = self
271            .client
272            .query_vectors()
273            .query_vector(VectorData::Float32(embedding))
274            .top_k(req.samples() as i32)
275            .return_distance(true)
276            .return_metadata(true)
277            .vector_bucket_name(self.bucket_name())
278            .index_name(self.index_name());
279
280        if let Some(filter) = req.filter() {
281            query_builder = query_builder.filter(filter.inner().clone())
282        }
283
284        let query = query_builder
285            .send()
286            .await
287            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
288
289        let res: Vec<(f64, String, T)> = query
290            .vectors
291            .into_iter()
292            .map(|x| {
293                let distance = x.distance.ok_or_else(|| {
294                    VectorStoreError::DatastoreError(Box::new(std::io::Error::other(
295                        "S3Vectors response missing distance",
296                    )))
297                })? as f64;
298
299                if req
300                    .threshold()
301                    .is_some_and(|threshold| distance < threshold)
302                {
303                    return Ok(None);
304                }
305
306                let metadata_document = x.metadata.ok_or_else(|| {
307                    VectorStoreError::DatastoreError(Box::new(std::io::Error::other(
308                        "S3Vectors response missing metadata",
309                    )))
310                })?;
311                let val = document_to_json_value(&metadata_document);
312                let metadata: T = serde_json::from_value(val)?;
313
314                Ok(Some((distance, x.key, metadata)))
315            })
316            .collect::<Result<Vec<_>, VectorStoreError>>()?
317            .into_iter()
318            .flatten()
319            .collect();
320
321        Ok(res)
322    }
323
324    async fn top_n_ids(
325        &self,
326        req: VectorSearchRequest<S3SearchFilter>,
327    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
328        if req.samples() > i32::MAX as u64 {
329            return Err(VectorStoreError::DatastoreError(format!("The number of samples to return with the `rig` AWS S3Vectors integration cannot be higher than {}", i32::MAX).into()));
330        }
331
332        let embedding = self
333            .embedding_model
334            .embed_text(req.query())
335            .await?
336            .vec
337            .into_iter()
338            .map(|x| x as f32)
339            .collect();
340
341        let mut query_builder = self
342            .client
343            .query_vectors()
344            .query_vector(VectorData::Float32(embedding))
345            .top_k(req.samples() as i32)
346            .return_distance(true)
347            .vector_bucket_name(self.bucket_name())
348            .index_name(self.index_name());
349
350        if let Some(filter) = req.filter() {
351            query_builder = query_builder.filter(filter.inner().clone())
352        }
353
354        let query = query_builder
355            .send()
356            .await
357            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
358
359        let res: Vec<(f64, String)> = query
360            .vectors
361            .into_iter()
362            .map(|x| {
363                let distance = x.distance.ok_or_else(|| {
364                    VectorStoreError::DatastoreError(Box::new(std::io::Error::other(
365                        "S3Vectors response missing distance",
366                    )))
367                })? as f64;
368
369                if req
370                    .threshold()
371                    .is_some_and(|threshold| distance < threshold)
372                {
373                    return Ok(None);
374                }
375
376                Ok(Some((distance, x.key)))
377            })
378            .collect::<Result<Vec<_>, VectorStoreError>>()?
379            .into_iter()
380            .flatten()
381            .collect();
382
383        Ok(res)
384    }
385}