rig/vector_store/
mod.rs

1pub use request::VectorSearchRequest;
2use reqwest::StatusCode;
3use serde::{Deserialize, Serialize};
4use serde_json::{Value, json};
5
6use crate::{
7    Embed, OneOrMany,
8    completion::ToolDefinition,
9    embeddings::{Embedding, EmbeddingError},
10    tool::Tool,
11    vector_store::request::{Filter, FilterError, SearchFilter},
12    wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync},
13};
14
15pub mod builder;
16pub mod in_memory_store;
17pub mod lsh;
18pub mod request;
19
20#[derive(Debug, thiserror::Error)]
21pub enum VectorStoreError {
22    #[error("Embedding error: {0}")]
23    EmbeddingError(#[from] EmbeddingError),
24
25    /// Json error (e.g.: serialization, deserialization, etc.)
26    #[error("Json error: {0}")]
27    JsonError(#[from] serde_json::Error),
28
29    #[cfg(not(target_family = "wasm"))]
30    #[error("Datastore error: {0}")]
31    DatastoreError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
32
33    #[error("Filter error: {0}")]
34    FilterError(#[from] FilterError),
35
36    #[cfg(target_family = "wasm")]
37    #[error("Datastore error: {0}")]
38    DatastoreError(#[from] Box<dyn std::error::Error + 'static>),
39
40    #[error("Missing Id: {0}")]
41    MissingIdError(String),
42
43    #[error("HTTP request error: {0}")]
44    ReqwestError(#[from] reqwest::Error),
45
46    #[error("External call to API returned an error. Error code: {0} Message: {1}")]
47    ExternalAPIError(StatusCode, String),
48
49    #[error("Error while building VectorSearchRequest: {0}")]
50    BuilderError(String),
51}
52
53/// Trait for inserting documents into a vector store.
54pub trait InsertDocuments: WasmCompatSend + WasmCompatSync {
55    /// Insert documents into the vector store.
56    ///
57    fn insert_documents<Doc: Serialize + Embed + WasmCompatSend>(
58        &self,
59        documents: Vec<(Doc, OneOrMany<Embedding>)>,
60    ) -> impl std::future::Future<Output = Result<(), VectorStoreError>> + WasmCompatSend;
61}
62
63/// Trait for vector store indexes
64pub trait VectorStoreIndex: WasmCompatSend + WasmCompatSync {
65    type Filter: SearchFilter + WasmCompatSend + WasmCompatSync;
66
67    /// Get the top n documents based on the distance to the given query.
68    /// The result is a list of tuples of the form (score, id, document)
69    fn top_n<T: for<'a> Deserialize<'a> + WasmCompatSend>(
70        &self,
71        req: VectorSearchRequest<Self::Filter>,
72    ) -> impl std::future::Future<Output = Result<Vec<(f64, String, T)>, VectorStoreError>>
73    + WasmCompatSend;
74
75    /// Same as `top_n` but returns the document ids only.
76    fn top_n_ids(
77        &self,
78        req: VectorSearchRequest<Self::Filter>,
79    ) -> impl std::future::Future<Output = Result<Vec<(f64, String)>, VectorStoreError>> + WasmCompatSend;
80}
81
82pub type TopNResults = Result<Vec<(f64, String, Value)>, VectorStoreError>;
83
84pub trait VectorStoreIndexDyn: WasmCompatSend + WasmCompatSync {
85    fn top_n<'a>(
86        &'a self,
87        req: VectorSearchRequest<Filter<serde_json::Value>>,
88    ) -> WasmBoxedFuture<'a, TopNResults>;
89
90    fn top_n_ids<'a>(
91        &'a self,
92        req: VectorSearchRequest<Filter<serde_json::Value>>,
93    ) -> WasmBoxedFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>>;
94}
95
96impl<I: VectorStoreIndex<Filter = F>, F> VectorStoreIndexDyn for I
97where
98    F: std::fmt::Debug
99        + Clone
100        + SearchFilter<Value = serde_json::Value>
101        + WasmCompatSend
102        + WasmCompatSync
103        + Serialize
104        + for<'de> Deserialize<'de>
105        + 'static,
106{
107    fn top_n<'a>(
108        &'a self,
109        req: VectorSearchRequest<Filter<serde_json::Value>>,
110    ) -> WasmBoxedFuture<'a, TopNResults> {
111        let req = req.map_filter(Filter::interpret);
112
113        Box::pin(async move {
114            Ok(self
115                .top_n::<serde_json::Value>(req)
116                .await?
117                .into_iter()
118                .map(|(score, id, doc)| (score, id, prune_document(doc).unwrap_or_default()))
119                .collect::<Vec<_>>())
120        })
121    }
122
123    fn top_n_ids<'a>(
124        &'a self,
125        req: VectorSearchRequest<Filter<serde_json::Value>>,
126    ) -> WasmBoxedFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>> {
127        let req = req.map_filter(Filter::interpret);
128
129        Box::pin(self.top_n_ids(req))
130    }
131}
132
133fn prune_document(document: serde_json::Value) -> Option<serde_json::Value> {
134    match document {
135        Value::Object(mut map) => {
136            let new_map = map
137                .iter_mut()
138                .filter_map(|(key, value)| {
139                    prune_document(value.take()).map(|value| (key.clone(), value))
140                })
141                .collect::<serde_json::Map<_, _>>();
142
143            Some(Value::Object(new_map))
144        }
145        Value::Array(vec) if vec.len() > 400 => None,
146        Value::Array(vec) => Some(Value::Array(
147            vec.into_iter().filter_map(prune_document).collect(),
148        )),
149        Value::Number(num) => Some(Value::Number(num)),
150        Value::String(s) => Some(Value::String(s)),
151        Value::Bool(b) => Some(Value::Bool(b)),
152        Value::Null => Some(Value::Null),
153    }
154}
155
156#[derive(Serialize, Deserialize, Debug)]
157pub struct VectorStoreOutput {
158    pub score: f64,
159    pub id: String,
160    pub document: Value,
161}
162
163impl<T, F> Tool for T
164where
165    F: SearchFilter<Value = serde_json::Value>
166        + WasmCompatSend
167        + WasmCompatSync
168        + for<'de> Deserialize<'de>,
169    T: VectorStoreIndex<Filter = F>,
170{
171    const NAME: &'static str = "search_vector_store";
172
173    type Error = VectorStoreError;
174    type Args = VectorSearchRequest<F>;
175    type Output = Vec<VectorStoreOutput>;
176
177    async fn definition(&self, _prompt: String) -> ToolDefinition {
178        ToolDefinition {
179            name: Self::NAME.to_string(),
180            description:
181                "Retrieves the most relevant documents from a vector store based on a query."
182                    .to_string(),
183            parameters: json!({
184                "type": "object",
185                "properties": {
186                    "query": {
187                        "type": "string",
188                        "description": "The query string to search for relevant documents in the vector store."
189                    },
190                    "samples": {
191                        "type": "integer",
192                        "description": "The maxinum number of samples / documents to retrieve.",
193                        "default": 5,
194                        "minimum": 1
195                    },
196                    "threshold": {
197                        "type": "number",
198                        "description": "Similarity search threshold. If present, any result with a distance less than this may be omitted from the final result."
199                    }
200                },
201                "required": ["query", "samples"]
202            }),
203        }
204    }
205
206    async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
207        let results = self.top_n(args).await?;
208        Ok(results
209            .into_iter()
210            .map(|(score, id, document)| VectorStoreOutput {
211                score,
212                id,
213                document,
214            })
215            .collect())
216    }
217}
218
219/// Index strategy for the super::InMemoryVectorStore
220#[derive(Clone, Debug)]
221pub enum IndexStrategy {
222    /// Checks all documents in the vector store to find the most relevant documents.
223    BruteForce,
224
225    /// Uses LSH to find candidates then computes exact distances.
226    LSH {
227        /// Number of tables to use for LSH.
228        num_tables: usize,
229        /// Number of hyperplanes to use for LSH.
230        num_hyperplanes: usize,
231    },
232}
233
234impl Default for IndexStrategy {
235    fn default() -> Self {
236        Self::BruteForce
237    }
238}