adk_tool/
stateful_tool.rs1use adk_core::{Result, Tool, ToolContext};
2use async_trait::async_trait;
3use schemars::{
4 JsonSchema,
5 generate::{SchemaGenerator, SchemaSettings},
6};
7use serde::Serialize;
8use serde_json::Value;
9use std::future::Future;
10use std::pin::Pin;
11use std::sync::Arc;
12
13type AsyncStatefulHandler<S> = Box<
14 dyn Fn(
15 Arc<S>,
16 Arc<dyn ToolContext>,
17 Value,
18 ) -> Pin<Box<dyn Future<Output = Result<Value>> + Send>>
19 + Send
20 + Sync,
21>;
22
23pub struct StatefulTool<S: Send + Sync + 'static> {
55 name: String,
56 description: String,
57 state: Arc<S>,
58 handler: AsyncStatefulHandler<S>,
59 long_running: bool,
60 read_only: bool,
61 concurrency_safe: bool,
62 parameters_schema: Option<Value>,
63 response_schema: Option<Value>,
64 scopes: Vec<&'static str>,
65}
66
67impl<S: Send + Sync + 'static> StatefulTool<S> {
68 pub fn new<F, Fut>(
77 name: impl Into<String>,
78 description: impl Into<String>,
79 state: Arc<S>,
80 handler: F,
81 ) -> Self
82 where
83 F: Fn(Arc<S>, Arc<dyn ToolContext>, Value) -> Fut + Send + Sync + 'static,
84 Fut: Future<Output = Result<Value>> + Send + 'static,
85 {
86 Self {
87 name: name.into(),
88 description: description.into(),
89 state,
90 handler: Box::new(move |s, ctx, args| Box::pin(handler(s, ctx, args))),
91 long_running: false,
92 read_only: false,
93 concurrency_safe: false,
94 parameters_schema: None,
95 response_schema: None,
96 scopes: Vec::new(),
97 }
98 }
99
100 pub fn with_long_running(mut self, long_running: bool) -> Self {
101 self.long_running = long_running;
102 self
103 }
104
105 pub fn with_read_only(mut self, read_only: bool) -> Self {
106 self.read_only = read_only;
107 self
108 }
109
110 pub fn with_concurrency_safe(mut self, concurrency_safe: bool) -> Self {
111 self.concurrency_safe = concurrency_safe;
112 self
113 }
114
115 pub fn with_parameters_schema<T>(mut self) -> Self
116 where
117 T: JsonSchema + Serialize,
118 {
119 self.parameters_schema = Some(generate_schema::<T>());
120 self
121 }
122
123 pub fn with_response_schema<T>(mut self) -> Self
124 where
125 T: JsonSchema + Serialize,
126 {
127 self.response_schema = Some(generate_schema::<T>());
128 self
129 }
130
131 pub fn with_scopes(mut self, scopes: &[&'static str]) -> Self {
136 self.scopes = scopes.to_vec();
137 self
138 }
139
140 pub fn parameters_schema(&self) -> Option<&Value> {
141 self.parameters_schema.as_ref()
142 }
143
144 pub fn response_schema(&self) -> Option<&Value> {
145 self.response_schema.as_ref()
146 }
147}
148
149const LONG_RUNNING_NOTE: &str = "NOTE: This is a long-running operation. Do not call this tool again if it has already returned some intermediate or pending status.";
151
152#[async_trait]
153impl<S: Send + Sync + 'static> Tool for StatefulTool<S> {
154 fn name(&self) -> &str {
155 &self.name
156 }
157
158 fn description(&self) -> &str {
159 &self.description
160 }
161
162 fn enhanced_description(&self) -> String {
163 if self.long_running {
164 if self.description.is_empty() {
165 LONG_RUNNING_NOTE.to_string()
166 } else {
167 format!("{}\n\n{}", self.description, LONG_RUNNING_NOTE)
168 }
169 } else {
170 self.description.clone()
171 }
172 }
173
174 fn is_long_running(&self) -> bool {
175 self.long_running
176 }
177
178 fn is_read_only(&self) -> bool {
179 self.read_only
180 }
181
182 fn is_concurrency_safe(&self) -> bool {
183 self.concurrency_safe
184 }
185
186 fn parameters_schema(&self) -> Option<Value> {
187 self.parameters_schema.clone()
188 }
189
190 fn response_schema(&self) -> Option<Value> {
191 self.response_schema.clone()
192 }
193
194 fn required_scopes(&self) -> &[&str] {
195 &self.scopes
196 }
197
198 #[adk_telemetry::instrument(
199 skip(self, ctx, args),
200 fields(
201 tool.name = %self.name,
202 tool.description = %self.description,
203 tool.long_running = %self.long_running,
204 function_call.id = %ctx.function_call_id()
205 )
206 )]
207 async fn execute(&self, ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value> {
208 adk_telemetry::debug!("Executing stateful tool");
209 let state = Arc::clone(&self.state);
210 (self.handler)(state, ctx, args).await
211 }
212}
213
214fn generate_schema<T>() -> Value
215where
216 T: JsonSchema + Serialize,
217{
218 let settings = SchemaSettings::openapi3().with(|s| {
219 s.inline_subschemas = true;
220 s.meta_schema = None;
221 });
222 let generator = SchemaGenerator::new(settings);
223 let mut schema = generator.into_root_schema_for::<T>();
224 if let Some(object) = schema.as_object_mut() {
225 object.remove("title");
226 }
227 serde_json::to_value(schema).unwrap()
228}