1use std::{collections::HashMap, pin::Pin};
13
14use futures::Future;
15use serde::{Deserialize, Serialize};
16
17use crate::{
18 completion::{self, ToolDefinition},
19 embeddings::{embed::EmbedError, tool::ToolSchema},
20};
21
22#[derive(Debug, thiserror::Error)]
23pub enum ToolError {
24 #[error("ToolCallError: {0}")]
26 ToolCallError(#[from] Box<dyn std::error::Error + Send + Sync>),
27
28 #[error("JsonError: {0}")]
29 JsonError(#[from] serde_json::Error),
30}
31
32pub trait Tool: Sized + Send + Sync {
88 const NAME: &'static str;
90
91 type Error: std::error::Error + Send + Sync + 'static;
93 type Args: for<'a> Deserialize<'a> + Send + Sync;
95 type Output: Serialize;
97
98 fn name(&self) -> String {
100 Self::NAME.to_string()
101 }
102
103 fn definition(&self, _prompt: String) -> impl Future<Output = ToolDefinition> + Send + Sync;
106
107 fn call(
111 &self,
112 args: Self::Args,
113 ) -> impl Future<Output = Result<Self::Output, Self::Error>> + Send + Sync;
114}
115
116pub trait ToolEmbedding: Tool {
118 type InitError: std::error::Error + Send + Sync + 'static;
119
120 type Context: for<'a> Deserialize<'a> + Serialize;
125
126 type State: Send;
130
131 fn embedding_docs(&self) -> Vec<String>;
135
136 fn context(&self) -> Self::Context;
138
139 fn init(state: Self::State, context: Self::Context) -> Result<Self, Self::InitError>;
141}
142
143pub trait ToolDyn: Send + Sync {
145 fn name(&self) -> String;
146
147 fn definition(
148 &self,
149 prompt: String,
150 ) -> Pin<Box<dyn Future<Output = ToolDefinition> + Send + Sync + '_>>;
151
152 fn call(
153 &self,
154 args: String,
155 ) -> Pin<Box<dyn Future<Output = Result<String, ToolError>> + Send + Sync + '_>>;
156}
157
158impl<T: Tool> ToolDyn for T {
159 fn name(&self) -> String {
160 self.name()
161 }
162
163 fn definition(
164 &self,
165 prompt: String,
166 ) -> Pin<Box<dyn Future<Output = ToolDefinition> + Send + Sync + '_>> {
167 Box::pin(<Self as Tool>::definition(self, prompt))
168 }
169
170 fn call(
171 &self,
172 args: String,
173 ) -> Pin<Box<dyn Future<Output = Result<String, ToolError>> + Send + Sync + '_>> {
174 Box::pin(async move {
175 match serde_json::from_str(&args) {
176 Ok(args) => <Self as Tool>::call(self, args)
177 .await
178 .map_err(|e| ToolError::ToolCallError(Box::new(e)))
179 .and_then(|output| {
180 serde_json::to_string(&output).map_err(ToolError::JsonError)
181 }),
182 Err(e) => Err(ToolError::JsonError(e)),
183 }
184 })
185 }
186}
187
188pub trait ToolEmbeddingDyn: ToolDyn {
190 fn context(&self) -> serde_json::Result<serde_json::Value>;
191
192 fn embedding_docs(&self) -> Vec<String>;
193}
194
195impl<T: ToolEmbedding> ToolEmbeddingDyn for T {
196 fn context(&self) -> serde_json::Result<serde_json::Value> {
197 serde_json::to_value(self.context())
198 }
199
200 fn embedding_docs(&self) -> Vec<String> {
201 self.embedding_docs()
202 }
203}
204
205pub(crate) enum ToolType {
206 Simple(Box<dyn ToolDyn>),
207 Embedding(Box<dyn ToolEmbeddingDyn>),
208}
209
210impl ToolType {
211 pub fn name(&self) -> String {
212 match self {
213 ToolType::Simple(tool) => tool.name(),
214 ToolType::Embedding(tool) => tool.name(),
215 }
216 }
217
218 pub async fn definition(&self, prompt: String) -> ToolDefinition {
219 match self {
220 ToolType::Simple(tool) => tool.definition(prompt).await,
221 ToolType::Embedding(tool) => tool.definition(prompt).await,
222 }
223 }
224
225 pub async fn call(&self, args: String) -> Result<String, ToolError> {
226 match self {
227 ToolType::Simple(tool) => tool.call(args).await,
228 ToolType::Embedding(tool) => tool.call(args).await,
229 }
230 }
231}
232
233#[derive(Debug, thiserror::Error)]
234pub enum ToolSetError {
235 #[error("ToolCallError: {0}")]
237 ToolCallError(#[from] ToolError),
238
239 #[error("ToolNotFoundError: {0}")]
240 ToolNotFoundError(String),
241
242 #[error("JsonError: {0}")]
244 JsonError(#[from] serde_json::Error),
245}
246
247#[derive(Default)]
249pub struct ToolSet {
250 pub(crate) tools: HashMap<String, ToolType>,
251}
252
253impl ToolSet {
254 pub fn from_tools(tools: Vec<impl ToolDyn + 'static>) -> Self {
256 let mut toolset = Self::default();
257 tools.into_iter().for_each(|tool| {
258 toolset.add_tool(tool);
259 });
260 toolset
261 }
262
263 pub fn builder() -> ToolSetBuilder {
265 ToolSetBuilder::default()
266 }
267
268 pub fn contains(&self, toolname: &str) -> bool {
270 self.tools.contains_key(toolname)
271 }
272
273 pub fn add_tool(&mut self, tool: impl ToolDyn + 'static) {
275 self.tools
276 .insert(tool.name(), ToolType::Simple(Box::new(tool)));
277 }
278
279 pub fn add_tools(&mut self, toolset: ToolSet) {
281 self.tools.extend(toolset.tools);
282 }
283
284 pub(crate) fn get(&self, toolname: &str) -> Option<&ToolType> {
285 self.tools.get(toolname)
286 }
287
288 pub async fn call(&self, toolname: &str, args: String) -> Result<String, ToolSetError> {
290 if let Some(tool) = self.tools.get(toolname) {
291 tracing::info!(target: "bep",
292 "Calling tool {toolname} with args:\n{}",
293 serde_json::to_string_pretty(&args).unwrap_or_else(|_| args.clone())
294 );
295 Ok(tool.call(args).await?)
296 } else {
297 Err(ToolSetError::ToolNotFoundError(toolname.to_string()))
298 }
299 }
300
301 pub async fn documents(&self) -> Result<Vec<completion::Document>, ToolSetError> {
303 let mut docs = Vec::new();
304 for tool in self.tools.values() {
305 match tool {
306 ToolType::Simple(tool) => {
307 docs.push(completion::Document {
308 id: tool.name(),
309 text: format!(
310 "\
311 Tool: {}\n\
312 Definition: \n\
313 {}\
314 ",
315 tool.name(),
316 serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
317 ),
318 additional_props: HashMap::new(),
319 });
320 }
321 ToolType::Embedding(tool) => {
322 docs.push(completion::Document {
323 id: tool.name(),
324 text: format!(
325 "\
326 Tool: {}\n\
327 Definition: \n\
328 {}\
329 ",
330 tool.name(),
331 serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
332 ),
333 additional_props: HashMap::new(),
334 });
335 }
336 }
337 }
338 Ok(docs)
339 }
340
341 pub fn schemas(&self) -> Result<Vec<ToolSchema>, EmbedError> {
345 self.tools
346 .values()
347 .filter_map(|tool_type| {
348 if let ToolType::Embedding(tool) = tool_type {
349 Some(ToolSchema::try_from(&**tool))
350 } else {
351 None
352 }
353 })
354 .collect::<Result<Vec<_>, _>>()
355 }
356}
357
358#[derive(Default)]
359pub struct ToolSetBuilder {
360 tools: Vec<ToolType>,
361}
362
363impl ToolSetBuilder {
364 pub fn static_tool(mut self, tool: impl ToolDyn + 'static) -> Self {
365 self.tools.push(ToolType::Simple(Box::new(tool)));
366 self
367 }
368
369 pub fn dynamic_tool(mut self, tool: impl ToolEmbeddingDyn + 'static) -> Self {
370 self.tools.push(ToolType::Embedding(Box::new(tool)));
371 self
372 }
373
374 pub fn build(self) -> ToolSet {
375 ToolSet {
376 tools: self
377 .tools
378 .into_iter()
379 .map(|tool| (tool.name(), tool))
380 .collect(),
381 }
382 }
383}