ai_chain/tools/tools/
vectorstore.rs

1//! The vector store tool accessses information from vector stores.
2//!
3//! Use it to give your LLM memory or access to semantically searchable information.
4use std::marker::PhantomData;
5
6use async_trait::async_trait;
7use serde::{de::DeserializeOwned, Deserialize, Serialize};
8use thiserror::Error;
9
10use crate::{
11    tools::{Describe, Format, Tool, ToolDescription, ToolError},
12    traits::{Embeddings, EmbeddingsError, VectorStore, VectorStoreError},
13};
14
15pub struct VectorStoreTool<E, M, V>
16where
17    E: Embeddings,
18    V: VectorStore<E, M>,
19    M: Serialize + DeserializeOwned,
20{
21    pub store: V,
22    pub topic: String,
23    pub topic_context: String,
24    _data1: PhantomData<E>,
25    _data2: PhantomData<M>,
26}
27
28impl<E, M, V> VectorStoreTool<E, M, V>
29where
30    E: Embeddings,
31    M: Serialize + DeserializeOwned,
32    V: VectorStore<E, M>,
33{
34    pub fn new(store: V, topic: &str, topic_context: &str) -> Self {
35        Self {
36            store,
37            topic: topic.to_string(),
38            topic_context: topic_context.to_string(),
39            _data1: Default::default(),
40            _data2: Default::default(),
41        }
42    }
43}
44
45#[derive(Debug, Error)]
46pub enum VectorStoreToolError<V, E>
47where
48    V: std::fmt::Debug + std::error::Error + VectorStoreError,
49    E: std::fmt::Debug + std::error::Error + EmbeddingsError,
50{
51    #[error(transparent)]
52    YamlError(#[from] serde_yaml::Error),
53    #[error(transparent)]
54    VectorStoreError(#[from] V),
55    #[error(transparent)]
56    Embeddings(E),
57}
58
59impl<V, E> ToolError for VectorStoreToolError<V, E>
60where
61    V: std::fmt::Debug + std::error::Error + VectorStoreError,
62    E: std::fmt::Debug + std::error::Error + EmbeddingsError,
63{
64}
65
66#[derive(Serialize, Deserialize)]
67pub struct VectorStoreToolInput {
68    query: String,
69    limit: u32,
70}
71
72#[derive(Serialize, Deserialize)]
73pub struct VectorStoreToolOutput {
74    texts: Vec<String>,
75}
76
77impl Describe for VectorStoreToolInput {
78    fn describe() -> Format {
79        vec![
80            (
81                "query",
82                "You can search for texts similar to this one in the vector database.",
83            )
84                .into(),
85            (
86                "limit",
87                "The number of texts that will be returned from the vector database.",
88            )
89                .into(),
90        ]
91        .into()
92    }
93}
94
95impl Describe for VectorStoreToolOutput {
96    fn describe() -> Format {
97        vec![
98            ("texts", "List of texts similar to the query.").into(),
99            (
100                "error_msg",
101                "Error message received when trying to search in the vector database.",
102            )
103                .into(),
104        ]
105        .into()
106    }
107}
108
109#[async_trait]
110impl<E, M, V> Tool for VectorStoreTool<E, M, V>
111where
112    E: Embeddings + Sync + Send,
113    V: VectorStore<E, M> + Sync + Send,
114    M: Sync + Send + serde::Serialize + serde::de::DeserializeOwned,
115    Self: 'static,
116{
117    type Input = VectorStoreToolInput;
118    type Output = VectorStoreToolOutput;
119    type Error = VectorStoreToolError<<V as VectorStore<E, M>>::Error, <E as Embeddings>::Error>;
120
121    async fn invoke_typed(&self, input: &Self::Input) -> Result<Self::Output, Self::Error> {
122        match self
123            .store
124            .similarity_search(input.query.clone(), input.limit)
125            .await
126        {
127            Ok(o) => Ok(VectorStoreToolOutput {
128                texts: o.into_iter().map(|d| d.page_content).collect(),
129            }),
130            Err(e) => Err(<<V as VectorStore<E, M>>::Error as Into<Self::Error>>::into(e)),
131        }
132    }
133
134    fn description(&self) -> crate::tools::ToolDescription {
135        ToolDescription::new(
136            "VectorStoreTool",
137            "A tool that retrieves documents based on similarity to a given query.",
138            &format!(
139                "Useful for when you need to answer questions about {}. 
140            Whenever you need information about {} 
141            you should ALWAYS use this. 
142            Input should be a fully formed question.",
143                self.topic, self.topic_context
144            ),
145            Self::Input::describe(),
146            Self::Output::describe(),
147        )
148    }
149}