1use futures::future::BoxFuture;
2pub use request::VectorSearchRequest;
3use reqwest::StatusCode;
4use serde::{Deserialize, Serialize};
5use serde_json::{Value, json};
6
7use crate::{
8 Embed, OneOrMany,
9 completion::ToolDefinition,
10 embeddings::{Embedding, EmbeddingError},
11 tool::Tool,
12};
13
14pub mod in_memory_store;
15pub mod request;
16
17#[derive(Debug, thiserror::Error)]
18pub enum VectorStoreError {
19 #[error("Embedding error: {0}")]
20 EmbeddingError(#[from] EmbeddingError),
21
22 #[error("Json error: {0}")]
24 JsonError(#[from] serde_json::Error),
25
26 #[error("Datastore error: {0}")]
27 DatastoreError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
28
29 #[error("Missing Id: {0}")]
30 MissingIdError(String),
31
32 #[error("HTTP request error: {0}")]
33 ReqwestError(#[from] reqwest::Error),
34
35 #[error("External call to API returned an error. Error code: {0} Message: {1}")]
36 ExternalAPIError(StatusCode, String),
37
38 #[error("Error while building VectorSearchRequest: {0}")]
39 BuilderError(String),
40}
41
42pub trait InsertDocuments: Send + Sync {
44 fn insert_documents<Doc: Serialize + Embed + Send>(
47 &self,
48 documents: Vec<(Doc, OneOrMany<Embedding>)>,
49 ) -> impl std::future::Future<Output = Result<(), VectorStoreError>> + Send;
50}
51
52pub trait VectorStoreIndex: Send + Sync {
54 fn top_n<T: for<'a> Deserialize<'a> + Send>(
57 &self,
58 req: VectorSearchRequest,
59 ) -> impl std::future::Future<Output = Result<Vec<(f64, String, T)>, VectorStoreError>> + Send;
60
61 fn top_n_ids(
63 &self,
64 req: VectorSearchRequest,
65 ) -> impl std::future::Future<Output = Result<Vec<(f64, String)>, VectorStoreError>> + Send;
66}
67
68pub type TopNResults = Result<Vec<(f64, String, Value)>, VectorStoreError>;
69
70pub trait VectorStoreIndexDyn: Send + Sync {
71 fn top_n<'a>(&'a self, req: VectorSearchRequest) -> BoxFuture<'a, TopNResults>;
72
73 fn top_n_ids<'a>(
74 &'a self,
75 req: VectorSearchRequest,
76 ) -> BoxFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>>;
77}
78
79impl<I: VectorStoreIndex> VectorStoreIndexDyn for I {
80 fn top_n<'a>(
81 &'a self,
82 req: VectorSearchRequest,
83 ) -> BoxFuture<'a, Result<Vec<(f64, String, Value)>, VectorStoreError>> {
84 Box::pin(async move {
85 Ok(self
86 .top_n::<serde_json::Value>(req)
87 .await?
88 .into_iter()
89 .map(|(score, id, doc)| (score, id, prune_document(doc).unwrap_or_default()))
90 .collect::<Vec<_>>())
91 })
92 }
93
94 fn top_n_ids<'a>(
95 &'a self,
96 req: VectorSearchRequest,
97 ) -> BoxFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>> {
98 Box::pin(self.top_n_ids(req))
99 }
100}
101
102fn prune_document(document: serde_json::Value) -> Option<serde_json::Value> {
103 match document {
104 Value::Object(mut map) => {
105 let new_map = map
106 .iter_mut()
107 .filter_map(|(key, value)| {
108 prune_document(value.take()).map(|value| (key.clone(), value))
109 })
110 .collect::<serde_json::Map<_, _>>();
111
112 Some(Value::Object(new_map))
113 }
114 Value::Array(vec) if vec.len() > 400 => None,
115 Value::Array(vec) => Some(Value::Array(
116 vec.into_iter().filter_map(prune_document).collect(),
117 )),
118 Value::Number(num) => Some(Value::Number(num)),
119 Value::String(s) => Some(Value::String(s)),
120 Value::Bool(b) => Some(Value::Bool(b)),
121 Value::Null => Some(Value::Null),
122 }
123}
124
125#[derive(Serialize, Deserialize, Debug)]
126pub struct VectorStoreOutput {
127 pub score: f64,
128 pub id: String,
129 pub document: Value,
130}
131
132impl<T> Tool for T
133where
134 T: VectorStoreIndex,
135{
136 const NAME: &'static str = "search_vector_store";
137
138 type Error = VectorStoreError;
139 type Args = VectorSearchRequest;
140 type Output = Vec<VectorStoreOutput>;
141
142 async fn definition(&self, _prompt: String) -> ToolDefinition {
143 ToolDefinition {
144 name: Self::NAME.to_string(),
145 description:
146 "Retrieves the most relevant documents from a vector store based on a query."
147 .to_string(),
148 parameters: json!({
149 "type": "object",
150 "properties": {
151 "query": {
152 "type": "string",
153 "description": "The query string to search for relevant documents in the vector store."
154 },
155 "samples": {
156 "type": "integer",
157 "description": "The maxinum number of samples / documents to retrieve.",
158 "default": 5,
159 "minimum": 1
160 },
161 "threshold": {
162 "type": "number",
163 "description": "Similarity search threshold. If present, any result with a distance less than this may be omitted from the final result."
164 }
165 },
166 "required": ["query", "samples"]
167 }),
168 }
169 }
170
171 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
172 let results = self.top_n(args).await?;
173 Ok(results
174 .into_iter()
175 .map(|(score, id, document)| VectorStoreOutput {
176 score,
177 id,
178 document,
179 })
180 .collect())
181 }
182}