rig/vector_store/
mod.rs

1use futures::future::BoxFuture;
2use reqwest::StatusCode;
3use serde::Deserialize;
4use serde::Serialize;
5use serde_json::Value;
6
7use crate::embeddings::EmbeddingError;
8use crate::{Embed, OneOrMany, embeddings::Embedding};
9
10pub mod in_memory_store;
11
12#[derive(Debug, thiserror::Error)]
13pub enum VectorStoreError {
14    #[error("Embedding error: {0}")]
15    EmbeddingError(#[from] EmbeddingError),
16
17    /// Json error (e.g.: serialization, deserialization, etc.)
18    #[error("Json error: {0}")]
19    JsonError(#[from] serde_json::Error),
20
21    #[error("Datastore error: {0}")]
22    DatastoreError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
23
24    #[error("Missing Id: {0}")]
25    MissingIdError(String),
26
27    #[error("HTTP request error: {0}")]
28    ReqwestError(#[from] reqwest::Error),
29
30    #[error("External call to API returned an error. Error code: {0} Message: {1}")]
31    ExternalAPIError(StatusCode, String),
32}
33
34/// Trait for inserting documents into a vector store.
35pub trait InsertDocuments: Send + Sync {
36    /// Insert documents into the vector store.
37    ///
38    fn insert_documents<Doc: Serialize + Embed + Send>(
39        &self,
40        documents: Vec<(Doc, OneOrMany<Embedding>)>,
41    ) -> impl std::future::Future<Output = Result<(), VectorStoreError>> + Send;
42}
43
44/// Trait for vector store indexes
45pub trait VectorStoreIndex: Send + Sync {
46    /// Get the top n documents based on the distance to the given query.
47    /// The result is a list of tuples of the form (score, id, document)
48    fn top_n<T: for<'a> Deserialize<'a> + Send>(
49        &self,
50        query: &str,
51        n: usize,
52    ) -> impl std::future::Future<Output = Result<Vec<(f64, String, T)>, VectorStoreError>> + Send;
53
54    /// Same as `top_n` but returns the document ids only.
55    fn top_n_ids(
56        &self,
57        query: &str,
58        n: usize,
59    ) -> impl std::future::Future<Output = Result<Vec<(f64, String)>, VectorStoreError>> + Send;
60}
61
62pub type TopNResults = Result<Vec<(f64, String, Value)>, VectorStoreError>;
63
64pub trait VectorStoreIndexDyn: Send + Sync {
65    fn top_n<'a>(&'a self, query: &'a str, n: usize) -> BoxFuture<'a, TopNResults>;
66
67    fn top_n_ids<'a>(
68        &'a self,
69        query: &'a str,
70        n: usize,
71    ) -> BoxFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>>;
72}
73
74impl<I: VectorStoreIndex> VectorStoreIndexDyn for I {
75    fn top_n<'a>(
76        &'a self,
77        query: &'a str,
78        n: usize,
79    ) -> BoxFuture<'a, Result<Vec<(f64, String, Value)>, VectorStoreError>> {
80        Box::pin(async move {
81            Ok(self
82                .top_n::<serde_json::Value>(query, n)
83                .await?
84                .into_iter()
85                .map(|(score, id, doc)| (score, id, prune_document(doc).unwrap_or_default()))
86                .collect::<Vec<_>>())
87        })
88    }
89
90    fn top_n_ids<'a>(
91        &'a self,
92        query: &'a str,
93        n: usize,
94    ) -> BoxFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>> {
95        Box::pin(self.top_n_ids(query, n))
96    }
97}
98
99fn prune_document(document: serde_json::Value) -> Option<serde_json::Value> {
100    match document {
101        Value::Object(mut map) => {
102            let new_map = map
103                .iter_mut()
104                .filter_map(|(key, value)| {
105                    prune_document(value.take()).map(|value| (key.clone(), value))
106                })
107                .collect::<serde_json::Map<_, _>>();
108
109            Some(Value::Object(new_map))
110        }
111        Value::Array(vec) if vec.len() > 400 => None,
112        Value::Array(vec) => Some(Value::Array(
113            vec.into_iter().filter_map(prune_document).collect(),
114        )),
115        Value::Number(num) => Some(Value::Number(num)),
116        Value::String(s) => Some(Value::String(s)),
117        Value::Bool(b) => Some(Value::Bool(b)),
118        Value::Null => Some(Value::Null),
119    }
120}