1#[macro_use]
2mod document;
3
4use aws_sdk_s3vectors::{
5 Client,
6 types::{PutInputVector, VectorData},
7};
8use aws_smithy_types::Document;
9use rig::{
10 embeddings::EmbeddingModel,
11 vector_store::{
12 InsertDocuments, VectorStoreError, VectorStoreIndex,
13 request::{SearchFilter, VectorSearchRequest},
14 },
15};
16use serde::{Deserialize, Serialize};
17use serde_json::Value;
18use std::collections::HashMap;
19use uuid::Uuid;
20
21#[derive(Debug, Serialize, Deserialize)]
22pub struct CreateRecord {
23 document: serde_json::Value,
24 embedded_text: String,
25}
26
27#[derive(Clone, Debug)]
29pub struct S3SearchFilter(aws_smithy_types::Document);
30
31impl SearchFilter for S3SearchFilter {
32 type Value = aws_smithy_types::Document;
33
34 fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
35 let key = key.as_ref().to_owned();
36 Self(document!({ key: { "$eq": value } }))
37 }
38
39 fn gt(key: impl AsRef<str>, value: Self::Value) -> Self {
40 let key = key.as_ref().to_owned();
41 Self(document!({ key: { "$gt": value } }))
42 }
43
44 fn lt(key: impl AsRef<str>, value: Self::Value) -> Self {
45 let key = key.as_ref().to_owned();
46 Self(document!({ key: { "$lt": value } }))
47 }
48
49 fn and(self, rhs: Self) -> Self {
50 Self(document!({ "$and": [ self.0, rhs.0 ]}))
51 }
52
53 fn or(self, rhs: Self) -> Self {
54 Self(document!({ "$or": [ self.0, rhs.0 ]}))
55 }
56}
57
58impl S3SearchFilter {
59 pub fn inner(&self) -> &aws_smithy_types::Document {
60 &self.0
61 }
62
63 pub fn into_inner(self) -> aws_smithy_types::Document {
64 self.0
65 }
66
67 pub fn gte(key: String, value: <Self as SearchFilter>::Value) -> Self {
68 Self(document!({ key: { "$gte": value } }))
69 }
70
71 pub fn lte(key: String, value: <Self as SearchFilter>::Value) -> Self {
72 Self(document!({ key: { "$lte": value } }))
73 }
74
75 pub fn exists(key: String) -> Self {
76 Self(document!({ "$exists": { key: true } }))
77 }
78
79 #[allow(clippy::should_implement_trait)]
80 pub fn not(self) -> Self {
81 Self(document!({ "$not": self.0 }))
82 }
83}
84
85pub struct S3VectorsVectorStore<M> {
86 embedding_model: M,
87 client: Client,
88 bucket_name: String,
89 index_name: String,
90}
91
92impl<M> S3VectorsVectorStore<M>
93where
94 M: EmbeddingModel,
95{
96 pub fn new(
97 embedding_model: M,
98 client: aws_sdk_s3vectors::Client,
99 bucket_name: &str,
100 index_name: &str,
101 ) -> Self {
102 Self {
103 embedding_model,
104 client,
105 bucket_name: bucket_name.to_string(),
106 index_name: index_name.to_string(),
107 }
108 }
109
110 pub fn bucket_name(&self) -> &str {
111 &self.bucket_name
112 }
113
114 pub fn set_bucket_name(&mut self, bucket_name: &str) {
115 self.bucket_name = bucket_name.to_string();
116 }
117
118 pub fn index_name(&self) -> &str {
119 &self.index_name
120 }
121
122 pub fn set_index_name(&mut self, index_name: &str) {
123 self.index_name = index_name.to_string();
124 }
125
126 pub fn client(&self) -> &Client {
127 &self.client
128 }
129}
130
131impl<M> InsertDocuments for S3VectorsVectorStore<M>
132where
133 M: EmbeddingModel,
134{
135 async fn insert_documents<Doc: serde::Serialize + rig::Embed + Send>(
136 &self,
137 documents: Vec<(Doc, rig::OneOrMany<rig::embeddings::Embedding>)>,
138 ) -> Result<(), rig::vector_store::VectorStoreError> {
139 let docs: Vec<PutInputVector> = documents
140 .into_iter()
141 .map(|x| {
142 let json_value = serde_json::to_value(&x.0).map_err(VectorStoreError::JsonError)?;
143
144 x.1.into_iter()
145 .map(|y| {
146 let document = CreateRecord {
147 document: json_value.clone(),
148 embedded_text: y.document,
149 };
150 let document =
151 serde_json::to_value(&document).map_err(VectorStoreError::JsonError)?;
152 let document = json_value_to_document(&document);
153 let vec = y.vec.into_iter().map(|item| item as f32).collect();
154 PutInputVector::builder()
155 .metadata(document.clone())
156 .data(VectorData::Float32(vec))
157 .key(Uuid::new_v4())
158 .build()
159 .map_err(|x| {
160 VectorStoreError::DatastoreError(
161 format!("Couldn't build vector input: {x}").into(),
162 )
163 })
164 })
165 .collect()
166 })
167 .collect::<Result<Vec<Vec<PutInputVector>>, VectorStoreError>>()
168 .map_err(|x| {
169 VectorStoreError::DatastoreError(
170 format!("Could not build vector store data: {x}").into(),
171 )
172 })?
173 .into_iter()
174 .flatten()
175 .collect();
176
177 self.client
178 .put_vectors()
179 .vector_bucket_name(self.bucket_name())
180 .set_vectors(Some(docs))
181 .set_index_name(Some(self.index_name.clone()))
182 .send()
183 .await
184 .map_err(|x| {
185 VectorStoreError::DatastoreError(
186 format!("Error while submitting document insertion request: {x}").into(),
187 )
188 })?;
189
190 Ok(())
191 }
192}
193
194fn json_value_to_document(value: &Value) -> Document {
195 match value {
196 Value::Null => Document::Null,
197 Value::Bool(b) => Document::Bool(*b),
198 Value::Number(n) => {
199 if let Some(i) = n.as_i64() {
200 Document::Number(aws_smithy_types::Number::NegInt(i))
201 } else if let Some(u) = n.as_u64() {
202 Document::Number(aws_smithy_types::Number::PosInt(u))
203 } else if let Some(f) = n.as_f64() {
204 Document::Number(aws_smithy_types::Number::Float(f))
205 } else {
206 Document::Null }
208 }
209 Value::String(s) => Document::String(s.clone()),
210 Value::Array(arr) => Document::Array(arr.iter().map(json_value_to_document).collect()),
211 Value::Object(obj) => Document::Object(
212 obj.iter()
213 .map(|(k, v)| (k.clone(), json_value_to_document(v)))
214 .collect::<HashMap<_, _>>(),
215 ),
216 }
217}
218
219fn document_to_json_value(value: &Document) -> Value {
220 match value {
221 Document::Null => Value::Null,
222 Document::Bool(b) => Value::Bool(*b),
223 Document::Number(n) => match n {
224 aws_smithy_types::Number::Float(f) => serde_json::Number::from_f64(*f)
225 .map(Value::Number)
226 .unwrap_or_else(|| Value::String(f.to_string())),
227 aws_smithy_types::Number::NegInt(i) => {
228 serde_json::Value::Number(serde_json::Number::from(*i))
229 }
230 aws_smithy_types::Number::PosInt(u) => {
231 serde_json::Value::Number(serde_json::Number::from(*u))
232 }
233 },
234 Document::String(s) => Value::String(s.clone()),
235 Document::Array(arr) => Value::Array(arr.iter().map(document_to_json_value).collect()),
236 Document::Object(obj) => {
237 let res = obj
238 .iter()
239 .map(|(k, v)| (k.clone(), document_to_json_value(v)))
240 .collect::<serde_json::Map<String, serde_json::Value>>();
241
242 serde_json::Value::Object(res)
243 }
244 }
245}
246
247impl<M> VectorStoreIndex for S3VectorsVectorStore<M>
248where
249 M: EmbeddingModel,
250{
251 type Filter = S3SearchFilter;
252
253 async fn top_n<T: for<'a> serde::Deserialize<'a> + Send>(
254 &self,
255 req: VectorSearchRequest<S3SearchFilter>,
256 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
257 if req.samples() > i32::MAX as u64 {
258 return Err(VectorStoreError::DatastoreError(format!("The number of samples to return with the `rig` AWS S3Vectors integration cannot be higher than {}", i32::MAX).into()));
259 }
260
261 let embedding = self
262 .embedding_model
263 .embed_text(req.query())
264 .await?
265 .vec
266 .into_iter()
267 .map(|x| x as f32)
268 .collect();
269
270 let mut query_builder = self
271 .client
272 .query_vectors()
273 .query_vector(VectorData::Float32(embedding))
274 .top_k(req.samples() as i32)
275 .return_distance(true)
276 .return_metadata(true)
277 .vector_bucket_name(self.bucket_name())
278 .index_name(self.index_name());
279
280 if let Some(filter) = req.filter() {
281 query_builder = query_builder.filter(filter.inner().clone())
282 }
283
284 let query = query_builder
285 .send()
286 .await
287 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
288
289 let res: Vec<(f64, String, T)> = query
290 .vectors
291 .into_iter()
292 .map(|x| {
293 let distance = x.distance.ok_or_else(|| {
294 VectorStoreError::DatastoreError(Box::new(std::io::Error::other(
295 "S3Vectors response missing distance",
296 )))
297 })? as f64;
298
299 if req
300 .threshold()
301 .is_some_and(|threshold| distance < threshold)
302 {
303 return Ok(None);
304 }
305
306 let metadata_document = x.metadata.ok_or_else(|| {
307 VectorStoreError::DatastoreError(Box::new(std::io::Error::other(
308 "S3Vectors response missing metadata",
309 )))
310 })?;
311 let val = document_to_json_value(&metadata_document);
312 let metadata: T = serde_json::from_value(val)?;
313
314 Ok(Some((distance, x.key, metadata)))
315 })
316 .collect::<Result<Vec<_>, VectorStoreError>>()?
317 .into_iter()
318 .flatten()
319 .collect();
320
321 Ok(res)
322 }
323
324 async fn top_n_ids(
325 &self,
326 req: VectorSearchRequest<S3SearchFilter>,
327 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
328 if req.samples() > i32::MAX as u64 {
329 return Err(VectorStoreError::DatastoreError(format!("The number of samples to return with the `rig` AWS S3Vectors integration cannot be higher than {}", i32::MAX).into()));
330 }
331
332 let embedding = self
333 .embedding_model
334 .embed_text(req.query())
335 .await?
336 .vec
337 .into_iter()
338 .map(|x| x as f32)
339 .collect();
340
341 let mut query_builder = self
342 .client
343 .query_vectors()
344 .query_vector(VectorData::Float32(embedding))
345 .top_k(req.samples() as i32)
346 .return_distance(true)
347 .vector_bucket_name(self.bucket_name())
348 .index_name(self.index_name());
349
350 if let Some(filter) = req.filter() {
351 query_builder = query_builder.filter(filter.inner().clone())
352 }
353
354 let query = query_builder
355 .send()
356 .await
357 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
358
359 let res: Vec<(f64, String)> = query
360 .vectors
361 .into_iter()
362 .map(|x| {
363 let distance = x.distance.ok_or_else(|| {
364 VectorStoreError::DatastoreError(Box::new(std::io::Error::other(
365 "S3Vectors response missing distance",
366 )))
367 })? as f64;
368
369 if req
370 .threshold()
371 .is_some_and(|threshold| distance < threshold)
372 {
373 return Ok(None);
374 }
375
376 Ok(Some((distance, x.key)))
377 })
378 .collect::<Result<Vec<_>, VectorStoreError>>()?
379 .into_iter()
380 .flatten()
381 .collect();
382
383 Ok(res)
384 }
385}