1use std::sync::Arc;
4
5use async_trait::async_trait;
6
7use crate::tool::ToolFunction;
8
9#[async_trait]
11pub trait Toolset: Send + Sync {
12 fn get_tools(&self) -> Vec<Arc<dyn ToolFunction>>;
14
15 async fn close(&self) {}
17}
18
19pub struct StaticToolset {
21 tools: Vec<Arc<dyn ToolFunction>>,
22}
23
24impl StaticToolset {
25 pub fn new(tools: Vec<Arc<dyn ToolFunction>>) -> Self {
27 Self { tools }
28 }
29
30 pub fn filter_by_name(&self, names: &[&str]) -> Self {
32 let filtered = self
33 .tools
34 .iter()
35 .filter(|t| names.contains(&t.name()))
36 .cloned()
37 .collect();
38 Self { tools: filtered }
39 }
40}
41
42#[async_trait]
43impl Toolset for StaticToolset {
44 fn get_tools(&self) -> Vec<Arc<dyn ToolFunction>> {
45 self.tools.clone()
46 }
47}
48
49#[cfg(test)]
50mod tests {
51 use super::*;
52 use crate::error::ToolError;
53
54 struct DummyTool {
55 name: &'static str,
56 }
57
58 #[async_trait]
59 impl ToolFunction for DummyTool {
60 fn name(&self) -> &str {
61 self.name
62 }
63 fn description(&self) -> &str {
64 "dummy"
65 }
66 fn parameters(&self) -> Option<serde_json::Value> {
67 None
68 }
69 async fn call(&self, _args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
70 Ok(serde_json::json!({"ok": true}))
71 }
72 }
73
74 #[test]
75 fn static_toolset_get_tools() {
76 let toolset = StaticToolset::new(vec![
77 Arc::new(DummyTool { name: "a" }),
78 Arc::new(DummyTool { name: "b" }),
79 ]);
80 let tools = toolset.get_tools();
81 assert_eq!(tools.len(), 2);
82 assert_eq!(tools[0].name(), "a");
83 assert_eq!(tools[1].name(), "b");
84 }
85
86 #[test]
87 fn filter_by_name() {
88 let toolset = StaticToolset::new(vec![
89 Arc::new(DummyTool { name: "alpha" }),
90 Arc::new(DummyTool { name: "beta" }),
91 Arc::new(DummyTool { name: "gamma" }),
92 ]);
93
94 let filtered = toolset.filter_by_name(&["alpha", "gamma"]);
95 let tools = filtered.get_tools();
96 assert_eq!(tools.len(), 2);
97 assert_eq!(tools[0].name(), "alpha");
98 assert_eq!(tools[1].name(), "gamma");
99 }
100
101 #[test]
102 fn empty_toolset() {
103 let toolset = StaticToolset::new(vec![]);
104 assert!(toolset.get_tools().is_empty());
105 }
106
107 #[test]
108 fn filter_by_nonexistent_name() {
109 let toolset = StaticToolset::new(vec![Arc::new(DummyTool { name: "a" })]);
110 let filtered = toolset.filter_by_name(&["nonexistent"]);
111 assert!(filtered.get_tools().is_empty());
112 }
113
114 #[test]
115 fn toolset_is_object_safe() {
116 fn _assert(_: &dyn Toolset) {}
117 }
118}