Skip to main content

adk_core/
tool.rs

1use crate::{CallbackContext, EventActions, MemoryEntry, Result};
2use async_trait::async_trait;
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use std::sync::Arc;
6
7#[async_trait]
8pub trait Tool: Send + Sync {
9    fn name(&self) -> &str;
10    fn description(&self) -> &str;
11
12    /// Returns an enhanced description that may include additional notes.
13    /// For long-running tools, this includes a warning not to call the tool
14    /// again if it has already returned a pending status.
15    /// Default implementation returns the base description.
16    fn enhanced_description(&self) -> String {
17        self.description().to_string()
18    }
19
20    /// Indicates whether the tool is a long-running operation.
21    /// Long-running tools typically return a task ID immediately and
22    /// complete the operation asynchronously.
23    fn is_long_running(&self) -> bool {
24        false
25    }
26    fn parameters_schema(&self) -> Option<Value> {
27        None
28    }
29    fn response_schema(&self) -> Option<Value> {
30        None
31    }
32
33    /// Returns the scopes required to execute this tool.
34    ///
35    /// When non-empty, the framework can enforce that the calling user
36    /// possesses **all** listed scopes before dispatching `execute()`.
37    /// The default implementation returns an empty slice (no scopes required).
38    ///
39    /// # Example
40    ///
41    /// ```rust,ignore
42    /// fn required_scopes(&self) -> &[&str] {
43    ///     &["finance:write", "verified"]
44    /// }
45    /// ```
46    fn required_scopes(&self) -> &[&str] {
47        &[]
48    }
49
50    async fn execute(&self, ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value>;
51}
52
53#[async_trait]
54pub trait ToolContext: CallbackContext {
55    fn function_call_id(&self) -> &str;
56    /// Get the current event actions. Returns an owned copy for thread safety.
57    fn actions(&self) -> EventActions;
58    /// Set the event actions (e.g., to trigger escalation or skip summarization).
59    fn set_actions(&self, actions: EventActions);
60    async fn search_memory(&self, query: &str) -> Result<Vec<MemoryEntry>>;
61
62    /// Returns the scopes granted to the current user for this invocation.
63    ///
64    /// Implementations may resolve scopes from session state, JWT claims,
65    /// or an external identity provider. The default returns an empty set
66    /// (no scopes granted), which means scope-protected tools will be denied
67    /// unless the implementation is overridden.
68    fn user_scopes(&self) -> Vec<String> {
69        vec![]
70    }
71}
72
73/// Configuration for automatic tool retry on failure.
74///
75/// Controls how many times a failed tool execution is retried before
76/// propagating the error. Applied as a flat delay between attempts
77/// (no exponential backoff in V1).
78///
79/// # Example
80///
81/// ```rust
82/// use std::time::Duration;
83/// use adk_core::RetryBudget;
84///
85/// // Retry up to 2 times with 500ms between attempts (3 total attempts)
86/// let budget = RetryBudget::new(2, Duration::from_millis(500));
87/// assert_eq!(budget.max_retries, 2);
88/// ```
89#[derive(Debug, Clone)]
90pub struct RetryBudget {
91    /// Maximum number of retry attempts (not counting the initial attempt).
92    /// E.g., `max_retries: 2` means up to 3 total attempts.
93    pub max_retries: u32,
94    /// Delay between retries. Applied as a flat delay (no backoff in V1).
95    pub delay: std::time::Duration,
96}
97
98impl RetryBudget {
99    /// Create a new retry budget.
100    ///
101    /// # Arguments
102    ///
103    /// * `max_retries` - Maximum retry attempts (not counting the initial attempt)
104    /// * `delay` - Flat delay between retry attempts
105    pub fn new(max_retries: u32, delay: std::time::Duration) -> Self {
106        Self { max_retries, delay }
107    }
108}
109
110#[async_trait]
111pub trait Toolset: Send + Sync {
112    fn name(&self) -> &str;
113    async fn tools(&self, ctx: Arc<dyn crate::ReadonlyContext>) -> Result<Vec<Arc<dyn Tool>>>;
114}
115
116/// Controls how the framework handles skills/agents that request unavailable tools.
117#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
118pub enum ValidationMode {
119    /// Reject the operation entirely if any requested tool is missing from the registry.
120    #[default]
121    Strict,
122    /// Bind available tools, omit missing ones, and log a warning.
123    Permissive,
124}
125
126/// A registry that maps tool names to concrete tool instances.
127///
128/// Implementations resolve string identifiers (e.g. from a skill or config)
129/// into executable `Arc<dyn Tool>` instances.
130pub trait ToolRegistry: Send + Sync {
131    /// Resolve a tool name to a concrete tool instance.
132    /// Returns `None` if the tool is not available in this registry.
133    fn resolve(&self, tool_name: &str) -> Option<Arc<dyn Tool>>;
134
135    /// Returns a list of all tool names available in this registry.
136    fn available_tools(&self) -> Vec<String> {
137        vec![]
138    }
139}
140
141pub type ToolPredicate = Box<dyn Fn(&dyn Tool) -> bool + Send + Sync>;
142
143#[cfg(test)]
144mod tests {
145    use super::*;
146    use crate::{Content, EventActions, ReadonlyContext, RunConfig};
147    use std::sync::Mutex;
148
149    struct TestTool {
150        name: String,
151    }
152
153    #[allow(dead_code)]
154    struct TestContext {
155        content: Content,
156        config: RunConfig,
157        actions: Mutex<EventActions>,
158    }
159
160    impl TestContext {
161        fn new() -> Self {
162            Self {
163                content: Content::new("user"),
164                config: RunConfig::default(),
165                actions: Mutex::new(EventActions::default()),
166            }
167        }
168    }
169
170    #[async_trait]
171    impl ReadonlyContext for TestContext {
172        fn invocation_id(&self) -> &str {
173            "test"
174        }
175        fn agent_name(&self) -> &str {
176            "test"
177        }
178        fn user_id(&self) -> &str {
179            "user"
180        }
181        fn app_name(&self) -> &str {
182            "app"
183        }
184        fn session_id(&self) -> &str {
185            "session"
186        }
187        fn branch(&self) -> &str {
188            ""
189        }
190        fn user_content(&self) -> &Content {
191            &self.content
192        }
193    }
194
195    #[async_trait]
196    impl CallbackContext for TestContext {
197        fn artifacts(&self) -> Option<Arc<dyn crate::Artifacts>> {
198            None
199        }
200    }
201
202    #[async_trait]
203    impl ToolContext for TestContext {
204        fn function_call_id(&self) -> &str {
205            "call-123"
206        }
207        fn actions(&self) -> EventActions {
208            self.actions.lock().unwrap().clone()
209        }
210        fn set_actions(&self, actions: EventActions) {
211            *self.actions.lock().unwrap() = actions;
212        }
213        async fn search_memory(&self, _query: &str) -> Result<Vec<crate::MemoryEntry>> {
214            Ok(vec![])
215        }
216    }
217
218    #[async_trait]
219    impl Tool for TestTool {
220        fn name(&self) -> &str {
221            &self.name
222        }
223
224        fn description(&self) -> &str {
225            "test tool"
226        }
227
228        async fn execute(&self, _ctx: Arc<dyn ToolContext>, _args: Value) -> Result<Value> {
229            Ok(Value::String("result".to_string()))
230        }
231    }
232
233    #[test]
234    fn test_tool_trait() {
235        let tool = TestTool { name: "test".to_string() };
236        assert_eq!(tool.name(), "test");
237        assert_eq!(tool.description(), "test tool");
238        assert!(!tool.is_long_running());
239    }
240
241    #[tokio::test]
242    async fn test_tool_execute() {
243        let tool = TestTool { name: "test".to_string() };
244        let ctx = Arc::new(TestContext::new()) as Arc<dyn ToolContext>;
245        let result = tool.execute(ctx, Value::Null).await.unwrap();
246        assert_eq!(result, Value::String("result".to_string()));
247    }
248}