1#[macro_use]
2mod document;
3
4use aws_sdk_s3vectors::{
5 Client,
6 types::{PutInputVector, VectorData},
7};
8use aws_smithy_types::Document;
9use rig::{
10 embeddings::EmbeddingModel,
11 vector_store::{
12 InsertDocuments, VectorStoreError, VectorStoreIndex,
13 request::{SearchFilter, VectorSearchRequest},
14 },
15};
16use serde::{Deserialize, Serialize};
17use serde_json::Value;
18use std::collections::HashMap;
19use uuid::Uuid;
20
21#[derive(Debug, Serialize, Deserialize)]
22pub struct CreateRecord {
23 document: serde_json::Value,
24 embedded_text: String,
25}
26
27#[derive(Clone, Debug)]
28pub struct S3SearchFilter(aws_smithy_types::Document);
29
30impl SearchFilter for S3SearchFilter {
31 type Value = aws_smithy_types::Document;
32
33 fn eq(key: String, value: Self::Value) -> Self {
34 Self(document!({ key: { "$eq": value } }))
35 }
36
37 fn gt(key: String, value: Self::Value) -> Self {
38 Self(document!({ key: { "$gt": value } }))
39 }
40
41 fn lt(key: String, value: Self::Value) -> Self {
42 Self(document!({ key: { "$lt": value } }))
43 }
44
45 fn and(self, rhs: Self) -> Self {
46 Self(document!({ "$and": [ self.0, rhs.0 ]}))
47 }
48
49 fn or(self, rhs: Self) -> Self {
50 Self(document!({ "$or": [ self.0, rhs.0 ]}))
51 }
52}
53
54impl S3SearchFilter {
55 pub fn inner(&self) -> &aws_smithy_types::Document {
56 &self.0
57 }
58
59 pub fn into_inner(self) -> aws_smithy_types::Document {
60 self.0
61 }
62
63 pub fn gte(key: String, value: <Self as SearchFilter>::Value) -> Self {
64 Self(document!({ key: { "$gte": value } }))
65 }
66
67 pub fn lte(key: String, value: <Self as SearchFilter>::Value) -> Self {
68 Self(document!({ key: { "$lte": value } }))
69 }
70
71 pub fn exists(key: String) -> Self {
72 Self(document!({ "$exists": { key: true } }))
73 }
74
75 #[allow(clippy::should_implement_trait)]
76 pub fn not(self) -> Self {
77 Self(document!({ "$not": self.0 }))
78 }
79}
80
81pub struct S3VectorsVectorStore<M> {
82 embedding_model: M,
83 client: Client,
84 bucket_name: String,
85 index_name: String,
86}
87
88impl<M> S3VectorsVectorStore<M>
89where
90 M: EmbeddingModel,
91{
92 pub fn new(
93 embedding_model: M,
94 client: aws_sdk_s3vectors::Client,
95 bucket_name: &str,
96 index_name: &str,
97 ) -> Self {
98 Self {
99 embedding_model,
100 client,
101 bucket_name: bucket_name.to_string(),
102 index_name: index_name.to_string(),
103 }
104 }
105
106 pub fn bucket_name(&self) -> &str {
107 &self.bucket_name
108 }
109
110 pub fn set_bucket_name(&mut self, bucket_name: &str) {
111 self.bucket_name = bucket_name.to_string();
112 }
113
114 pub fn index_name(&self) -> &str {
115 &self.index_name
116 }
117
118 pub fn set_index_name(&mut self, index_name: &str) {
119 self.index_name = index_name.to_string();
120 }
121
122 pub fn client(&self) -> &Client {
123 &self.client
124 }
125}
126
127impl<M> InsertDocuments for S3VectorsVectorStore<M>
128where
129 M: EmbeddingModel,
130{
131 async fn insert_documents<Doc: serde::Serialize + rig::Embed + Send>(
132 &self,
133 documents: Vec<(Doc, rig::OneOrMany<rig::embeddings::Embedding>)>,
134 ) -> Result<(), rig::vector_store::VectorStoreError> {
135 let docs: Vec<PutInputVector> = documents
136 .into_iter()
137 .map(|x| {
138 let json_value = serde_json::to_value(&x.0).map_err(VectorStoreError::JsonError)?;
139
140 x.1.into_iter()
141 .map(|y| {
142 let document = CreateRecord {
143 document: json_value.clone(),
144 embedded_text: y.document,
145 };
146 let document =
147 serde_json::to_value(&document).map_err(VectorStoreError::JsonError)?;
148 let document = json_value_to_document(&document);
149 let vec = y.vec.into_iter().map(|item| item as f32).collect();
150 PutInputVector::builder()
151 .metadata(document.clone())
152 .data(VectorData::Float32(vec))
153 .key(Uuid::new_v4())
154 .build()
155 .map_err(|x| {
156 VectorStoreError::DatastoreError(
157 format!("Couldn't build vector input: {x}").into(),
158 )
159 })
160 })
161 .collect()
162 })
163 .collect::<Result<Vec<Vec<PutInputVector>>, VectorStoreError>>()
164 .map_err(|x| {
165 VectorStoreError::DatastoreError(
166 format!("Could not build vector store data: {x}").into(),
167 )
168 })?
169 .into_iter()
170 .flatten()
171 .collect();
172
173 self.client
174 .put_vectors()
175 .vector_bucket_name(self.bucket_name())
176 .set_vectors(Some(docs))
177 .set_index_name(Some(self.index_name.clone()))
178 .send()
179 .await
180 .map_err(|x| {
181 VectorStoreError::DatastoreError(
182 format!("Error while submitting document insertion request: {x}").into(),
183 )
184 })?;
185
186 Ok(())
187 }
188}
189
190fn json_value_to_document(value: &Value) -> Document {
191 match value {
192 Value::Null => Document::Null,
193 Value::Bool(b) => Document::Bool(*b),
194 Value::Number(n) => {
195 if let Some(i) = n.as_i64() {
196 Document::Number(aws_smithy_types::Number::NegInt(i))
197 } else if let Some(u) = n.as_u64() {
198 Document::Number(aws_smithy_types::Number::PosInt(u))
199 } else if let Some(f) = n.as_f64() {
200 Document::Number(aws_smithy_types::Number::Float(f))
201 } else {
202 Document::Null }
204 }
205 Value::String(s) => Document::String(s.clone()),
206 Value::Array(arr) => Document::Array(arr.iter().map(json_value_to_document).collect()),
207 Value::Object(obj) => Document::Object(
208 obj.iter()
209 .map(|(k, v)| (k.clone(), json_value_to_document(v)))
210 .collect::<HashMap<_, _>>(),
211 ),
212 }
213}
214
215fn document_to_json_value(value: &Document) -> Value {
216 match value {
217 Document::Null => Value::Null,
218 Document::Bool(b) => Value::Bool(*b),
219 Document::Number(n) => {
220 let res = match n {
221 aws_smithy_types::Number::Float(f) => {
222 serde_json::Number::from_f64(f.to_owned()).unwrap()
223 }
224 aws_smithy_types::Number::NegInt(i) => {
225 serde_json::Number::from_i128(*i as i128).unwrap()
226 }
227 aws_smithy_types::Number::PosInt(u) => {
228 serde_json::Number::from_u128(*u as u128).unwrap()
229 }
230 };
231
232 serde_json::Value::Number(res)
233 }
234 Document::String(s) => Value::String(s.clone()),
235 Document::Array(arr) => Value::Array(arr.iter().map(document_to_json_value).collect()),
236 Document::Object(obj) => {
237 let res = obj
238 .iter()
239 .map(|(k, v)| (k.clone(), document_to_json_value(v)))
240 .collect::<serde_json::Map<String, serde_json::Value>>();
241
242 serde_json::Value::Object(res)
243 }
244 }
245}
246
247impl<M> VectorStoreIndex for S3VectorsVectorStore<M>
248where
249 M: EmbeddingModel,
250{
251 type Filter = S3SearchFilter;
252
253 async fn top_n<T: for<'a> serde::Deserialize<'a> + Send>(
254 &self,
255 req: VectorSearchRequest<S3SearchFilter>,
256 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
257 if req.samples() > i32::MAX as u64 {
258 return Err(VectorStoreError::DatastoreError(format!("The number of samples to return with the `rig` AWS S3Vectors integration cannot be higher than {}", i32::MAX).into()));
259 }
260
261 let embedding = self
262 .embedding_model
263 .embed_text(req.query())
264 .await?
265 .vec
266 .into_iter()
267 .map(|x| x as f32)
268 .collect();
269
270 let mut query_builder = self
271 .client
272 .query_vectors()
273 .query_vector(VectorData::Float32(embedding))
274 .top_k(req.samples() as i32)
275 .return_distance(true)
276 .return_metadata(true)
277 .vector_bucket_name(self.bucket_name())
278 .index_name(self.index_name());
279
280 if let Some(filter) = req.filter() {
281 query_builder = query_builder.filter(filter.inner().clone())
282 }
283
284 let query = query_builder.send().await.unwrap();
285
286 let res: Vec<(f64, String, T)> = query
287 .vectors
288 .into_iter()
289 .filter(|vector| {
290 req.threshold().is_none_or(|threshold| {
291 (vector
292 .distance()
293 .expect("vector distance should always exist") as f64)
294 >= threshold
295 })
296 })
297 .map(|x| {
298 let distance = x.distance.expect("vector distance should always exist") as f64;
299 let val =
300 document_to_json_value(&x.metadata.expect("metadata should always exist"));
301
302 let metadata: T = serde_json::from_value(val)
303 .expect("converting JSON from S3Vectors to valid T should always work");
304
305 (distance, x.key, metadata)
306 })
307 .collect();
308
309 Ok(res)
310 }
311
312 async fn top_n_ids(
313 &self,
314 req: VectorSearchRequest<S3SearchFilter>,
315 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
316 if req.samples() > i32::MAX as u64 {
317 return Err(VectorStoreError::DatastoreError(format!("The number of samples to return with the `rig` AWS S3Vectors integration cannot be higher than {}", i32::MAX).into()));
318 }
319
320 let embedding = self
321 .embedding_model
322 .embed_text(req.query())
323 .await?
324 .vec
325 .into_iter()
326 .map(|x| x as f32)
327 .collect();
328
329 let mut query_builder = self
330 .client
331 .query_vectors()
332 .query_vector(VectorData::Float32(embedding))
333 .top_k(req.samples() as i32)
334 .return_distance(true)
335 .vector_bucket_name(self.bucket_name())
336 .index_name(self.index_name());
337
338 if let Some(filter) = req.filter() {
339 query_builder = query_builder.filter(filter.inner().clone())
340 }
341
342 let query = query_builder.send().await.unwrap();
343
344 let res: Vec<(f64, String)> = query
345 .vectors
346 .into_iter()
347 .filter(|vector| {
348 req.threshold().is_none_or(|threshold| {
349 (vector
350 .distance()
351 .expect("vector distance should always exist") as f64)
352 >= threshold
353 })
354 })
355 .map(|x| {
356 let distance = x.distance.expect("vector distance should always exist") as f64;
357
358 (distance, x.key)
359 })
360 .collect();
361
362 Ok(res)
363 }
364}