Skip to main content

adk_tool/
stateful_tool.rs

1use adk_core::{Result, Tool, ToolContext};
2use async_trait::async_trait;
3use schemars::{
4    JsonSchema,
5    generate::{SchemaGenerator, SchemaSettings},
6};
7use serde::Serialize;
8use serde_json::Value;
9use std::future::Future;
10use std::pin::Pin;
11use std::sync::Arc;
12
13type AsyncStatefulHandler<S> = Box<
14    dyn Fn(
15            Arc<S>,
16            Arc<dyn ToolContext>,
17            Value,
18        ) -> Pin<Box<dyn Future<Output = Result<Value>> + Send>>
19        + Send
20        + Sync,
21>;
22
23/// A generic tool wrapper that manages shared state for stateful closures.
24///
25/// `StatefulTool<S>` accepts an `Arc<S>` and a handler closure that receives
26/// the state alongside the tool context and arguments. The `Arc<S>` is cloned
27/// (cheap reference count bump) on each invocation, so all executions share
28/// the same underlying state.
29///
30/// # Example
31///
32/// ```rust,ignore
33/// use adk_tool::StatefulTool;
34/// use adk_core::{ToolContext, Result};
35/// use serde_json::{json, Value};
36/// use std::sync::Arc;
37/// use tokio::sync::RwLock;
38///
39/// struct Counter { count: RwLock<u64> }
40///
41/// let state = Arc::new(Counter { count: RwLock::new(0) });
42///
43/// let tool = StatefulTool::new(
44///     "increment",
45///     "Increment a counter",
46///     state,
47///     |s, _ctx, _args| async move {
48///         let mut count = s.count.write().await;
49///         *count += 1;
50///         Ok(json!({ "count": *count }))
51///     },
52/// );
53/// ```
54pub struct StatefulTool<S: Send + Sync + 'static> {
55    name: String,
56    description: String,
57    state: Arc<S>,
58    handler: AsyncStatefulHandler<S>,
59    long_running: bool,
60    read_only: bool,
61    concurrency_safe: bool,
62    parameters_schema: Option<Value>,
63    response_schema: Option<Value>,
64    scopes: Vec<&'static str>,
65}
66
67impl<S: Send + Sync + 'static> StatefulTool<S> {
68    /// Create a new stateful tool.
69    ///
70    /// # Arguments
71    ///
72    /// * `name` - Tool name exposed to the LLM
73    /// * `description` - Human-readable description of what the tool does
74    /// * `state` - Shared state wrapped in `Arc<S>`
75    /// * `handler` - Async closure receiving `(Arc<S>, Arc<dyn ToolContext>, Value)`
76    pub fn new<F, Fut>(
77        name: impl Into<String>,
78        description: impl Into<String>,
79        state: Arc<S>,
80        handler: F,
81    ) -> Self
82    where
83        F: Fn(Arc<S>, Arc<dyn ToolContext>, Value) -> Fut + Send + Sync + 'static,
84        Fut: Future<Output = Result<Value>> + Send + 'static,
85    {
86        Self {
87            name: name.into(),
88            description: description.into(),
89            state,
90            handler: Box::new(move |s, ctx, args| Box::pin(handler(s, ctx, args))),
91            long_running: false,
92            read_only: false,
93            concurrency_safe: false,
94            parameters_schema: None,
95            response_schema: None,
96            scopes: Vec::new(),
97        }
98    }
99
100    /// Mark this tool as long-running (prevents duplicate invocations).
101    pub fn with_long_running(mut self, long_running: bool) -> Self {
102        self.long_running = long_running;
103        self
104    }
105
106    /// Mark this tool as read-only (safe for parallel dispatch).
107    pub fn with_read_only(mut self, read_only: bool) -> Self {
108        self.read_only = read_only;
109        self
110    }
111
112    /// Mark this tool as concurrency-safe (can run in parallel with other tools).
113    pub fn with_concurrency_safe(mut self, concurrency_safe: bool) -> Self {
114        self.concurrency_safe = concurrency_safe;
115        self
116    }
117
118    /// Derive the parameters JSON Schema from a type implementing `JsonSchema`.
119    pub fn with_parameters_schema<T>(mut self) -> Self
120    where
121        T: JsonSchema + Serialize,
122    {
123        self.parameters_schema = Some(generate_schema::<T>());
124        self
125    }
126
127    /// Derive the response JSON Schema from a type implementing `JsonSchema`.
128    pub fn with_response_schema<T>(mut self) -> Self
129    where
130        T: JsonSchema + Serialize,
131    {
132        self.response_schema = Some(generate_schema::<T>());
133        self
134    }
135
136    /// Declare the scopes required to execute this tool.
137    ///
138    /// When set, the framework will enforce that the calling user possesses
139    /// **all** listed scopes before dispatching `execute()`.
140    pub fn with_scopes(mut self, scopes: &[&'static str]) -> Self {
141        self.scopes = scopes.to_vec();
142        self
143    }
144
145    /// Get the parameters schema, if set.
146    pub fn parameters_schema(&self) -> Option<&Value> {
147        self.parameters_schema.as_ref()
148    }
149
150    /// Get the response schema, if set.
151    pub fn response_schema(&self) -> Option<&Value> {
152        self.response_schema.as_ref()
153    }
154}
155
156/// The note appended to long-running tool descriptions to prevent duplicate calls.
157const LONG_RUNNING_NOTE: &str = "NOTE: This is a long-running operation. Do not call this tool again if it has already returned some intermediate or pending status.";
158
159#[async_trait]
160impl<S: Send + Sync + 'static> Tool for StatefulTool<S> {
161    fn name(&self) -> &str {
162        &self.name
163    }
164
165    fn description(&self) -> &str {
166        &self.description
167    }
168
169    fn enhanced_description(&self) -> String {
170        if self.long_running {
171            if self.description.is_empty() {
172                LONG_RUNNING_NOTE.to_string()
173            } else {
174                format!("{}\n\n{}", self.description, LONG_RUNNING_NOTE)
175            }
176        } else {
177            self.description.clone()
178        }
179    }
180
181    fn is_long_running(&self) -> bool {
182        self.long_running
183    }
184
185    fn is_read_only(&self) -> bool {
186        self.read_only
187    }
188
189    fn is_concurrency_safe(&self) -> bool {
190        self.concurrency_safe
191    }
192
193    fn parameters_schema(&self) -> Option<Value> {
194        self.parameters_schema.clone()
195    }
196
197    fn response_schema(&self) -> Option<Value> {
198        self.response_schema.clone()
199    }
200
201    fn required_scopes(&self) -> &[&str] {
202        &self.scopes
203    }
204
205    #[adk_telemetry::instrument(
206        skip(self, ctx, args),
207        fields(
208            tool.name = %self.name,
209            tool.description = %self.description,
210            tool.long_running = %self.long_running,
211            function_call.id = %ctx.function_call_id()
212        )
213    )]
214    async fn execute(&self, ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value> {
215        adk_telemetry::debug!("Executing stateful tool");
216        let state = Arc::clone(&self.state);
217        (self.handler)(state, ctx, args).await
218    }
219}
220
221fn generate_schema<T>() -> Value
222where
223    T: JsonSchema + Serialize,
224{
225    let settings = SchemaSettings::openapi3().with(|s| {
226        s.inline_subschemas = true;
227        s.meta_schema = None;
228    });
229    let generator = SchemaGenerator::new(settings);
230    let mut schema = generator.into_root_schema_for::<T>();
231    if let Some(object) = schema.as_object_mut() {
232        object.remove("title");
233    }
234    serde_json::to_value(schema).unwrap()
235}