rig/vector_store/
mod.rs

1//! Vector store abstractions for semantic search and retrieval.
2//!
3//! # Core Traits
4//!
5//! - [`VectorStoreIndex`]: Query a vector store for similar documents.
6//! - [`InsertDocuments`]: Insert documents and their embeddings.
7//! - [`VectorStoreIndexDyn`]: Type-erased version for dynamic contexts.
8//!
9//! Use [`VectorSearchRequest`] to build queries. See [`request`] for filtering.
10//!
11//! Types implementing [`VectorStoreIndex`] automatically implement [`Tool`].
12
13pub use request::VectorSearchRequest;
14use reqwest::StatusCode;
15use serde::{Deserialize, Serialize};
16use serde_json::{Value, json};
17
18use crate::{
19    Embed, OneOrMany,
20    completion::ToolDefinition,
21    embeddings::{Embedding, EmbeddingError},
22    tool::Tool,
23    vector_store::request::{Filter, FilterError, SearchFilter},
24    wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync},
25};
26
27pub mod builder;
28pub mod in_memory_store;
29pub mod lsh;
30pub mod request;
31
32/// Errors from vector store operations.
33#[derive(Debug, thiserror::Error)]
34pub enum VectorStoreError {
35    #[error("Embedding error: {0}")]
36    EmbeddingError(#[from] EmbeddingError),
37
38    #[error("Json error: {0}")]
39    JsonError(#[from] serde_json::Error),
40
41    #[cfg(not(target_family = "wasm"))]
42    #[error("Datastore error: {0}")]
43    DatastoreError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
44
45    #[error("Filter error: {0}")]
46    FilterError(#[from] FilterError),
47
48    #[cfg(target_family = "wasm")]
49    #[error("Datastore error: {0}")]
50    DatastoreError(#[from] Box<dyn std::error::Error + 'static>),
51
52    #[error("Missing Id: {0}")]
53    MissingIdError(String),
54
55    #[error("HTTP request error: {0}")]
56    ReqwestError(#[from] reqwest::Error),
57
58    #[error("External call to API returned an error. Error code: {0} Message: {1}")]
59    ExternalAPIError(StatusCode, String),
60
61    #[error("Error while building VectorSearchRequest: {0}")]
62    BuilderError(String),
63}
64
65/// Trait for inserting documents and embeddings into a vector store.
66pub trait InsertDocuments: WasmCompatSend + WasmCompatSync {
67    fn insert_documents<Doc: Serialize + Embed + WasmCompatSend>(
68        &self,
69        documents: Vec<(Doc, OneOrMany<Embedding>)>,
70    ) -> impl std::future::Future<Output = Result<(), VectorStoreError>> + WasmCompatSend;
71}
72
73/// Trait for querying a vector store by similarity.
74pub trait VectorStoreIndex: WasmCompatSend + WasmCompatSync {
75    /// The filter type for this backend.
76    type Filter: SearchFilter + WasmCompatSend + WasmCompatSync;
77
78    /// Returns the top N most similar documents as `(score, id, document)` tuples.
79    fn top_n<T: for<'a> Deserialize<'a> + WasmCompatSend>(
80        &self,
81        req: VectorSearchRequest<Self::Filter>,
82    ) -> impl std::future::Future<Output = Result<Vec<(f64, String, T)>, VectorStoreError>>
83    + WasmCompatSend;
84
85    /// Returns the top N most similar document IDs as `(score, id)` tuples.
86    fn top_n_ids(
87        &self,
88        req: VectorSearchRequest<Self::Filter>,
89    ) -> impl std::future::Future<Output = Result<Vec<(f64, String)>, VectorStoreError>> + WasmCompatSend;
90}
91
92pub type TopNResults = Result<Vec<(f64, String, Value)>, VectorStoreError>;
93
94/// Type-erased [`VectorStoreIndex`] for dynamic dispatch.
95pub trait VectorStoreIndexDyn: WasmCompatSend + WasmCompatSync {
96    fn top_n<'a>(
97        &'a self,
98        req: VectorSearchRequest<Filter<serde_json::Value>>,
99    ) -> WasmBoxedFuture<'a, TopNResults>;
100
101    fn top_n_ids<'a>(
102        &'a self,
103        req: VectorSearchRequest<Filter<serde_json::Value>>,
104    ) -> WasmBoxedFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>>;
105}
106
107impl<I: VectorStoreIndex<Filter = F>, F> VectorStoreIndexDyn for I
108where
109    F: std::fmt::Debug
110        + Clone
111        + SearchFilter<Value = serde_json::Value>
112        + WasmCompatSend
113        + WasmCompatSync
114        + Serialize
115        + for<'de> Deserialize<'de>
116        + 'static,
117{
118    fn top_n<'a>(
119        &'a self,
120        req: VectorSearchRequest<Filter<serde_json::Value>>,
121    ) -> WasmBoxedFuture<'a, TopNResults> {
122        let req = req.map_filter(Filter::interpret);
123
124        Box::pin(async move {
125            Ok(self
126                .top_n::<serde_json::Value>(req)
127                .await?
128                .into_iter()
129                .map(|(score, id, doc)| (score, id, prune_document(doc).unwrap_or_default()))
130                .collect::<Vec<_>>())
131        })
132    }
133
134    fn top_n_ids<'a>(
135        &'a self,
136        req: VectorSearchRequest<Filter<serde_json::Value>>,
137    ) -> WasmBoxedFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>> {
138        let req = req.map_filter(Filter::interpret);
139
140        Box::pin(self.top_n_ids(req))
141    }
142}
143
144fn prune_document(document: serde_json::Value) -> Option<serde_json::Value> {
145    match document {
146        Value::Object(mut map) => {
147            let new_map = map
148                .iter_mut()
149                .filter_map(|(key, value)| {
150                    prune_document(value.take()).map(|value| (key.clone(), value))
151                })
152                .collect::<serde_json::Map<_, _>>();
153
154            Some(Value::Object(new_map))
155        }
156        Value::Array(vec) if vec.len() > 400 => None,
157        Value::Array(vec) => Some(Value::Array(
158            vec.into_iter().filter_map(prune_document).collect(),
159        )),
160        Value::Number(num) => Some(Value::Number(num)),
161        Value::String(s) => Some(Value::String(s)),
162        Value::Bool(b) => Some(Value::Bool(b)),
163        Value::Null => Some(Value::Null),
164    }
165}
166
167/// The output of vector store queries invoked via [`Tool`]
168#[derive(Serialize, Deserialize, Debug)]
169pub struct VectorStoreOutput {
170    pub score: f64,
171    pub id: String,
172    pub document: Value,
173}
174
175impl<T, F> Tool for T
176where
177    F: SearchFilter<Value = serde_json::Value>
178        + WasmCompatSend
179        + WasmCompatSync
180        + for<'de> Deserialize<'de>,
181    T: VectorStoreIndex<Filter = F>,
182{
183    const NAME: &'static str = "search_vector_store";
184
185    type Error = VectorStoreError;
186    type Args = VectorSearchRequest<F>;
187    type Output = Vec<VectorStoreOutput>;
188
189    async fn definition(&self, _prompt: String) -> ToolDefinition {
190        ToolDefinition {
191            name: Self::NAME.to_string(),
192            description:
193                "Retrieves the most relevant documents from a vector store based on a query."
194                    .to_string(),
195            parameters: json!({
196                "type": "object",
197                "properties": {
198                    "query": {
199                        "type": "string",
200                        "description": "The query string to search for relevant documents in the vector store."
201                    },
202                    "samples": {
203                        "type": "integer",
204                        "description": "The maxinum number of samples / documents to retrieve.",
205                        "default": 5,
206                        "minimum": 1
207                    },
208                    "threshold": {
209                        "type": "number",
210                        "description": "Similarity search threshold. If present, any result with a distance less than this may be omitted from the final result."
211                    }
212                },
213                "required": ["query", "samples"]
214            }),
215        }
216    }
217
218    async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
219        let results = self.top_n(args).await?;
220        Ok(results
221            .into_iter()
222            .map(|(score, id, document)| VectorStoreOutput {
223                score,
224                id,
225                document,
226            })
227            .collect())
228    }
229}
230
231/// Index strategy for the super::InMemoryVectorStore
232#[derive(Clone, Debug)]
233pub enum IndexStrategy {
234    /// Checks all documents in the vector store to find the most relevant documents.
235    BruteForce,
236
237    /// Uses LSH to find candidates then computes exact distances.
238    LSH {
239        /// Number of tables to use for LSH.
240        num_tables: usize,
241        /// Number of hyperplanes to use for LSH.
242        num_hyperplanes: usize,
243    },
244}
245
246impl Default for IndexStrategy {
247    fn default() -> Self {
248        Self::BruteForce
249    }
250}