1use std::marker::PhantomData;
5
6use async_trait::async_trait;
7use serde::{de::DeserializeOwned, Deserialize, Serialize};
8use thiserror::Error;
9
10use crate::{
11 tools::{Describe, Format, Tool, ToolDescription, ToolError},
12 traits::{Embeddings, EmbeddingsError, VectorStore, VectorStoreError},
13};
14
15pub struct VectorStoreTool<E, M, V>
16where
17 E: Embeddings,
18 V: VectorStore<E, M>,
19 M: Serialize + DeserializeOwned,
20{
21 pub store: V,
22 pub topic: String,
23 pub topic_context: String,
24 _data1: PhantomData<E>,
25 _data2: PhantomData<M>,
26}
27
28impl<E, M, V> VectorStoreTool<E, M, V>
29where
30 E: Embeddings,
31 M: Serialize + DeserializeOwned,
32 V: VectorStore<E, M>,
33{
34 pub fn new(store: V, topic: &str, topic_context: &str) -> Self {
35 Self {
36 store,
37 topic: topic.to_string(),
38 topic_context: topic_context.to_string(),
39 _data1: Default::default(),
40 _data2: Default::default(),
41 }
42 }
43}
44
45#[derive(Debug, Error)]
46pub enum VectorStoreToolError<V, E>
47where
48 V: std::fmt::Debug + std::error::Error + VectorStoreError,
49 E: std::fmt::Debug + std::error::Error + EmbeddingsError,
50{
51 #[error(transparent)]
52 YamlError(#[from] serde_yaml::Error),
53 #[error(transparent)]
54 VectorStoreError(#[from] V),
55 #[error(transparent)]
56 Embeddings(E),
57}
58
59impl<V, E> ToolError for VectorStoreToolError<V, E>
60where
61 V: std::fmt::Debug + std::error::Error + VectorStoreError,
62 E: std::fmt::Debug + std::error::Error + EmbeddingsError,
63{
64}
65
66#[derive(Serialize, Deserialize)]
67pub struct VectorStoreToolInput {
68 query: String,
69 limit: u32,
70}
71
72#[derive(Serialize, Deserialize)]
73pub struct VectorStoreToolOutput {
74 texts: Vec<String>,
75}
76
77impl Describe for VectorStoreToolInput {
78 fn describe() -> Format {
79 vec![
80 (
81 "query",
82 "You can search for texts similar to this one in the vector database.",
83 )
84 .into(),
85 (
86 "limit",
87 "The number of texts that will be returned from the vector database.",
88 )
89 .into(),
90 ]
91 .into()
92 }
93}
94
95impl Describe for VectorStoreToolOutput {
96 fn describe() -> Format {
97 vec![
98 ("texts", "List of texts similar to the query.").into(),
99 (
100 "error_msg",
101 "Error message received when trying to search in the vector database.",
102 )
103 .into(),
104 ]
105 .into()
106 }
107}
108
109#[async_trait]
110impl<E, M, V> Tool for VectorStoreTool<E, M, V>
111where
112 E: Embeddings + Sync + Send,
113 V: VectorStore<E, M> + Sync + Send,
114 M: Sync + Send + serde::Serialize + serde::de::DeserializeOwned,
115 Self: 'static,
116{
117 type Input = VectorStoreToolInput;
118 type Output = VectorStoreToolOutput;
119 type Error = VectorStoreToolError<<V as VectorStore<E, M>>::Error, <E as Embeddings>::Error>;
120
121 async fn invoke_typed(&self, input: &Self::Input) -> Result<Self::Output, Self::Error> {
122 match self
123 .store
124 .similarity_search(input.query.clone(), input.limit)
125 .await
126 {
127 Ok(o) => Ok(VectorStoreToolOutput {
128 texts: o.into_iter().map(|d| d.page_content).collect(),
129 }),
130 Err(e) => Err(<<V as VectorStore<E, M>>::Error as Into<Self::Error>>::into(e)),
131 }
132 }
133
134 fn description(&self) -> crate::tools::ToolDescription {
135 ToolDescription::new(
136 "VectorStoreTool",
137 "A tool that retrieves documents based on similarity to a given query.",
138 &format!(
139 "Useful for when you need to answer questions about {}.
140 Whenever you need information about {}
141 you should ALWAYS use this.
142 Input should be a fully formed question.",
143 self.topic, self.topic_context
144 ),
145 Self::Input::describe(),
146 Self::Output::describe(),
147 )
148 }
149}