rig/vector_store/
mod.rs

1use futures::future::BoxFuture;
2use serde::Deserialize;
3use serde_json::Value;
4
5use crate::embeddings::EmbeddingError;
6
7pub mod in_memory_store;
8
9#[derive(Debug, thiserror::Error)]
10pub enum VectorStoreError {
11    #[error("Embedding error: {0}")]
12    EmbeddingError(#[from] EmbeddingError),
13
14    /// Json error (e.g.: serialization, deserialization, etc.)
15    #[error("Json error: {0}")]
16    JsonError(#[from] serde_json::Error),
17
18    #[error("Datastore error: {0}")]
19    DatastoreError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
20
21    #[error("Missing Id: {0}")]
22    MissingIdError(String),
23}
24
25/// Trait for vector store indexes
26pub trait VectorStoreIndex: Send + Sync {
27    /// Get the top n documents based on the distance to the given query.
28    /// The result is a list of tuples of the form (score, id, document)
29    fn top_n<T: for<'a> Deserialize<'a> + Send>(
30        &self,
31        query: &str,
32        n: usize,
33    ) -> impl std::future::Future<Output = Result<Vec<(f64, String, T)>, VectorStoreError>> + Send;
34
35    /// Same as `top_n` but returns the document ids only.
36    fn top_n_ids(
37        &self,
38        query: &str,
39        n: usize,
40    ) -> impl std::future::Future<Output = Result<Vec<(f64, String)>, VectorStoreError>> + Send;
41}
42
43pub type TopNResults = Result<Vec<(f64, String, Value)>, VectorStoreError>;
44
45pub trait VectorStoreIndexDyn: Send + Sync {
46    fn top_n<'a>(&'a self, query: &'a str, n: usize) -> BoxFuture<'a, TopNResults>;
47
48    fn top_n_ids<'a>(
49        &'a self,
50        query: &'a str,
51        n: usize,
52    ) -> BoxFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>>;
53}
54
55impl<I: VectorStoreIndex> VectorStoreIndexDyn for I {
56    fn top_n<'a>(
57        &'a self,
58        query: &'a str,
59        n: usize,
60    ) -> BoxFuture<'a, Result<Vec<(f64, String, Value)>, VectorStoreError>> {
61        Box::pin(async move {
62            Ok(self
63                .top_n::<serde_json::Value>(query, n)
64                .await?
65                .into_iter()
66                .map(|(score, id, doc)| (score, id, prune_document(doc).unwrap_or_default()))
67                .collect::<Vec<_>>())
68        })
69    }
70
71    fn top_n_ids<'a>(
72        &'a self,
73        query: &'a str,
74        n: usize,
75    ) -> BoxFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>> {
76        Box::pin(self.top_n_ids(query, n))
77    }
78}
79
80fn prune_document(document: serde_json::Value) -> Option<serde_json::Value> {
81    match document {
82        Value::Object(mut map) => {
83            let new_map = map
84                .iter_mut()
85                .filter_map(|(key, value)| {
86                    prune_document(value.take()).map(|value| (key.clone(), value))
87                })
88                .collect::<serde_json::Map<_, _>>();
89
90            Some(Value::Object(new_map))
91        }
92        Value::Array(vec) if vec.len() > 400 => None,
93        Value::Array(vec) => Some(Value::Array(
94            vec.into_iter().filter_map(prune_document).collect(),
95        )),
96        Value::Number(num) => Some(Value::Number(num)),
97        Value::String(s) => Some(Value::String(s)),
98        Value::Bool(b) => Some(Value::Bool(b)),
99        Value::Null => Some(Value::Null),
100    }
101}