rig/vector_store/
mod.rs

1pub use request::VectorSearchRequest;
2use reqwest::StatusCode;
3use serde::{Deserialize, Serialize};
4use serde_json::{Value, json};
5
6use crate::{
7    Embed, OneOrMany,
8    completion::ToolDefinition,
9    embeddings::{Embedding, EmbeddingError},
10    tool::Tool,
11    vector_store::request::{Filter, FilterError, SearchFilter},
12    wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync},
13};
14
15pub mod in_memory_store;
16pub mod request;
17
18#[derive(Debug, thiserror::Error)]
19pub enum VectorStoreError {
20    #[error("Embedding error: {0}")]
21    EmbeddingError(#[from] EmbeddingError),
22
23    /// Json error (e.g.: serialization, deserialization, etc.)
24    #[error("Json error: {0}")]
25    JsonError(#[from] serde_json::Error),
26
27    #[cfg(not(target_family = "wasm"))]
28    #[error("Datastore error: {0}")]
29    DatastoreError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
30
31    #[error("Filter error: {0}")]
32    FilterError(#[from] FilterError),
33
34    #[cfg(target_family = "wasm")]
35    #[error("Datastore error: {0}")]
36    DatastoreError(#[from] Box<dyn std::error::Error + 'static>),
37
38    #[error("Missing Id: {0}")]
39    MissingIdError(String),
40
41    #[error("HTTP request error: {0}")]
42    ReqwestError(#[from] reqwest::Error),
43
44    #[error("External call to API returned an error. Error code: {0} Message: {1}")]
45    ExternalAPIError(StatusCode, String),
46
47    #[error("Error while building VectorSearchRequest: {0}")]
48    BuilderError(String),
49}
50
51/// Trait for inserting documents into a vector store.
52pub trait InsertDocuments: WasmCompatSend + WasmCompatSync {
53    /// Insert documents into the vector store.
54    ///
55    fn insert_documents<Doc: Serialize + Embed + WasmCompatSend>(
56        &self,
57        documents: Vec<(Doc, OneOrMany<Embedding>)>,
58    ) -> impl std::future::Future<Output = Result<(), VectorStoreError>> + WasmCompatSend;
59}
60
61/// Trait for vector store indexes
62pub trait VectorStoreIndex: WasmCompatSend + WasmCompatSync {
63    type Filter: SearchFilter + WasmCompatSend + WasmCompatSync;
64
65    /// Get the top n documents based on the distance to the given query.
66    /// The result is a list of tuples of the form (score, id, document)
67    fn top_n<T: for<'a> Deserialize<'a> + WasmCompatSend>(
68        &self,
69        req: VectorSearchRequest<Self::Filter>,
70    ) -> impl std::future::Future<Output = Result<Vec<(f64, String, T)>, VectorStoreError>>
71    + WasmCompatSend;
72
73    /// Same as `top_n` but returns the document ids only.
74    fn top_n_ids(
75        &self,
76        req: VectorSearchRequest<Self::Filter>,
77    ) -> impl std::future::Future<Output = Result<Vec<(f64, String)>, VectorStoreError>> + WasmCompatSend;
78}
79
80pub type TopNResults = Result<Vec<(f64, String, Value)>, VectorStoreError>;
81
82pub trait VectorStoreIndexDyn: WasmCompatSend + WasmCompatSync {
83    fn top_n<'a>(
84        &'a self,
85        req: VectorSearchRequest<Filter<serde_json::Value>>,
86    ) -> WasmBoxedFuture<'a, TopNResults>;
87
88    fn top_n_ids<'a>(
89        &'a self,
90        req: VectorSearchRequest<Filter<serde_json::Value>>,
91    ) -> WasmBoxedFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>>;
92}
93
94impl<I: VectorStoreIndex<Filter = F>, F> VectorStoreIndexDyn for I
95where
96    F: std::fmt::Debug
97        + Clone
98        + SearchFilter<Value = serde_json::Value>
99        + WasmCompatSend
100        + WasmCompatSync
101        + Serialize
102        + for<'de> Deserialize<'de>
103        + 'static,
104{
105    fn top_n<'a>(
106        &'a self,
107        req: VectorSearchRequest<Filter<serde_json::Value>>,
108    ) -> WasmBoxedFuture<'a, TopNResults> {
109        let req = req.map_filter(Filter::interpret);
110
111        Box::pin(async move {
112            Ok(self
113                .top_n::<serde_json::Value>(req)
114                .await?
115                .into_iter()
116                .map(|(score, id, doc)| (score, id, prune_document(doc).unwrap_or_default()))
117                .collect::<Vec<_>>())
118        })
119    }
120
121    fn top_n_ids<'a>(
122        &'a self,
123        req: VectorSearchRequest<Filter<serde_json::Value>>,
124    ) -> WasmBoxedFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>> {
125        let req = req.map_filter(Filter::interpret);
126
127        Box::pin(self.top_n_ids(req))
128    }
129}
130
131fn prune_document(document: serde_json::Value) -> Option<serde_json::Value> {
132    match document {
133        Value::Object(mut map) => {
134            let new_map = map
135                .iter_mut()
136                .filter_map(|(key, value)| {
137                    prune_document(value.take()).map(|value| (key.clone(), value))
138                })
139                .collect::<serde_json::Map<_, _>>();
140
141            Some(Value::Object(new_map))
142        }
143        Value::Array(vec) if vec.len() > 400 => None,
144        Value::Array(vec) => Some(Value::Array(
145            vec.into_iter().filter_map(prune_document).collect(),
146        )),
147        Value::Number(num) => Some(Value::Number(num)),
148        Value::String(s) => Some(Value::String(s)),
149        Value::Bool(b) => Some(Value::Bool(b)),
150        Value::Null => Some(Value::Null),
151    }
152}
153
154#[derive(Serialize, Deserialize, Debug)]
155pub struct VectorStoreOutput {
156    pub score: f64,
157    pub id: String,
158    pub document: Value,
159}
160
161impl<T, F> Tool for T
162where
163    F: SearchFilter<Value = serde_json::Value>
164        + WasmCompatSend
165        + WasmCompatSync
166        + for<'de> Deserialize<'de>,
167    T: VectorStoreIndex<Filter = F>,
168{
169    const NAME: &'static str = "search_vector_store";
170
171    type Error = VectorStoreError;
172    type Args = VectorSearchRequest<F>;
173    type Output = Vec<VectorStoreOutput>;
174
175    async fn definition(&self, _prompt: String) -> ToolDefinition {
176        ToolDefinition {
177            name: Self::NAME.to_string(),
178            description:
179                "Retrieves the most relevant documents from a vector store based on a query."
180                    .to_string(),
181            parameters: json!({
182                "type": "object",
183                "properties": {
184                    "query": {
185                        "type": "string",
186                        "description": "The query string to search for relevant documents in the vector store."
187                    },
188                    "samples": {
189                        "type": "integer",
190                        "description": "The maxinum number of samples / documents to retrieve.",
191                        "default": 5,
192                        "minimum": 1
193                    },
194                    "threshold": {
195                        "type": "number",
196                        "description": "Similarity search threshold. If present, any result with a distance less than this may be omitted from the final result."
197                    }
198                },
199                "required": ["query", "samples"]
200            }),
201        }
202    }
203
204    async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
205        let results = self.top_n(args).await?;
206        Ok(results
207            .into_iter()
208            .map(|(score, id, document)| VectorStoreOutput {
209                score,
210                id,
211                document,
212            })
213            .collect())
214    }
215}