1pub use request::VectorSearchRequest;
2use reqwest::StatusCode;
3use serde::{Deserialize, Serialize};
4use serde_json::{Value, json};
5
6use crate::{
7 Embed, OneOrMany,
8 completion::ToolDefinition,
9 embeddings::{Embedding, EmbeddingError},
10 tool::Tool,
11 wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync},
12};
13
14pub mod in_memory_store;
15pub mod request;
16
17#[derive(Debug, thiserror::Error)]
18pub enum VectorStoreError {
19 #[error("Embedding error: {0}")]
20 EmbeddingError(#[from] EmbeddingError),
21
22 #[error("Json error: {0}")]
24 JsonError(#[from] serde_json::Error),
25
26 #[cfg(not(target_family = "wasm"))]
27 #[error("Datastore error: {0}")]
28 DatastoreError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
29
30 #[cfg(target_family = "wasm")]
31 #[error("Datastore error: {0}")]
32 DatastoreError(#[from] Box<dyn std::error::Error + 'static>),
33
34 #[error("Missing Id: {0}")]
35 MissingIdError(String),
36
37 #[error("HTTP request error: {0}")]
38 ReqwestError(#[from] reqwest::Error),
39
40 #[error("External call to API returned an error. Error code: {0} Message: {1}")]
41 ExternalAPIError(StatusCode, String),
42
43 #[error("Error while building VectorSearchRequest: {0}")]
44 BuilderError(String),
45}
46
47pub trait InsertDocuments: WasmCompatSend + WasmCompatSync {
49 fn insert_documents<Doc: Serialize + Embed + WasmCompatSend>(
52 &self,
53 documents: Vec<(Doc, OneOrMany<Embedding>)>,
54 ) -> impl std::future::Future<Output = Result<(), VectorStoreError>> + WasmCompatSend;
55}
56
57pub trait VectorStoreIndex: WasmCompatSend + WasmCompatSync {
59 fn top_n<T: for<'a> Deserialize<'a> + WasmCompatSend>(
62 &self,
63 req: VectorSearchRequest,
64 ) -> impl std::future::Future<Output = Result<Vec<(f64, String, T)>, VectorStoreError>>
65 + WasmCompatSend;
66
67 fn top_n_ids(
69 &self,
70 req: VectorSearchRequest,
71 ) -> impl std::future::Future<Output = Result<Vec<(f64, String)>, VectorStoreError>> + WasmCompatSend;
72}
73
74pub type TopNResults = Result<Vec<(f64, String, Value)>, VectorStoreError>;
75
76pub trait VectorStoreIndexDyn: WasmCompatSend + WasmCompatSync {
77 fn top_n<'a>(&'a self, req: VectorSearchRequest) -> WasmBoxedFuture<'a, TopNResults>;
78
79 fn top_n_ids<'a>(
80 &'a self,
81 req: VectorSearchRequest,
82 ) -> WasmBoxedFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>>;
83}
84
85impl<I: VectorStoreIndex> VectorStoreIndexDyn for I {
86 fn top_n<'a>(&'a self, req: VectorSearchRequest) -> WasmBoxedFuture<'a, TopNResults> {
87 Box::pin(async move {
88 Ok(self
89 .top_n::<serde_json::Value>(req)
90 .await?
91 .into_iter()
92 .map(|(score, id, doc)| (score, id, prune_document(doc).unwrap_or_default()))
93 .collect::<Vec<_>>())
94 })
95 }
96
97 fn top_n_ids<'a>(
98 &'a self,
99 req: VectorSearchRequest,
100 ) -> WasmBoxedFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>> {
101 Box::pin(self.top_n_ids(req))
102 }
103}
104
105fn prune_document(document: serde_json::Value) -> Option<serde_json::Value> {
106 match document {
107 Value::Object(mut map) => {
108 let new_map = map
109 .iter_mut()
110 .filter_map(|(key, value)| {
111 prune_document(value.take()).map(|value| (key.clone(), value))
112 })
113 .collect::<serde_json::Map<_, _>>();
114
115 Some(Value::Object(new_map))
116 }
117 Value::Array(vec) if vec.len() > 400 => None,
118 Value::Array(vec) => Some(Value::Array(
119 vec.into_iter().filter_map(prune_document).collect(),
120 )),
121 Value::Number(num) => Some(Value::Number(num)),
122 Value::String(s) => Some(Value::String(s)),
123 Value::Bool(b) => Some(Value::Bool(b)),
124 Value::Null => Some(Value::Null),
125 }
126}
127
128#[derive(Serialize, Deserialize, Debug)]
129pub struct VectorStoreOutput {
130 pub score: f64,
131 pub id: String,
132 pub document: Value,
133}
134
135impl<T> Tool for T
136where
137 T: VectorStoreIndex,
138{
139 const NAME: &'static str = "search_vector_store";
140
141 type Error = VectorStoreError;
142 type Args = VectorSearchRequest;
143 type Output = Vec<VectorStoreOutput>;
144
145 async fn definition(&self, _prompt: String) -> ToolDefinition {
146 ToolDefinition {
147 name: Self::NAME.to_string(),
148 description:
149 "Retrieves the most relevant documents from a vector store based on a query."
150 .to_string(),
151 parameters: json!({
152 "type": "object",
153 "properties": {
154 "query": {
155 "type": "string",
156 "description": "The query string to search for relevant documents in the vector store."
157 },
158 "samples": {
159 "type": "integer",
160 "description": "The maxinum number of samples / documents to retrieve.",
161 "default": 5,
162 "minimum": 1
163 },
164 "threshold": {
165 "type": "number",
166 "description": "Similarity search threshold. If present, any result with a distance less than this may be omitted from the final result."
167 }
168 },
169 "required": ["query", "samples"]
170 }),
171 }
172 }
173
174 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
175 let results = self.top_n(args).await?;
176 Ok(results
177 .into_iter()
178 .map(|(score, id, document)| VectorStoreOutput {
179 score,
180 id,
181 document,
182 })
183 .collect())
184 }
185}