Skip to main content

rs_adk/
toolset.rs

1//! Toolset trait — collections of tools that can be enumerated and filtered.
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6
7use crate::tool::ToolFunction;
8
9/// A collection of tools that can be enumerated.
10#[async_trait]
11pub trait Toolset: Send + Sync {
12    /// Get all tools in this toolset.
13    fn get_tools(&self) -> Vec<Arc<dyn ToolFunction>>;
14
15    /// Clean up resources when the toolset is no longer needed.
16    async fn close(&self) {}
17}
18
19/// A simple toolset backed by a fixed list of tools.
20pub struct StaticToolset {
21    tools: Vec<Arc<dyn ToolFunction>>,
22}
23
24impl StaticToolset {
25    /// Create a new static toolset from a list of tools.
26    pub fn new(tools: Vec<Arc<dyn ToolFunction>>) -> Self {
27        Self { tools }
28    }
29
30    /// Create a new toolset containing only tools whose names are in `names`.
31    pub fn filter_by_name(&self, names: &[&str]) -> Self {
32        let filtered = self
33            .tools
34            .iter()
35            .filter(|t| names.contains(&t.name()))
36            .cloned()
37            .collect();
38        Self { tools: filtered }
39    }
40}
41
42#[async_trait]
43impl Toolset for StaticToolset {
44    fn get_tools(&self) -> Vec<Arc<dyn ToolFunction>> {
45        self.tools.clone()
46    }
47}
48
49#[cfg(test)]
50mod tests {
51    use super::*;
52    use crate::error::ToolError;
53
54    struct DummyTool {
55        name: &'static str,
56    }
57
58    #[async_trait]
59    impl ToolFunction for DummyTool {
60        fn name(&self) -> &str {
61            self.name
62        }
63        fn description(&self) -> &str {
64            "dummy"
65        }
66        fn parameters(&self) -> Option<serde_json::Value> {
67            None
68        }
69        async fn call(&self, _args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
70            Ok(serde_json::json!({"ok": true}))
71        }
72    }
73
74    #[test]
75    fn static_toolset_get_tools() {
76        let toolset = StaticToolset::new(vec![
77            Arc::new(DummyTool { name: "a" }),
78            Arc::new(DummyTool { name: "b" }),
79        ]);
80        let tools = toolset.get_tools();
81        assert_eq!(tools.len(), 2);
82        assert_eq!(tools[0].name(), "a");
83        assert_eq!(tools[1].name(), "b");
84    }
85
86    #[test]
87    fn filter_by_name() {
88        let toolset = StaticToolset::new(vec![
89            Arc::new(DummyTool { name: "alpha" }),
90            Arc::new(DummyTool { name: "beta" }),
91            Arc::new(DummyTool { name: "gamma" }),
92        ]);
93
94        let filtered = toolset.filter_by_name(&["alpha", "gamma"]);
95        let tools = filtered.get_tools();
96        assert_eq!(tools.len(), 2);
97        assert_eq!(tools[0].name(), "alpha");
98        assert_eq!(tools[1].name(), "gamma");
99    }
100
101    #[test]
102    fn empty_toolset() {
103        let toolset = StaticToolset::new(vec![]);
104        assert!(toolset.get_tools().is_empty());
105    }
106
107    #[test]
108    fn filter_by_nonexistent_name() {
109        let toolset = StaticToolset::new(vec![Arc::new(DummyTool { name: "a" })]);
110        let filtered = toolset.filter_by_name(&["nonexistent"]);
111        assert!(filtered.get_tools().is_empty());
112    }
113
114    #[test]
115    fn toolset_is_object_safe() {
116        fn _assert(_: &dyn Toolset) {}
117    }
118}