pub use request::VectorSearchRequest;
use reqwest::StatusCode;
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use crate::{
Embed, OneOrMany,
completion::ToolDefinition,
embeddings::{Embedding, EmbeddingError},
tool::Tool,
vector_store::request::{Filter, FilterError, SearchFilter},
wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync},
};
pub mod builder;
pub mod in_memory_store;
pub mod lsh;
pub mod request;
#[derive(Debug, thiserror::Error)]
pub enum VectorStoreError {
#[error("Embedding error: {0}")]
EmbeddingError(#[from] EmbeddingError),
#[error("Json error: {0}")]
JsonError(#[from] serde_json::Error),
#[cfg(not(target_family = "wasm"))]
#[error("Datastore error: {0}")]
DatastoreError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("Filter error: {0}")]
FilterError(#[from] FilterError),
#[cfg(target_family = "wasm")]
#[error("Datastore error: {0}")]
DatastoreError(#[from] Box<dyn std::error::Error + 'static>),
#[error("Missing Id: {0}")]
MissingIdError(String),
#[error("HTTP request error: {0}")]
ReqwestError(#[from] reqwest::Error),
#[error("External call to API returned an error. Error code: {0} Message: {1}")]
ExternalAPIError(StatusCode, String),
#[error("Error while building VectorSearchRequest: {0}")]
BuilderError(String),
}
pub trait InsertDocuments: WasmCompatSend + WasmCompatSync {
fn insert_documents<Doc: Serialize + Embed + WasmCompatSend>(
&self,
documents: Vec<(Doc, OneOrMany<Embedding>)>,
) -> impl std::future::Future<Output = Result<(), VectorStoreError>> + WasmCompatSend;
}
pub trait VectorStoreIndex: WasmCompatSend + WasmCompatSync {
type Filter: SearchFilter + WasmCompatSend + WasmCompatSync;
fn top_n<T: for<'a> Deserialize<'a> + WasmCompatSend>(
&self,
req: VectorSearchRequest<Self::Filter>,
) -> impl std::future::Future<Output = Result<Vec<(f64, String, T)>, VectorStoreError>>
+ WasmCompatSend;
fn top_n_ids(
&self,
req: VectorSearchRequest<Self::Filter>,
) -> impl std::future::Future<Output = Result<Vec<(f64, String)>, VectorStoreError>> + WasmCompatSend;
}
pub type TopNResults = Result<Vec<(f64, String, Value)>, VectorStoreError>;
pub trait VectorStoreIndexDyn: WasmCompatSend + WasmCompatSync {
fn top_n<'a>(
&'a self,
req: VectorSearchRequest<Filter<serde_json::Value>>,
) -> WasmBoxedFuture<'a, TopNResults>;
fn top_n_ids<'a>(
&'a self,
req: VectorSearchRequest<Filter<serde_json::Value>>,
) -> WasmBoxedFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>>;
}
impl<I: VectorStoreIndex<Filter = F>, F> VectorStoreIndexDyn for I
where
F: std::fmt::Debug
+ Clone
+ SearchFilter<Value = serde_json::Value>
+ WasmCompatSend
+ WasmCompatSync
+ Serialize
+ for<'de> Deserialize<'de>
+ 'static,
{
fn top_n<'a>(
&'a self,
req: VectorSearchRequest<Filter<serde_json::Value>>,
) -> WasmBoxedFuture<'a, TopNResults> {
let req = req.map_filter(Filter::interpret);
Box::pin(async move {
Ok(self
.top_n::<serde_json::Value>(req)
.await?
.into_iter()
.map(|(score, id, doc)| (score, id, prune_document(doc).unwrap_or_default()))
.collect::<Vec<_>>())
})
}
fn top_n_ids<'a>(
&'a self,
req: VectorSearchRequest<Filter<serde_json::Value>>,
) -> WasmBoxedFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>> {
let req = req.map_filter(Filter::interpret);
Box::pin(self.top_n_ids(req))
}
}
fn prune_document(document: serde_json::Value) -> Option<serde_json::Value> {
match document {
Value::Object(mut map) => {
let new_map = map
.iter_mut()
.filter_map(|(key, value)| {
prune_document(value.take()).map(|value| (key.clone(), value))
})
.collect::<serde_json::Map<_, _>>();
Some(Value::Object(new_map))
}
Value::Array(vec) if vec.len() > 400 => None,
Value::Array(vec) => Some(Value::Array(
vec.into_iter().filter_map(prune_document).collect(),
)),
Value::Number(num) => Some(Value::Number(num)),
Value::String(s) => Some(Value::String(s)),
Value::Bool(b) => Some(Value::Bool(b)),
Value::Null => Some(Value::Null),
}
}
#[derive(Serialize, Deserialize, Debug)]
pub struct VectorStoreOutput {
pub score: f64,
pub id: String,
pub document: Value,
}
impl<T, F> Tool for T
where
F: SearchFilter<Value = serde_json::Value>
+ WasmCompatSend
+ WasmCompatSync
+ for<'de> Deserialize<'de>,
T: VectorStoreIndex<Filter = F>,
{
const NAME: &'static str = "search_vector_store";
type Error = VectorStoreError;
type Args = VectorSearchRequest<F>;
type Output = Vec<VectorStoreOutput>;
async fn definition(&self, _prompt: String) -> ToolDefinition {
ToolDefinition {
name: Self::NAME.to_string(),
description:
"Retrieves the most relevant documents from a vector store based on a query."
.to_string(),
parameters: json!({
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The query string to search for relevant documents in the vector store."
},
"samples": {
"type": "integer",
"description": "The maxinum number of samples / documents to retrieve.",
"default": 5,
"minimum": 1
},
"threshold": {
"type": "number",
"description": "Similarity search threshold. If present, any result with a distance less than this may be omitted from the final result."
}
},
"required": ["query", "samples"]
}),
}
}
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
let results = self.top_n(args).await?;
Ok(results
.into_iter()
.map(|(score, id, document)| VectorStoreOutput {
score,
id,
document,
})
.collect())
}
}
#[derive(Clone, Debug)]
pub enum IndexStrategy {
BruteForce,
LSH {
num_tables: usize,
num_hyperplanes: usize,
},
}
impl Default for IndexStrategy {
fn default() -> Self {
Self::BruteForce
}
}