Skip to main content

adk_tool/
stateful_tool.rs

1use adk_core::{Result, Tool, ToolContext};
2use async_trait::async_trait;
3use schemars::{
4    JsonSchema,
5    generate::{SchemaGenerator, SchemaSettings},
6};
7use serde::Serialize;
8use serde_json::Value;
9use std::future::Future;
10use std::pin::Pin;
11use std::sync::Arc;
12
13type AsyncStatefulHandler<S> = Box<
14    dyn Fn(
15            Arc<S>,
16            Arc<dyn ToolContext>,
17            Value,
18        ) -> Pin<Box<dyn Future<Output = Result<Value>> + Send>>
19        + Send
20        + Sync,
21>;
22
23/// A generic tool wrapper that manages shared state for stateful closures.
24///
25/// `StatefulTool<S>` accepts an `Arc<S>` and a handler closure that receives
26/// the state alongside the tool context and arguments. The `Arc<S>` is cloned
27/// (cheap reference count bump) on each invocation, so all executions share
28/// the same underlying state.
29///
30/// # Example
31///
32/// ```rust,ignore
33/// use adk_tool::StatefulTool;
34/// use adk_core::{ToolContext, Result};
35/// use serde_json::{json, Value};
36/// use std::sync::Arc;
37/// use tokio::sync::RwLock;
38///
39/// struct Counter { count: RwLock<u64> }
40///
41/// let state = Arc::new(Counter { count: RwLock::new(0) });
42///
43/// let tool = StatefulTool::new(
44///     "increment",
45///     "Increment a counter",
46///     state,
47///     |s, _ctx, _args| async move {
48///         let mut count = s.count.write().await;
49///         *count += 1;
50///         Ok(json!({ "count": *count }))
51///     },
52/// );
53/// ```
54pub struct StatefulTool<S: Send + Sync + 'static> {
55    name: String,
56    description: String,
57    state: Arc<S>,
58    handler: AsyncStatefulHandler<S>,
59    long_running: bool,
60    read_only: bool,
61    concurrency_safe: bool,
62    parameters_schema: Option<Value>,
63    response_schema: Option<Value>,
64    scopes: Vec<&'static str>,
65}
66
67impl<S: Send + Sync + 'static> StatefulTool<S> {
68    /// Create a new stateful tool.
69    ///
70    /// # Arguments
71    ///
72    /// * `name` - Tool name exposed to the LLM
73    /// * `description` - Human-readable description of what the tool does
74    /// * `state` - Shared state wrapped in `Arc<S>`
75    /// * `handler` - Async closure receiving `(Arc<S>, Arc<dyn ToolContext>, Value)`
76    pub fn new<F, Fut>(
77        name: impl Into<String>,
78        description: impl Into<String>,
79        state: Arc<S>,
80        handler: F,
81    ) -> Self
82    where
83        F: Fn(Arc<S>, Arc<dyn ToolContext>, Value) -> Fut + Send + Sync + 'static,
84        Fut: Future<Output = Result<Value>> + Send + 'static,
85    {
86        Self {
87            name: name.into(),
88            description: description.into(),
89            state,
90            handler: Box::new(move |s, ctx, args| Box::pin(handler(s, ctx, args))),
91            long_running: false,
92            read_only: false,
93            concurrency_safe: false,
94            parameters_schema: None,
95            response_schema: None,
96            scopes: Vec::new(),
97        }
98    }
99
100    pub fn with_long_running(mut self, long_running: bool) -> Self {
101        self.long_running = long_running;
102        self
103    }
104
105    pub fn with_read_only(mut self, read_only: bool) -> Self {
106        self.read_only = read_only;
107        self
108    }
109
110    pub fn with_concurrency_safe(mut self, concurrency_safe: bool) -> Self {
111        self.concurrency_safe = concurrency_safe;
112        self
113    }
114
115    pub fn with_parameters_schema<T>(mut self) -> Self
116    where
117        T: JsonSchema + Serialize,
118    {
119        self.parameters_schema = Some(generate_schema::<T>());
120        self
121    }
122
123    pub fn with_response_schema<T>(mut self) -> Self
124    where
125        T: JsonSchema + Serialize,
126    {
127        self.response_schema = Some(generate_schema::<T>());
128        self
129    }
130
131    /// Declare the scopes required to execute this tool.
132    ///
133    /// When set, the framework will enforce that the calling user possesses
134    /// **all** listed scopes before dispatching `execute()`.
135    pub fn with_scopes(mut self, scopes: &[&'static str]) -> Self {
136        self.scopes = scopes.to_vec();
137        self
138    }
139
140    pub fn parameters_schema(&self) -> Option<&Value> {
141        self.parameters_schema.as_ref()
142    }
143
144    pub fn response_schema(&self) -> Option<&Value> {
145        self.response_schema.as_ref()
146    }
147}
148
149/// The note appended to long-running tool descriptions to prevent duplicate calls.
150const LONG_RUNNING_NOTE: &str = "NOTE: This is a long-running operation. Do not call this tool again if it has already returned some intermediate or pending status.";
151
152#[async_trait]
153impl<S: Send + Sync + 'static> Tool for StatefulTool<S> {
154    fn name(&self) -> &str {
155        &self.name
156    }
157
158    fn description(&self) -> &str {
159        &self.description
160    }
161
162    fn enhanced_description(&self) -> String {
163        if self.long_running {
164            if self.description.is_empty() {
165                LONG_RUNNING_NOTE.to_string()
166            } else {
167                format!("{}\n\n{}", self.description, LONG_RUNNING_NOTE)
168            }
169        } else {
170            self.description.clone()
171        }
172    }
173
174    fn is_long_running(&self) -> bool {
175        self.long_running
176    }
177
178    fn is_read_only(&self) -> bool {
179        self.read_only
180    }
181
182    fn is_concurrency_safe(&self) -> bool {
183        self.concurrency_safe
184    }
185
186    fn parameters_schema(&self) -> Option<Value> {
187        self.parameters_schema.clone()
188    }
189
190    fn response_schema(&self) -> Option<Value> {
191        self.response_schema.clone()
192    }
193
194    fn required_scopes(&self) -> &[&str] {
195        &self.scopes
196    }
197
198    #[adk_telemetry::instrument(
199        skip(self, ctx, args),
200        fields(
201            tool.name = %self.name,
202            tool.description = %self.description,
203            tool.long_running = %self.long_running,
204            function_call.id = %ctx.function_call_id()
205        )
206    )]
207    async fn execute(&self, ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value> {
208        adk_telemetry::debug!("Executing stateful tool");
209        let state = Arc::clone(&self.state);
210        (self.handler)(state, ctx, args).await
211    }
212}
213
214fn generate_schema<T>() -> Value
215where
216    T: JsonSchema + Serialize,
217{
218    let settings = SchemaSettings::openapi3().with(|s| {
219        s.inline_subschemas = true;
220        s.meta_schema = None;
221    });
222    let generator = SchemaGenerator::new(settings);
223    let mut schema = generator.into_root_schema_for::<T>();
224    if let Some(object) = schema.as_object_mut() {
225        object.remove("title");
226    }
227    serde_json::to_value(schema).unwrap()
228}