1pub use request::VectorSearchRequest;
2use reqwest::StatusCode;
3use serde::{Deserialize, Serialize};
4use serde_json::{Value, json};
5
6use crate::{
7 Embed, OneOrMany,
8 completion::ToolDefinition,
9 embeddings::{Embedding, EmbeddingError},
10 tool::Tool,
11 vector_store::request::{Filter, FilterError, SearchFilter},
12 wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync},
13};
14
15pub mod in_memory_store;
16pub mod request;
17
18#[derive(Debug, thiserror::Error)]
19pub enum VectorStoreError {
20 #[error("Embedding error: {0}")]
21 EmbeddingError(#[from] EmbeddingError),
22
23 #[error("Json error: {0}")]
25 JsonError(#[from] serde_json::Error),
26
27 #[cfg(not(target_family = "wasm"))]
28 #[error("Datastore error: {0}")]
29 DatastoreError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
30
31 #[error("Filter error: {0}")]
32 FilterError(#[from] FilterError),
33
34 #[cfg(target_family = "wasm")]
35 #[error("Datastore error: {0}")]
36 DatastoreError(#[from] Box<dyn std::error::Error + 'static>),
37
38 #[error("Missing Id: {0}")]
39 MissingIdError(String),
40
41 #[error("HTTP request error: {0}")]
42 ReqwestError(#[from] reqwest::Error),
43
44 #[error("External call to API returned an error. Error code: {0} Message: {1}")]
45 ExternalAPIError(StatusCode, String),
46
47 #[error("Error while building VectorSearchRequest: {0}")]
48 BuilderError(String),
49}
50
51pub trait InsertDocuments: WasmCompatSend + WasmCompatSync {
53 fn insert_documents<Doc: Serialize + Embed + WasmCompatSend>(
56 &self,
57 documents: Vec<(Doc, OneOrMany<Embedding>)>,
58 ) -> impl std::future::Future<Output = Result<(), VectorStoreError>> + WasmCompatSend;
59}
60
61pub trait VectorStoreIndex: WasmCompatSend + WasmCompatSync {
63 type Filter: SearchFilter + WasmCompatSend + WasmCompatSync;
64
65 fn top_n<T: for<'a> Deserialize<'a> + WasmCompatSend>(
68 &self,
69 req: VectorSearchRequest<Self::Filter>,
70 ) -> impl std::future::Future<Output = Result<Vec<(f64, String, T)>, VectorStoreError>>
71 + WasmCompatSend;
72
73 fn top_n_ids(
75 &self,
76 req: VectorSearchRequest<Self::Filter>,
77 ) -> impl std::future::Future<Output = Result<Vec<(f64, String)>, VectorStoreError>> + WasmCompatSend;
78}
79
80pub type TopNResults = Result<Vec<(f64, String, Value)>, VectorStoreError>;
81
82pub trait VectorStoreIndexDyn: WasmCompatSend + WasmCompatSync {
83 fn top_n<'a>(
84 &'a self,
85 req: VectorSearchRequest<Filter<serde_json::Value>>,
86 ) -> WasmBoxedFuture<'a, TopNResults>;
87
88 fn top_n_ids<'a>(
89 &'a self,
90 req: VectorSearchRequest<Filter<serde_json::Value>>,
91 ) -> WasmBoxedFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>>;
92}
93
94impl<I: VectorStoreIndex<Filter = F>, F> VectorStoreIndexDyn for I
95where
96 F: std::fmt::Debug
97 + Clone
98 + SearchFilter<Value = serde_json::Value>
99 + WasmCompatSend
100 + WasmCompatSync
101 + Serialize
102 + for<'de> Deserialize<'de>
103 + 'static,
104{
105 fn top_n<'a>(
106 &'a self,
107 req: VectorSearchRequest<Filter<serde_json::Value>>,
108 ) -> WasmBoxedFuture<'a, TopNResults> {
109 let req = req.map_filter(Filter::interpret);
110
111 Box::pin(async move {
112 Ok(self
113 .top_n::<serde_json::Value>(req)
114 .await?
115 .into_iter()
116 .map(|(score, id, doc)| (score, id, prune_document(doc).unwrap_or_default()))
117 .collect::<Vec<_>>())
118 })
119 }
120
121 fn top_n_ids<'a>(
122 &'a self,
123 req: VectorSearchRequest<Filter<serde_json::Value>>,
124 ) -> WasmBoxedFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>> {
125 let req = req.map_filter(Filter::interpret);
126
127 Box::pin(self.top_n_ids(req))
128 }
129}
130
131fn prune_document(document: serde_json::Value) -> Option<serde_json::Value> {
132 match document {
133 Value::Object(mut map) => {
134 let new_map = map
135 .iter_mut()
136 .filter_map(|(key, value)| {
137 prune_document(value.take()).map(|value| (key.clone(), value))
138 })
139 .collect::<serde_json::Map<_, _>>();
140
141 Some(Value::Object(new_map))
142 }
143 Value::Array(vec) if vec.len() > 400 => None,
144 Value::Array(vec) => Some(Value::Array(
145 vec.into_iter().filter_map(prune_document).collect(),
146 )),
147 Value::Number(num) => Some(Value::Number(num)),
148 Value::String(s) => Some(Value::String(s)),
149 Value::Bool(b) => Some(Value::Bool(b)),
150 Value::Null => Some(Value::Null),
151 }
152}
153
154#[derive(Serialize, Deserialize, Debug)]
155pub struct VectorStoreOutput {
156 pub score: f64,
157 pub id: String,
158 pub document: Value,
159}
160
161impl<T, F> Tool for T
162where
163 F: SearchFilter<Value = serde_json::Value>
164 + WasmCompatSend
165 + WasmCompatSync
166 + for<'de> Deserialize<'de>,
167 T: VectorStoreIndex<Filter = F>,
168{
169 const NAME: &'static str = "search_vector_store";
170
171 type Error = VectorStoreError;
172 type Args = VectorSearchRequest<F>;
173 type Output = Vec<VectorStoreOutput>;
174
175 async fn definition(&self, _prompt: String) -> ToolDefinition {
176 ToolDefinition {
177 name: Self::NAME.to_string(),
178 description:
179 "Retrieves the most relevant documents from a vector store based on a query."
180 .to_string(),
181 parameters: json!({
182 "type": "object",
183 "properties": {
184 "query": {
185 "type": "string",
186 "description": "The query string to search for relevant documents in the vector store."
187 },
188 "samples": {
189 "type": "integer",
190 "description": "The maxinum number of samples / documents to retrieve.",
191 "default": 5,
192 "minimum": 1
193 },
194 "threshold": {
195 "type": "number",
196 "description": "Similarity search threshold. If present, any result with a distance less than this may be omitted from the final result."
197 }
198 },
199 "required": ["query", "samples"]
200 }),
201 }
202 }
203
204 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
205 let results = self.top_n(args).await?;
206 Ok(results
207 .into_iter()
208 .map(|(score, id, document)| VectorStoreOutput {
209 score,
210 id,
211 document,
212 })
213 .collect())
214 }
215}