1pub use request::VectorSearchRequest;
2use reqwest::StatusCode;
3use serde::{Deserialize, Serialize};
4use serde_json::{Value, json};
5
6use crate::{
7 Embed, OneOrMany,
8 completion::ToolDefinition,
9 embeddings::{Embedding, EmbeddingError},
10 tool::Tool,
11 vector_store::request::{Filter, FilterError, SearchFilter},
12 wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync},
13};
14
15pub mod builder;
16pub mod in_memory_store;
17pub mod lsh;
18pub mod request;
19
20#[derive(Debug, thiserror::Error)]
21pub enum VectorStoreError {
22 #[error("Embedding error: {0}")]
23 EmbeddingError(#[from] EmbeddingError),
24
25 #[error("Json error: {0}")]
27 JsonError(#[from] serde_json::Error),
28
29 #[cfg(not(target_family = "wasm"))]
30 #[error("Datastore error: {0}")]
31 DatastoreError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
32
33 #[error("Filter error: {0}")]
34 FilterError(#[from] FilterError),
35
36 #[cfg(target_family = "wasm")]
37 #[error("Datastore error: {0}")]
38 DatastoreError(#[from] Box<dyn std::error::Error + 'static>),
39
40 #[error("Missing Id: {0}")]
41 MissingIdError(String),
42
43 #[error("HTTP request error: {0}")]
44 ReqwestError(#[from] reqwest::Error),
45
46 #[error("External call to API returned an error. Error code: {0} Message: {1}")]
47 ExternalAPIError(StatusCode, String),
48
49 #[error("Error while building VectorSearchRequest: {0}")]
50 BuilderError(String),
51}
52
53pub trait InsertDocuments: WasmCompatSend + WasmCompatSync {
55 fn insert_documents<Doc: Serialize + Embed + WasmCompatSend>(
58 &self,
59 documents: Vec<(Doc, OneOrMany<Embedding>)>,
60 ) -> impl std::future::Future<Output = Result<(), VectorStoreError>> + WasmCompatSend;
61}
62
63pub trait VectorStoreIndex: WasmCompatSend + WasmCompatSync {
65 type Filter: SearchFilter + WasmCompatSend + WasmCompatSync;
66
67 fn top_n<T: for<'a> Deserialize<'a> + WasmCompatSend>(
70 &self,
71 req: VectorSearchRequest<Self::Filter>,
72 ) -> impl std::future::Future<Output = Result<Vec<(f64, String, T)>, VectorStoreError>>
73 + WasmCompatSend;
74
75 fn top_n_ids(
77 &self,
78 req: VectorSearchRequest<Self::Filter>,
79 ) -> impl std::future::Future<Output = Result<Vec<(f64, String)>, VectorStoreError>> + WasmCompatSend;
80}
81
82pub type TopNResults = Result<Vec<(f64, String, Value)>, VectorStoreError>;
83
84pub trait VectorStoreIndexDyn: WasmCompatSend + WasmCompatSync {
85 fn top_n<'a>(
86 &'a self,
87 req: VectorSearchRequest<Filter<serde_json::Value>>,
88 ) -> WasmBoxedFuture<'a, TopNResults>;
89
90 fn top_n_ids<'a>(
91 &'a self,
92 req: VectorSearchRequest<Filter<serde_json::Value>>,
93 ) -> WasmBoxedFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>>;
94}
95
96impl<I: VectorStoreIndex<Filter = F>, F> VectorStoreIndexDyn for I
97where
98 F: std::fmt::Debug
99 + Clone
100 + SearchFilter<Value = serde_json::Value>
101 + WasmCompatSend
102 + WasmCompatSync
103 + Serialize
104 + for<'de> Deserialize<'de>
105 + 'static,
106{
107 fn top_n<'a>(
108 &'a self,
109 req: VectorSearchRequest<Filter<serde_json::Value>>,
110 ) -> WasmBoxedFuture<'a, TopNResults> {
111 let req = req.map_filter(Filter::interpret);
112
113 Box::pin(async move {
114 Ok(self
115 .top_n::<serde_json::Value>(req)
116 .await?
117 .into_iter()
118 .map(|(score, id, doc)| (score, id, prune_document(doc).unwrap_or_default()))
119 .collect::<Vec<_>>())
120 })
121 }
122
123 fn top_n_ids<'a>(
124 &'a self,
125 req: VectorSearchRequest<Filter<serde_json::Value>>,
126 ) -> WasmBoxedFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>> {
127 let req = req.map_filter(Filter::interpret);
128
129 Box::pin(self.top_n_ids(req))
130 }
131}
132
133fn prune_document(document: serde_json::Value) -> Option<serde_json::Value> {
134 match document {
135 Value::Object(mut map) => {
136 let new_map = map
137 .iter_mut()
138 .filter_map(|(key, value)| {
139 prune_document(value.take()).map(|value| (key.clone(), value))
140 })
141 .collect::<serde_json::Map<_, _>>();
142
143 Some(Value::Object(new_map))
144 }
145 Value::Array(vec) if vec.len() > 400 => None,
146 Value::Array(vec) => Some(Value::Array(
147 vec.into_iter().filter_map(prune_document).collect(),
148 )),
149 Value::Number(num) => Some(Value::Number(num)),
150 Value::String(s) => Some(Value::String(s)),
151 Value::Bool(b) => Some(Value::Bool(b)),
152 Value::Null => Some(Value::Null),
153 }
154}
155
156#[derive(Serialize, Deserialize, Debug)]
157pub struct VectorStoreOutput {
158 pub score: f64,
159 pub id: String,
160 pub document: Value,
161}
162
163impl<T, F> Tool for T
164where
165 F: SearchFilter<Value = serde_json::Value>
166 + WasmCompatSend
167 + WasmCompatSync
168 + for<'de> Deserialize<'de>,
169 T: VectorStoreIndex<Filter = F>,
170{
171 const NAME: &'static str = "search_vector_store";
172
173 type Error = VectorStoreError;
174 type Args = VectorSearchRequest<F>;
175 type Output = Vec<VectorStoreOutput>;
176
177 async fn definition(&self, _prompt: String) -> ToolDefinition {
178 ToolDefinition {
179 name: Self::NAME.to_string(),
180 description:
181 "Retrieves the most relevant documents from a vector store based on a query."
182 .to_string(),
183 parameters: json!({
184 "type": "object",
185 "properties": {
186 "query": {
187 "type": "string",
188 "description": "The query string to search for relevant documents in the vector store."
189 },
190 "samples": {
191 "type": "integer",
192 "description": "The maxinum number of samples / documents to retrieve.",
193 "default": 5,
194 "minimum": 1
195 },
196 "threshold": {
197 "type": "number",
198 "description": "Similarity search threshold. If present, any result with a distance less than this may be omitted from the final result."
199 }
200 },
201 "required": ["query", "samples"]
202 }),
203 }
204 }
205
206 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
207 let results = self.top_n(args).await?;
208 Ok(results
209 .into_iter()
210 .map(|(score, id, document)| VectorStoreOutput {
211 score,
212 id,
213 document,
214 })
215 .collect())
216 }
217}
218
219#[derive(Clone, Debug)]
221pub enum IndexStrategy {
222 BruteForce,
224
225 LSH {
227 num_tables: usize,
229 num_hyperplanes: usize,
231 },
232}
233
234impl Default for IndexStrategy {
235 fn default() -> Self {
236 Self::BruteForce
237 }
238}