1pub use request::VectorSearchRequest;
14use reqwest::StatusCode;
15use serde::{Deserialize, Serialize};
16use serde_json::{Value, json};
17
18use crate::{
19 Embed, OneOrMany,
20 completion::ToolDefinition,
21 embeddings::{Embedding, EmbeddingError},
22 tool::Tool,
23 vector_store::request::{Filter, FilterError, SearchFilter},
24 wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync},
25};
26
27pub mod builder;
28pub mod in_memory_store;
29pub mod lsh;
30pub mod request;
31
32#[derive(Debug, thiserror::Error)]
34pub enum VectorStoreError {
35 #[error("Embedding error: {0}")]
36 EmbeddingError(#[from] EmbeddingError),
37
38 #[error("Json error: {0}")]
39 JsonError(#[from] serde_json::Error),
40
41 #[cfg(not(target_family = "wasm"))]
42 #[error("Datastore error: {0}")]
43 DatastoreError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
44
45 #[error("Filter error: {0}")]
46 FilterError(#[from] FilterError),
47
48 #[cfg(target_family = "wasm")]
49 #[error("Datastore error: {0}")]
50 DatastoreError(#[from] Box<dyn std::error::Error + 'static>),
51
52 #[error("Missing Id: {0}")]
53 MissingIdError(String),
54
55 #[error("HTTP request error: {0}")]
56 ReqwestError(#[from] reqwest::Error),
57
58 #[error("External call to API returned an error. Error code: {0} Message: {1}")]
59 ExternalAPIError(StatusCode, String),
60
61 #[error("Error while building VectorSearchRequest: {0}")]
62 BuilderError(String),
63}
64
65pub trait InsertDocuments: WasmCompatSend + WasmCompatSync {
67 fn insert_documents<Doc: Serialize + Embed + WasmCompatSend>(
68 &self,
69 documents: Vec<(Doc, OneOrMany<Embedding>)>,
70 ) -> impl std::future::Future<Output = Result<(), VectorStoreError>> + WasmCompatSend;
71}
72
73pub trait VectorStoreIndex: WasmCompatSend + WasmCompatSync {
75 type Filter: SearchFilter + WasmCompatSend + WasmCompatSync;
77
78 fn top_n<T: for<'a> Deserialize<'a> + WasmCompatSend>(
80 &self,
81 req: VectorSearchRequest<Self::Filter>,
82 ) -> impl std::future::Future<Output = Result<Vec<(f64, String, T)>, VectorStoreError>>
83 + WasmCompatSend;
84
85 fn top_n_ids(
87 &self,
88 req: VectorSearchRequest<Self::Filter>,
89 ) -> impl std::future::Future<Output = Result<Vec<(f64, String)>, VectorStoreError>> + WasmCompatSend;
90}
91
92pub type TopNResults = Result<Vec<(f64, String, Value)>, VectorStoreError>;
93
94pub trait VectorStoreIndexDyn: WasmCompatSend + WasmCompatSync {
96 fn top_n<'a>(
97 &'a self,
98 req: VectorSearchRequest<Filter<serde_json::Value>>,
99 ) -> WasmBoxedFuture<'a, TopNResults>;
100
101 fn top_n_ids<'a>(
102 &'a self,
103 req: VectorSearchRequest<Filter<serde_json::Value>>,
104 ) -> WasmBoxedFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>>;
105}
106
107impl<I: VectorStoreIndex<Filter = F>, F> VectorStoreIndexDyn for I
108where
109 F: std::fmt::Debug
110 + Clone
111 + SearchFilter<Value = serde_json::Value>
112 + WasmCompatSend
113 + WasmCompatSync
114 + Serialize
115 + for<'de> Deserialize<'de>
116 + 'static,
117{
118 fn top_n<'a>(
119 &'a self,
120 req: VectorSearchRequest<Filter<serde_json::Value>>,
121 ) -> WasmBoxedFuture<'a, TopNResults> {
122 let req = req.map_filter(Filter::interpret);
123
124 Box::pin(async move {
125 Ok(self
126 .top_n::<serde_json::Value>(req)
127 .await?
128 .into_iter()
129 .map(|(score, id, doc)| (score, id, prune_document(doc).unwrap_or_default()))
130 .collect::<Vec<_>>())
131 })
132 }
133
134 fn top_n_ids<'a>(
135 &'a self,
136 req: VectorSearchRequest<Filter<serde_json::Value>>,
137 ) -> WasmBoxedFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>> {
138 let req = req.map_filter(Filter::interpret);
139
140 Box::pin(self.top_n_ids(req))
141 }
142}
143
144fn prune_document(document: serde_json::Value) -> Option<serde_json::Value> {
145 match document {
146 Value::Object(mut map) => {
147 let new_map = map
148 .iter_mut()
149 .filter_map(|(key, value)| {
150 prune_document(value.take()).map(|value| (key.clone(), value))
151 })
152 .collect::<serde_json::Map<_, _>>();
153
154 Some(Value::Object(new_map))
155 }
156 Value::Array(vec) if vec.len() > 400 => None,
157 Value::Array(vec) => Some(Value::Array(
158 vec.into_iter().filter_map(prune_document).collect(),
159 )),
160 Value::Number(num) => Some(Value::Number(num)),
161 Value::String(s) => Some(Value::String(s)),
162 Value::Bool(b) => Some(Value::Bool(b)),
163 Value::Null => Some(Value::Null),
164 }
165}
166
167#[derive(Serialize, Deserialize, Debug)]
169pub struct VectorStoreOutput {
170 pub score: f64,
171 pub id: String,
172 pub document: Value,
173}
174
175impl<T, F> Tool for T
176where
177 F: SearchFilter<Value = serde_json::Value>
178 + WasmCompatSend
179 + WasmCompatSync
180 + for<'de> Deserialize<'de>,
181 T: VectorStoreIndex<Filter = F>,
182{
183 const NAME: &'static str = "search_vector_store";
184
185 type Error = VectorStoreError;
186 type Args = VectorSearchRequest<F>;
187 type Output = Vec<VectorStoreOutput>;
188
189 async fn definition(&self, _prompt: String) -> ToolDefinition {
190 ToolDefinition {
191 name: Self::NAME.to_string(),
192 description:
193 "Retrieves the most relevant documents from a vector store based on a query."
194 .to_string(),
195 parameters: json!({
196 "type": "object",
197 "properties": {
198 "query": {
199 "type": "string",
200 "description": "The query string to search for relevant documents in the vector store."
201 },
202 "samples": {
203 "type": "integer",
204 "description": "The maxinum number of samples / documents to retrieve.",
205 "default": 5,
206 "minimum": 1
207 },
208 "threshold": {
209 "type": "number",
210 "description": "Similarity search threshold. If present, any result with a distance less than this may be omitted from the final result."
211 }
212 },
213 "required": ["query", "samples"]
214 }),
215 }
216 }
217
218 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
219 let results = self.top_n(args).await?;
220 Ok(results
221 .into_iter()
222 .map(|(score, id, document)| VectorStoreOutput {
223 score,
224 id,
225 document,
226 })
227 .collect())
228 }
229}
230
231#[derive(Clone, Debug)]
233pub enum IndexStrategy {
234 BruteForce,
236
237 LSH {
239 num_tables: usize,
241 num_hyperplanes: usize,
243 },
244}
245
246impl Default for IndexStrategy {
247 fn default() -> Self {
248 Self::BruteForce
249 }
250}