rig/vector_store/
mod.rs

1pub use request::VectorSearchRequest;
2use reqwest::StatusCode;
3use serde::{Deserialize, Serialize};
4use serde_json::{Value, json};
5
6use crate::{
7    Embed, OneOrMany,
8    completion::ToolDefinition,
9    embeddings::{Embedding, EmbeddingError},
10    tool::Tool,
11    wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync},
12};
13
14pub mod in_memory_store;
15pub mod request;
16
17#[derive(Debug, thiserror::Error)]
18pub enum VectorStoreError {
19    #[error("Embedding error: {0}")]
20    EmbeddingError(#[from] EmbeddingError),
21
22    /// Json error (e.g.: serialization, deserialization, etc.)
23    #[error("Json error: {0}")]
24    JsonError(#[from] serde_json::Error),
25
26    #[cfg(not(target_family = "wasm"))]
27    #[error("Datastore error: {0}")]
28    DatastoreError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
29
30    #[cfg(target_family = "wasm")]
31    #[error("Datastore error: {0}")]
32    DatastoreError(#[from] Box<dyn std::error::Error + 'static>),
33
34    #[error("Missing Id: {0}")]
35    MissingIdError(String),
36
37    #[error("HTTP request error: {0}")]
38    ReqwestError(#[from] reqwest::Error),
39
40    #[error("External call to API returned an error. Error code: {0} Message: {1}")]
41    ExternalAPIError(StatusCode, String),
42
43    #[error("Error while building VectorSearchRequest: {0}")]
44    BuilderError(String),
45}
46
47/// Trait for inserting documents into a vector store.
48pub trait InsertDocuments: WasmCompatSend + WasmCompatSync {
49    /// Insert documents into the vector store.
50    ///
51    fn insert_documents<Doc: Serialize + Embed + WasmCompatSend>(
52        &self,
53        documents: Vec<(Doc, OneOrMany<Embedding>)>,
54    ) -> impl std::future::Future<Output = Result<(), VectorStoreError>> + WasmCompatSend;
55}
56
57/// Trait for vector store indexes
58pub trait VectorStoreIndex: WasmCompatSend + WasmCompatSync {
59    /// Get the top n documents based on the distance to the given query.
60    /// The result is a list of tuples of the form (score, id, document)
61    fn top_n<T: for<'a> Deserialize<'a> + WasmCompatSend>(
62        &self,
63        req: VectorSearchRequest,
64    ) -> impl std::future::Future<Output = Result<Vec<(f64, String, T)>, VectorStoreError>>
65    + WasmCompatSend;
66
67    /// Same as `top_n` but returns the document ids only.
68    fn top_n_ids(
69        &self,
70        req: VectorSearchRequest,
71    ) -> impl std::future::Future<Output = Result<Vec<(f64, String)>, VectorStoreError>> + WasmCompatSend;
72}
73
74pub type TopNResults = Result<Vec<(f64, String, Value)>, VectorStoreError>;
75
76pub trait VectorStoreIndexDyn: WasmCompatSend + WasmCompatSync {
77    fn top_n<'a>(&'a self, req: VectorSearchRequest) -> WasmBoxedFuture<'a, TopNResults>;
78
79    fn top_n_ids<'a>(
80        &'a self,
81        req: VectorSearchRequest,
82    ) -> WasmBoxedFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>>;
83}
84
85impl<I: VectorStoreIndex> VectorStoreIndexDyn for I {
86    fn top_n<'a>(&'a self, req: VectorSearchRequest) -> WasmBoxedFuture<'a, TopNResults> {
87        Box::pin(async move {
88            Ok(self
89                .top_n::<serde_json::Value>(req)
90                .await?
91                .into_iter()
92                .map(|(score, id, doc)| (score, id, prune_document(doc).unwrap_or_default()))
93                .collect::<Vec<_>>())
94        })
95    }
96
97    fn top_n_ids<'a>(
98        &'a self,
99        req: VectorSearchRequest,
100    ) -> WasmBoxedFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>> {
101        Box::pin(self.top_n_ids(req))
102    }
103}
104
105fn prune_document(document: serde_json::Value) -> Option<serde_json::Value> {
106    match document {
107        Value::Object(mut map) => {
108            let new_map = map
109                .iter_mut()
110                .filter_map(|(key, value)| {
111                    prune_document(value.take()).map(|value| (key.clone(), value))
112                })
113                .collect::<serde_json::Map<_, _>>();
114
115            Some(Value::Object(new_map))
116        }
117        Value::Array(vec) if vec.len() > 400 => None,
118        Value::Array(vec) => Some(Value::Array(
119            vec.into_iter().filter_map(prune_document).collect(),
120        )),
121        Value::Number(num) => Some(Value::Number(num)),
122        Value::String(s) => Some(Value::String(s)),
123        Value::Bool(b) => Some(Value::Bool(b)),
124        Value::Null => Some(Value::Null),
125    }
126}
127
128#[derive(Serialize, Deserialize, Debug)]
129pub struct VectorStoreOutput {
130    pub score: f64,
131    pub id: String,
132    pub document: Value,
133}
134
135impl<T> Tool for T
136where
137    T: VectorStoreIndex,
138{
139    const NAME: &'static str = "search_vector_store";
140
141    type Error = VectorStoreError;
142    type Args = VectorSearchRequest;
143    type Output = Vec<VectorStoreOutput>;
144
145    async fn definition(&self, _prompt: String) -> ToolDefinition {
146        ToolDefinition {
147            name: Self::NAME.to_string(),
148            description:
149                "Retrieves the most relevant documents from a vector store based on a query."
150                    .to_string(),
151            parameters: json!({
152                "type": "object",
153                "properties": {
154                    "query": {
155                        "type": "string",
156                        "description": "The query string to search for relevant documents in the vector store."
157                    },
158                    "samples": {
159                        "type": "integer",
160                        "description": "The maxinum number of samples / documents to retrieve.",
161                        "default": 5,
162                        "minimum": 1
163                    },
164                    "threshold": {
165                        "type": "number",
166                        "description": "Similarity search threshold. If present, any result with a distance less than this may be omitted from the final result."
167                    }
168                },
169                "required": ["query", "samples"]
170            }),
171        }
172    }
173
174    async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
175        let results = self.top_n(args).await?;
176        Ok(results
177            .into_iter()
178            .map(|(score, id, document)| VectorStoreOutput {
179                score,
180                id,
181                document,
182            })
183            .collect())
184    }
185}