Skip to main content

adk_tool/
stateful_tool.rs

1use adk_core::{Result, Tool, ToolContext};
2use async_trait::async_trait;
3use schemars::{JsonSchema, schema::RootSchema};
4use serde::Serialize;
5use serde_json::Value;
6use std::future::Future;
7use std::pin::Pin;
8use std::sync::Arc;
9
10type AsyncStatefulHandler<S> = Box<
11    dyn Fn(
12            Arc<S>,
13            Arc<dyn ToolContext>,
14            Value,
15        ) -> Pin<Box<dyn Future<Output = Result<Value>> + Send>>
16        + Send
17        + Sync,
18>;
19
20/// A generic tool wrapper that manages shared state for stateful closures.
21///
22/// `StatefulTool<S>` accepts an `Arc<S>` and a handler closure that receives
23/// the state alongside the tool context and arguments. The `Arc<S>` is cloned
24/// (cheap reference count bump) on each invocation, so all executions share
25/// the same underlying state.
26///
27/// # Example
28///
29/// ```rust,ignore
30/// use adk_tool::StatefulTool;
31/// use adk_core::{ToolContext, Result};
32/// use serde_json::{json, Value};
33/// use std::sync::Arc;
34/// use tokio::sync::RwLock;
35///
36/// struct Counter { count: RwLock<u64> }
37///
38/// let state = Arc::new(Counter { count: RwLock::new(0) });
39///
40/// let tool = StatefulTool::new(
41///     "increment",
42///     "Increment a counter",
43///     state,
44///     |s, _ctx, _args| async move {
45///         let mut count = s.count.write().await;
46///         *count += 1;
47///         Ok(json!({ "count": *count }))
48///     },
49/// );
50/// ```
51pub struct StatefulTool<S: Send + Sync + 'static> {
52    name: String,
53    description: String,
54    state: Arc<S>,
55    handler: AsyncStatefulHandler<S>,
56    long_running: bool,
57    read_only: bool,
58    concurrency_safe: bool,
59    parameters_schema: Option<Value>,
60    response_schema: Option<Value>,
61    scopes: Vec<&'static str>,
62}
63
64impl<S: Send + Sync + 'static> StatefulTool<S> {
65    /// Create a new stateful tool.
66    ///
67    /// # Arguments
68    ///
69    /// * `name` - Tool name exposed to the LLM
70    /// * `description` - Human-readable description of what the tool does
71    /// * `state` - Shared state wrapped in `Arc<S>`
72    /// * `handler` - Async closure receiving `(Arc<S>, Arc<dyn ToolContext>, Value)`
73    pub fn new<F, Fut>(
74        name: impl Into<String>,
75        description: impl Into<String>,
76        state: Arc<S>,
77        handler: F,
78    ) -> Self
79    where
80        F: Fn(Arc<S>, Arc<dyn ToolContext>, Value) -> Fut + Send + Sync + 'static,
81        Fut: Future<Output = Result<Value>> + Send + 'static,
82    {
83        Self {
84            name: name.into(),
85            description: description.into(),
86            state,
87            handler: Box::new(move |s, ctx, args| Box::pin(handler(s, ctx, args))),
88            long_running: false,
89            read_only: false,
90            concurrency_safe: false,
91            parameters_schema: None,
92            response_schema: None,
93            scopes: Vec::new(),
94        }
95    }
96
97    pub fn with_long_running(mut self, long_running: bool) -> Self {
98        self.long_running = long_running;
99        self
100    }
101
102    pub fn with_read_only(mut self, read_only: bool) -> Self {
103        self.read_only = read_only;
104        self
105    }
106
107    pub fn with_concurrency_safe(mut self, concurrency_safe: bool) -> Self {
108        self.concurrency_safe = concurrency_safe;
109        self
110    }
111
112    pub fn with_parameters_schema<T>(mut self) -> Self
113    where
114        T: JsonSchema + Serialize,
115    {
116        self.parameters_schema = Some(generate_schema::<T>());
117        self
118    }
119
120    pub fn with_response_schema<T>(mut self) -> Self
121    where
122        T: JsonSchema + Serialize,
123    {
124        self.response_schema = Some(generate_schema::<T>());
125        self
126    }
127
128    /// Declare the scopes required to execute this tool.
129    ///
130    /// When set, the framework will enforce that the calling user possesses
131    /// **all** listed scopes before dispatching `execute()`.
132    pub fn with_scopes(mut self, scopes: &[&'static str]) -> Self {
133        self.scopes = scopes.to_vec();
134        self
135    }
136
137    pub fn parameters_schema(&self) -> Option<&Value> {
138        self.parameters_schema.as_ref()
139    }
140
141    pub fn response_schema(&self) -> Option<&Value> {
142        self.response_schema.as_ref()
143    }
144}
145
146/// The note appended to long-running tool descriptions to prevent duplicate calls.
147const LONG_RUNNING_NOTE: &str = "NOTE: This is a long-running operation. Do not call this tool again if it has already returned some intermediate or pending status.";
148
149#[async_trait]
150impl<S: Send + Sync + 'static> Tool for StatefulTool<S> {
151    fn name(&self) -> &str {
152        &self.name
153    }
154
155    fn description(&self) -> &str {
156        &self.description
157    }
158
159    fn enhanced_description(&self) -> String {
160        if self.long_running {
161            if self.description.is_empty() {
162                LONG_RUNNING_NOTE.to_string()
163            } else {
164                format!("{}\n\n{}", self.description, LONG_RUNNING_NOTE)
165            }
166        } else {
167            self.description.clone()
168        }
169    }
170
171    fn is_long_running(&self) -> bool {
172        self.long_running
173    }
174
175    fn is_read_only(&self) -> bool {
176        self.read_only
177    }
178
179    fn is_concurrency_safe(&self) -> bool {
180        self.concurrency_safe
181    }
182
183    fn parameters_schema(&self) -> Option<Value> {
184        self.parameters_schema.clone()
185    }
186
187    fn response_schema(&self) -> Option<Value> {
188        self.response_schema.clone()
189    }
190
191    fn required_scopes(&self) -> &[&str] {
192        &self.scopes
193    }
194
195    #[adk_telemetry::instrument(
196        skip(self, ctx, args),
197        fields(
198            tool.name = %self.name,
199            tool.description = %self.description,
200            tool.long_running = %self.long_running,
201            function_call.id = %ctx.function_call_id()
202        )
203    )]
204    async fn execute(&self, ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value> {
205        adk_telemetry::debug!("Executing stateful tool");
206        let state = Arc::clone(&self.state);
207        (self.handler)(state, ctx, args).await
208    }
209}
210
211fn generate_schema<T>() -> Value
212where
213    T: JsonSchema + Serialize,
214{
215    let settings = schemars::r#gen::SchemaSettings::openapi3().with(|s| {
216        s.inline_subschemas = true;
217        s.meta_schema = None;
218    });
219    let generator = schemars::r#gen::SchemaGenerator::new(settings);
220    let mut schema: RootSchema = generator.into_root_schema_for::<T>();
221    schema.schema.metadata().title = None;
222    serde_json::to_value(schema).unwrap()
223}