rig/vector_store/
mod.rs

1use futures::future::BoxFuture;
2pub use request::VectorSearchRequest;
3use reqwest::StatusCode;
4use serde::Deserialize;
5use serde::Serialize;
6use serde_json::Value;
7
8use crate::embeddings::EmbeddingError;
9use crate::{Embed, OneOrMany, embeddings::Embedding};
10
11pub mod in_memory_store;
12pub mod request;
13
14#[derive(Debug, thiserror::Error)]
15pub enum VectorStoreError {
16    #[error("Embedding error: {0}")]
17    EmbeddingError(#[from] EmbeddingError),
18
19    /// Json error (e.g.: serialization, deserialization, etc.)
20    #[error("Json error: {0}")]
21    JsonError(#[from] serde_json::Error),
22
23    #[error("Datastore error: {0}")]
24    DatastoreError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
25
26    #[error("Missing Id: {0}")]
27    MissingIdError(String),
28
29    #[error("HTTP request error: {0}")]
30    ReqwestError(#[from] reqwest::Error),
31
32    #[error("External call to API returned an error. Error code: {0} Message: {1}")]
33    ExternalAPIError(StatusCode, String),
34
35    #[error("Error while building VectorSearchRequest: {0}")]
36    BuilderError(String),
37}
38
39/// Trait for inserting documents into a vector store.
40pub trait InsertDocuments: Send + Sync {
41    /// Insert documents into the vector store.
42    ///
43    fn insert_documents<Doc: Serialize + Embed + Send>(
44        &self,
45        documents: Vec<(Doc, OneOrMany<Embedding>)>,
46    ) -> impl std::future::Future<Output = Result<(), VectorStoreError>> + Send;
47}
48
49/// Trait for vector store indexes
50pub trait VectorStoreIndex: Send + Sync {
51    /// Get the top n documents based on the distance to the given query.
52    /// The result is a list of tuples of the form (score, id, document)
53    fn top_n<T: for<'a> Deserialize<'a> + Send>(
54        &self,
55        req: VectorSearchRequest,
56    ) -> impl std::future::Future<Output = Result<Vec<(f64, String, T)>, VectorStoreError>> + Send;
57
58    /// Same as `top_n` but returns the document ids only.
59    fn top_n_ids(
60        &self,
61        req: VectorSearchRequest,
62    ) -> impl std::future::Future<Output = Result<Vec<(f64, String)>, VectorStoreError>> + Send;
63}
64
65pub type TopNResults = Result<Vec<(f64, String, Value)>, VectorStoreError>;
66
67pub trait VectorStoreIndexDyn: Send + Sync {
68    fn top_n<'a>(&'a self, req: VectorSearchRequest) -> BoxFuture<'a, TopNResults>;
69
70    fn top_n_ids<'a>(
71        &'a self,
72        req: VectorSearchRequest,
73    ) -> BoxFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>>;
74}
75
76impl<I: VectorStoreIndex> VectorStoreIndexDyn for I {
77    fn top_n<'a>(
78        &'a self,
79        req: VectorSearchRequest,
80    ) -> BoxFuture<'a, Result<Vec<(f64, String, Value)>, VectorStoreError>> {
81        Box::pin(async move {
82            Ok(self
83                .top_n::<serde_json::Value>(req)
84                .await?
85                .into_iter()
86                .map(|(score, id, doc)| (score, id, prune_document(doc).unwrap_or_default()))
87                .collect::<Vec<_>>())
88        })
89    }
90
91    fn top_n_ids<'a>(
92        &'a self,
93        req: VectorSearchRequest,
94    ) -> BoxFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>> {
95        Box::pin(self.top_n_ids(req))
96    }
97}
98
99fn prune_document(document: serde_json::Value) -> Option<serde_json::Value> {
100    match document {
101        Value::Object(mut map) => {
102            let new_map = map
103                .iter_mut()
104                .filter_map(|(key, value)| {
105                    prune_document(value.take()).map(|value| (key.clone(), value))
106                })
107                .collect::<serde_json::Map<_, _>>();
108
109            Some(Value::Object(new_map))
110        }
111        Value::Array(vec) if vec.len() > 400 => None,
112        Value::Array(vec) => Some(Value::Array(
113            vec.into_iter().filter_map(prune_document).collect(),
114        )),
115        Value::Number(num) => Some(Value::Number(num)),
116        Value::String(s) => Some(Value::String(s)),
117        Value::Bool(b) => Some(Value::Bool(b)),
118        Value::Null => Some(Value::Null),
119    }
120}