rig/vector_store/
mod.rs

1use futures::future::BoxFuture;
2pub use request::VectorSearchRequest;
3use reqwest::StatusCode;
4use serde::{Deserialize, Serialize};
5use serde_json::{Value, json};
6
7use crate::{
8    Embed, OneOrMany,
9    completion::ToolDefinition,
10    embeddings::{Embedding, EmbeddingError},
11    tool::Tool,
12};
13
14pub mod in_memory_store;
15pub mod request;
16
17#[derive(Debug, thiserror::Error)]
18pub enum VectorStoreError {
19    #[error("Embedding error: {0}")]
20    EmbeddingError(#[from] EmbeddingError),
21
22    /// Json error (e.g.: serialization, deserialization, etc.)
23    #[error("Json error: {0}")]
24    JsonError(#[from] serde_json::Error),
25
26    #[error("Datastore error: {0}")]
27    DatastoreError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
28
29    #[error("Missing Id: {0}")]
30    MissingIdError(String),
31
32    #[error("HTTP request error: {0}")]
33    ReqwestError(#[from] reqwest::Error),
34
35    #[error("External call to API returned an error. Error code: {0} Message: {1}")]
36    ExternalAPIError(StatusCode, String),
37
38    #[error("Error while building VectorSearchRequest: {0}")]
39    BuilderError(String),
40}
41
42/// Trait for inserting documents into a vector store.
43pub trait InsertDocuments: Send + Sync {
44    /// Insert documents into the vector store.
45    ///
46    fn insert_documents<Doc: Serialize + Embed + Send>(
47        &self,
48        documents: Vec<(Doc, OneOrMany<Embedding>)>,
49    ) -> impl std::future::Future<Output = Result<(), VectorStoreError>> + Send;
50}
51
52/// Trait for vector store indexes
53pub trait VectorStoreIndex: Send + Sync {
54    /// Get the top n documents based on the distance to the given query.
55    /// The result is a list of tuples of the form (score, id, document)
56    fn top_n<T: for<'a> Deserialize<'a> + Send>(
57        &self,
58        req: VectorSearchRequest,
59    ) -> impl std::future::Future<Output = Result<Vec<(f64, String, T)>, VectorStoreError>> + Send;
60
61    /// Same as `top_n` but returns the document ids only.
62    fn top_n_ids(
63        &self,
64        req: VectorSearchRequest,
65    ) -> impl std::future::Future<Output = Result<Vec<(f64, String)>, VectorStoreError>> + Send;
66}
67
68pub type TopNResults = Result<Vec<(f64, String, Value)>, VectorStoreError>;
69
70pub trait VectorStoreIndexDyn: Send + Sync {
71    fn top_n<'a>(&'a self, req: VectorSearchRequest) -> BoxFuture<'a, TopNResults>;
72
73    fn top_n_ids<'a>(
74        &'a self,
75        req: VectorSearchRequest,
76    ) -> BoxFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>>;
77}
78
79impl<I: VectorStoreIndex> VectorStoreIndexDyn for I {
80    fn top_n<'a>(
81        &'a self,
82        req: VectorSearchRequest,
83    ) -> BoxFuture<'a, Result<Vec<(f64, String, Value)>, VectorStoreError>> {
84        Box::pin(async move {
85            Ok(self
86                .top_n::<serde_json::Value>(req)
87                .await?
88                .into_iter()
89                .map(|(score, id, doc)| (score, id, prune_document(doc).unwrap_or_default()))
90                .collect::<Vec<_>>())
91        })
92    }
93
94    fn top_n_ids<'a>(
95        &'a self,
96        req: VectorSearchRequest,
97    ) -> BoxFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>> {
98        Box::pin(self.top_n_ids(req))
99    }
100}
101
102fn prune_document(document: serde_json::Value) -> Option<serde_json::Value> {
103    match document {
104        Value::Object(mut map) => {
105            let new_map = map
106                .iter_mut()
107                .filter_map(|(key, value)| {
108                    prune_document(value.take()).map(|value| (key.clone(), value))
109                })
110                .collect::<serde_json::Map<_, _>>();
111
112            Some(Value::Object(new_map))
113        }
114        Value::Array(vec) if vec.len() > 400 => None,
115        Value::Array(vec) => Some(Value::Array(
116            vec.into_iter().filter_map(prune_document).collect(),
117        )),
118        Value::Number(num) => Some(Value::Number(num)),
119        Value::String(s) => Some(Value::String(s)),
120        Value::Bool(b) => Some(Value::Bool(b)),
121        Value::Null => Some(Value::Null),
122    }
123}
124
125#[derive(Serialize, Deserialize, Debug)]
126pub struct VectorStoreOutput {
127    pub score: f64,
128    pub id: String,
129    pub document: Value,
130}
131
132impl<T> Tool for T
133where
134    T: VectorStoreIndex,
135{
136    const NAME: &'static str = "search_vector_store";
137
138    type Error = VectorStoreError;
139    type Args = VectorSearchRequest;
140    type Output = Vec<VectorStoreOutput>;
141
142    async fn definition(&self, _prompt: String) -> ToolDefinition {
143        ToolDefinition {
144            name: Self::NAME.to_string(),
145            description:
146                "Retrieves the most relevant documents from a vector store based on a query."
147                    .to_string(),
148            parameters: json!({
149                "type": "object",
150                "properties": {
151                    "query": {
152                        "type": "string",
153                        "description": "The query string to search for relevant documents in the vector store."
154                    },
155                    "samples": {
156                        "type": "integer",
157                        "description": "The maxinum number of samples / documents to retrieve.",
158                        "default": 5,
159                        "minimum": 1
160                    },
161                    "threshold": {
162                        "type": "number",
163                        "description": "Similarity search threshold. If present, any result with a distance less than this may be omitted from the final result."
164                    }
165                },
166                "required": ["query", "samples"]
167            }),
168        }
169    }
170
171    async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
172        let results = self.top_n(args).await?;
173        Ok(results
174            .into_iter()
175            .map(|(score, id, document)| VectorStoreOutput {
176                score,
177                id,
178                document,
179            })
180            .collect())
181    }
182}