1use futures::future::BoxFuture;
2use serde::Deserialize;
3use serde_json::Value;
4
5use crate::embeddings::EmbeddingError;
6
7pub mod in_memory_store;
8
9#[derive(Debug, thiserror::Error)]
10pub enum VectorStoreError {
11 #[error("Embedding error: {0}")]
12 EmbeddingError(#[from] EmbeddingError),
13
14 #[error("Json error: {0}")]
16 JsonError(#[from] serde_json::Error),
17
18 #[error("Datastore error: {0}")]
19 DatastoreError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
20
21 #[error("Missing Id: {0}")]
22 MissingIdError(String),
23}
24
25pub trait VectorStoreIndex: Send + Sync {
27 fn top_n<T: for<'a> Deserialize<'a> + Send>(
30 &self,
31 query: &str,
32 n: usize,
33 ) -> impl std::future::Future<Output = Result<Vec<(f64, String, T)>, VectorStoreError>> + Send;
34
35 fn top_n_ids(
37 &self,
38 query: &str,
39 n: usize,
40 ) -> impl std::future::Future<Output = Result<Vec<(f64, String)>, VectorStoreError>> + Send;
41}
42
43pub type TopNResults = Result<Vec<(f64, String, Value)>, VectorStoreError>;
44
45pub trait VectorStoreIndexDyn: Send + Sync {
46 fn top_n<'a>(&'a self, query: &'a str, n: usize) -> BoxFuture<'a, TopNResults>;
47
48 fn top_n_ids<'a>(
49 &'a self,
50 query: &'a str,
51 n: usize,
52 ) -> BoxFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>>;
53}
54
55impl<I: VectorStoreIndex> VectorStoreIndexDyn for I {
56 fn top_n<'a>(
57 &'a self,
58 query: &'a str,
59 n: usize,
60 ) -> BoxFuture<'a, Result<Vec<(f64, String, Value)>, VectorStoreError>> {
61 Box::pin(async move {
62 Ok(self
63 .top_n::<serde_json::Value>(query, n)
64 .await?
65 .into_iter()
66 .map(|(score, id, doc)| (score, id, prune_document(doc).unwrap_or_default()))
67 .collect::<Vec<_>>())
68 })
69 }
70
71 fn top_n_ids<'a>(
72 &'a self,
73 query: &'a str,
74 n: usize,
75 ) -> BoxFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>> {
76 Box::pin(self.top_n_ids(query, n))
77 }
78}
79
80fn prune_document(document: serde_json::Value) -> Option<serde_json::Value> {
81 match document {
82 Value::Object(mut map) => {
83 let new_map = map
84 .iter_mut()
85 .filter_map(|(key, value)| {
86 prune_document(value.take()).map(|value| (key.clone(), value))
87 })
88 .collect::<serde_json::Map<_, _>>();
89
90 Some(Value::Object(new_map))
91 }
92 Value::Array(vec) if vec.len() > 400 => None,
93 Value::Array(vec) => Some(Value::Array(
94 vec.into_iter().filter_map(prune_document).collect(),
95 )),
96 Value::Number(num) => Some(Value::Number(num)),
97 Value::String(s) => Some(Value::String(s)),
98 Value::Bool(b) => Some(Value::Bool(b)),
99 Value::Null => Some(Value::Null),
100 }
101}