adk_core/
tool.rs

1use crate::{CallbackContext, EventActions, MemoryEntry, Result};
2use async_trait::async_trait;
3use serde_json::Value;
4use std::sync::Arc;
5
6#[async_trait]
7pub trait Tool: Send + Sync {
8    fn name(&self) -> &str;
9    fn description(&self) -> &str;
10
11    /// Returns an enhanced description that may include additional notes.
12    /// For long-running tools, this includes a warning not to call the tool
13    /// again if it has already returned a pending status.
14    /// Default implementation returns the base description.
15    fn enhanced_description(&self) -> String {
16        self.description().to_string()
17    }
18
19    /// Indicates whether the tool is a long-running operation.
20    /// Long-running tools typically return a task ID immediately and
21    /// complete the operation asynchronously.
22    fn is_long_running(&self) -> bool {
23        false
24    }
25    fn parameters_schema(&self) -> Option<Value> {
26        None
27    }
28    fn response_schema(&self) -> Option<Value> {
29        None
30    }
31    async fn execute(&self, ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value>;
32}
33
34#[async_trait]
35pub trait ToolContext: CallbackContext {
36    fn function_call_id(&self) -> &str;
37    /// Get the current event actions. Returns an owned copy for thread safety.
38    fn actions(&self) -> EventActions;
39    /// Set the event actions (e.g., to trigger escalation or skip summarization).
40    fn set_actions(&self, actions: EventActions);
41    async fn search_memory(&self, query: &str) -> Result<Vec<MemoryEntry>>;
42}
43
44#[async_trait]
45pub trait Toolset: Send + Sync {
46    fn name(&self) -> &str;
47    async fn tools(&self, ctx: Arc<dyn crate::ReadonlyContext>) -> Result<Vec<Arc<dyn Tool>>>;
48}
49
50pub type ToolPredicate = Box<dyn Fn(&dyn Tool) -> bool + Send + Sync>;
51
52#[cfg(test)]
53mod tests {
54    use super::*;
55    use crate::{Content, EventActions, ReadonlyContext, RunConfig};
56    use std::sync::Mutex;
57
58    struct TestTool {
59        name: String,
60    }
61
62    #[allow(dead_code)]
63    struct TestContext {
64        content: Content,
65        config: RunConfig,
66        actions: Mutex<EventActions>,
67    }
68
69    impl TestContext {
70        fn new() -> Self {
71            Self {
72                content: Content::new("user"),
73                config: RunConfig::default(),
74                actions: Mutex::new(EventActions::default()),
75            }
76        }
77    }
78
79    #[async_trait]
80    impl ReadonlyContext for TestContext {
81        fn invocation_id(&self) -> &str {
82            "test"
83        }
84        fn agent_name(&self) -> &str {
85            "test"
86        }
87        fn user_id(&self) -> &str {
88            "user"
89        }
90        fn app_name(&self) -> &str {
91            "app"
92        }
93        fn session_id(&self) -> &str {
94            "session"
95        }
96        fn branch(&self) -> &str {
97            ""
98        }
99        fn user_content(&self) -> &Content {
100            &self.content
101        }
102    }
103
104    #[async_trait]
105    impl CallbackContext for TestContext {
106        fn artifacts(&self) -> Option<Arc<dyn crate::Artifacts>> {
107            None
108        }
109    }
110
111    #[async_trait]
112    impl ToolContext for TestContext {
113        fn function_call_id(&self) -> &str {
114            "call-123"
115        }
116        fn actions(&self) -> EventActions {
117            self.actions.lock().unwrap().clone()
118        }
119        fn set_actions(&self, actions: EventActions) {
120            *self.actions.lock().unwrap() = actions;
121        }
122        async fn search_memory(&self, _query: &str) -> Result<Vec<crate::MemoryEntry>> {
123            Ok(vec![])
124        }
125    }
126
127    #[async_trait]
128    impl Tool for TestTool {
129        fn name(&self) -> &str {
130            &self.name
131        }
132
133        fn description(&self) -> &str {
134            "test tool"
135        }
136
137        async fn execute(&self, _ctx: Arc<dyn ToolContext>, _args: Value) -> Result<Value> {
138            Ok(Value::String("result".to_string()))
139        }
140    }
141
142    #[test]
143    fn test_tool_trait() {
144        let tool = TestTool { name: "test".to_string() };
145        assert_eq!(tool.name(), "test");
146        assert_eq!(tool.description(), "test tool");
147        assert!(!tool.is_long_running());
148    }
149
150    #[tokio::test]
151    async fn test_tool_execute() {
152        let tool = TestTool { name: "test".to_string() };
153        let ctx = Arc::new(TestContext::new()) as Arc<dyn ToolContext>;
154        let result = tool.execute(ctx, Value::Null).await.unwrap();
155        assert_eq!(result, Value::String("result".to_string()));
156    }
157}