1pub use request::VectorSearchRequest;
14use reqwest::StatusCode;
15use serde::{Deserialize, Serialize};
16use serde_json::{Value, json};
17
18use crate::{
19 Embed, OneOrMany,
20 completion::ToolDefinition,
21 embeddings::{Embedding, EmbeddingError},
22 tool::Tool,
23 vector_store::request::{Filter, FilterError, SearchFilter},
24 wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync},
25};
26
27pub mod builder;
28pub mod in_memory_store;
29pub mod lsh;
30pub mod request;
31
32#[derive(Debug, thiserror::Error)]
34pub enum VectorStoreError {
35 #[error("Embedding error: {0}")]
37 EmbeddingError(#[from] EmbeddingError),
38
39 #[error("Json error: {0}")]
41 JsonError(#[from] serde_json::Error),
42
43 #[cfg(not(target_family = "wasm"))]
44 #[error("Datastore error: {0}")]
46 DatastoreError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
47
48 #[error("Filter error: {0}")]
50 FilterError(#[from] FilterError),
51
52 #[cfg(target_family = "wasm")]
53 #[error("Datastore error: {0}")]
55 DatastoreError(#[from] Box<dyn std::error::Error + 'static>),
56
57 #[error("Missing Id: {0}")]
59 MissingIdError(String),
60
61 #[error("HTTP request error: {0}")]
63 ReqwestError(#[from] reqwest::Error),
64
65 #[error("External call to API returned an error. Error code: {0} Message: {1}")]
67 ExternalAPIError(StatusCode, String),
68
69 #[error("Error while building VectorSearchRequest: {0}")]
71 BuilderError(String),
72}
73
74pub trait InsertDocuments: WasmCompatSend + WasmCompatSync {
76 fn insert_documents<Doc: Serialize + Embed + WasmCompatSend>(
78 &self,
79 documents: Vec<(Doc, OneOrMany<Embedding>)>,
80 ) -> impl std::future::Future<Output = Result<(), VectorStoreError>> + WasmCompatSend;
81}
82
83pub trait VectorStoreIndex: WasmCompatSend + WasmCompatSync {
85 type Filter: SearchFilter + WasmCompatSend + WasmCompatSync;
87
88 fn top_n<T: for<'a> Deserialize<'a> + WasmCompatSend>(
90 &self,
91 req: VectorSearchRequest<Self::Filter>,
92 ) -> impl std::future::Future<Output = Result<Vec<(f64, String, T)>, VectorStoreError>>
93 + WasmCompatSend;
94
95 fn top_n_ids(
97 &self,
98 req: VectorSearchRequest<Self::Filter>,
99 ) -> impl std::future::Future<Output = Result<Vec<(f64, String)>, VectorStoreError>> + WasmCompatSend;
100}
101
102pub type TopNResults = Result<Vec<(f64, String, Value)>, VectorStoreError>;
104
105pub trait VectorStoreIndexDyn: WasmCompatSend + WasmCompatSync {
107 fn top_n<'a>(
109 &'a self,
110 req: VectorSearchRequest<Filter<serde_json::Value>>,
111 ) -> WasmBoxedFuture<'a, TopNResults>;
112
113 fn top_n_ids<'a>(
115 &'a self,
116 req: VectorSearchRequest<Filter<serde_json::Value>>,
117 ) -> WasmBoxedFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>>;
118}
119
120impl<I: VectorStoreIndex<Filter = F>, F> VectorStoreIndexDyn for I
121where
122 F: std::fmt::Debug
123 + Clone
124 + SearchFilter<Value = serde_json::Value>
125 + WasmCompatSend
126 + WasmCompatSync
127 + Serialize
128 + for<'de> Deserialize<'de>
129 + 'static,
130{
131 fn top_n<'a>(
132 &'a self,
133 req: VectorSearchRequest<Filter<serde_json::Value>>,
134 ) -> WasmBoxedFuture<'a, TopNResults> {
135 let req = req.map_filter(Filter::interpret);
136
137 Box::pin(async move {
138 Ok(self
139 .top_n::<serde_json::Value>(req)
140 .await?
141 .into_iter()
142 .map(|(score, id, doc)| (score, id, prune_document(doc).unwrap_or_default()))
143 .collect::<Vec<_>>())
144 })
145 }
146
147 fn top_n_ids<'a>(
148 &'a self,
149 req: VectorSearchRequest<Filter<serde_json::Value>>,
150 ) -> WasmBoxedFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>> {
151 let req = req.map_filter(Filter::interpret);
152
153 Box::pin(self.top_n_ids(req))
154 }
155}
156
157fn prune_document(document: serde_json::Value) -> Option<serde_json::Value> {
158 match document {
159 Value::Object(mut map) => {
160 let new_map = map
161 .iter_mut()
162 .filter_map(|(key, value)| {
163 prune_document(value.take()).map(|value| (key.clone(), value))
164 })
165 .collect::<serde_json::Map<_, _>>();
166
167 Some(Value::Object(new_map))
168 }
169 Value::Array(vec) if vec.len() > 400 => None,
170 Value::Array(vec) => Some(Value::Array(
171 vec.into_iter().filter_map(prune_document).collect(),
172 )),
173 Value::Number(num) => Some(Value::Number(num)),
174 Value::String(s) => Some(Value::String(s)),
175 Value::Bool(b) => Some(Value::Bool(b)),
176 Value::Null => Some(Value::Null),
177 }
178}
179
180#[derive(Serialize, Deserialize, Debug)]
182pub struct VectorStoreOutput {
183 pub score: f64,
185 pub id: String,
187 pub document: Value,
189}
190
191impl<T, F> Tool for T
192where
193 F: SearchFilter<Value = serde_json::Value>
194 + WasmCompatSend
195 + WasmCompatSync
196 + for<'de> Deserialize<'de>,
197 T: VectorStoreIndex<Filter = F>,
198{
199 const NAME: &'static str = "search_vector_store";
200
201 type Error = VectorStoreError;
202 type Args = VectorSearchRequest<F>;
203 type Output = Vec<VectorStoreOutput>;
204
205 async fn definition(&self, _prompt: String) -> ToolDefinition {
206 ToolDefinition {
207 name: Self::NAME.to_string(),
208 description:
209 "Retrieves the most relevant documents from a vector store based on a query."
210 .to_string(),
211 parameters: json!({
212 "type": "object",
213 "properties": {
214 "query": {
215 "type": "string",
216 "description": "The query string to search for relevant documents in the vector store."
217 },
218 "samples": {
219 "type": "integer",
220 "description": "The maximum number of samples / documents to retrieve.",
221 "default": 5,
222 "minimum": 1
223 },
224 "threshold": {
225 "type": "number",
226 "description": "Similarity search threshold. If present, any result with a distance less than this may be omitted from the final result."
227 }
228 },
229 "required": ["query", "samples"]
230 }),
231 }
232 }
233
234 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
235 let results = self.top_n(args).await?;
236 Ok(results
237 .into_iter()
238 .map(|(score, id, document)| VectorStoreOutput {
239 score,
240 id,
241 document,
242 })
243 .collect())
244 }
245}
246
247#[derive(Clone, Debug, Default)]
249pub enum IndexStrategy {
250 #[default]
252 BruteForce,
253
254 LSH {
256 num_tables: usize,
258 num_hyperplanes: usize,
260 },
261}