adk_tool/
stateful_tool.rs1use adk_core::{Result, Tool, ToolContext};
2use async_trait::async_trait;
3use schemars::{
4 JsonSchema,
5 generate::{SchemaGenerator, SchemaSettings},
6};
7use serde::Serialize;
8use serde_json::Value;
9use std::future::Future;
10use std::pin::Pin;
11use std::sync::Arc;
12
13type AsyncStatefulHandler<S> = Box<
14 dyn Fn(
15 Arc<S>,
16 Arc<dyn ToolContext>,
17 Value,
18 ) -> Pin<Box<dyn Future<Output = Result<Value>> + Send>>
19 + Send
20 + Sync,
21>;
22
23pub struct StatefulTool<S: Send + Sync + 'static> {
55 name: String,
56 description: String,
57 state: Arc<S>,
58 handler: AsyncStatefulHandler<S>,
59 long_running: bool,
60 read_only: bool,
61 concurrency_safe: bool,
62 parameters_schema: Option<Value>,
63 response_schema: Option<Value>,
64 scopes: Vec<&'static str>,
65}
66
67impl<S: Send + Sync + 'static> StatefulTool<S> {
68 pub fn new<F, Fut>(
77 name: impl Into<String>,
78 description: impl Into<String>,
79 state: Arc<S>,
80 handler: F,
81 ) -> Self
82 where
83 F: Fn(Arc<S>, Arc<dyn ToolContext>, Value) -> Fut + Send + Sync + 'static,
84 Fut: Future<Output = Result<Value>> + Send + 'static,
85 {
86 Self {
87 name: name.into(),
88 description: description.into(),
89 state,
90 handler: Box::new(move |s, ctx, args| Box::pin(handler(s, ctx, args))),
91 long_running: false,
92 read_only: false,
93 concurrency_safe: false,
94 parameters_schema: None,
95 response_schema: None,
96 scopes: Vec::new(),
97 }
98 }
99
100 pub fn with_long_running(mut self, long_running: bool) -> Self {
102 self.long_running = long_running;
103 self
104 }
105
106 pub fn with_read_only(mut self, read_only: bool) -> Self {
108 self.read_only = read_only;
109 self
110 }
111
112 pub fn with_concurrency_safe(mut self, concurrency_safe: bool) -> Self {
114 self.concurrency_safe = concurrency_safe;
115 self
116 }
117
118 pub fn with_parameters_schema<T>(mut self) -> Self
120 where
121 T: JsonSchema + Serialize,
122 {
123 self.parameters_schema = Some(generate_schema::<T>());
124 self
125 }
126
127 pub fn with_response_schema<T>(mut self) -> Self
129 where
130 T: JsonSchema + Serialize,
131 {
132 self.response_schema = Some(generate_schema::<T>());
133 self
134 }
135
136 pub fn with_scopes(mut self, scopes: &[&'static str]) -> Self {
141 self.scopes = scopes.to_vec();
142 self
143 }
144
145 pub fn parameters_schema(&self) -> Option<&Value> {
147 self.parameters_schema.as_ref()
148 }
149
150 pub fn response_schema(&self) -> Option<&Value> {
152 self.response_schema.as_ref()
153 }
154}
155
156const LONG_RUNNING_NOTE: &str = "NOTE: This is a long-running operation. Do not call this tool again if it has already returned some intermediate or pending status.";
158
159#[async_trait]
160impl<S: Send + Sync + 'static> Tool for StatefulTool<S> {
161 fn name(&self) -> &str {
162 &self.name
163 }
164
165 fn description(&self) -> &str {
166 &self.description
167 }
168
169 fn enhanced_description(&self) -> String {
170 if self.long_running {
171 if self.description.is_empty() {
172 LONG_RUNNING_NOTE.to_string()
173 } else {
174 format!("{}\n\n{}", self.description, LONG_RUNNING_NOTE)
175 }
176 } else {
177 self.description.clone()
178 }
179 }
180
181 fn is_long_running(&self) -> bool {
182 self.long_running
183 }
184
185 fn is_read_only(&self) -> bool {
186 self.read_only
187 }
188
189 fn is_concurrency_safe(&self) -> bool {
190 self.concurrency_safe
191 }
192
193 fn parameters_schema(&self) -> Option<Value> {
194 self.parameters_schema.clone()
195 }
196
197 fn response_schema(&self) -> Option<Value> {
198 self.response_schema.clone()
199 }
200
201 fn required_scopes(&self) -> &[&str] {
202 &self.scopes
203 }
204
205 #[adk_telemetry::instrument(
206 skip(self, ctx, args),
207 fields(
208 tool.name = %self.name,
209 tool.description = %self.description,
210 tool.long_running = %self.long_running,
211 function_call.id = %ctx.function_call_id()
212 )
213 )]
214 async fn execute(&self, ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value> {
215 adk_telemetry::debug!("Executing stateful tool");
216 let state = Arc::clone(&self.state);
217 (self.handler)(state, ctx, args).await
218 }
219}
220
221fn generate_schema<T>() -> Value
222where
223 T: JsonSchema + Serialize,
224{
225 let settings = SchemaSettings::openapi3().with(|s| {
226 s.inline_subschemas = true;
227 s.meta_schema = None;
228 });
229 let generator = SchemaGenerator::new(settings);
230 let mut schema = generator.into_root_schema_for::<T>();
231 if let Some(object) = schema.as_object_mut() {
232 object.remove("title");
233 }
234 serde_json::to_value(schema).unwrap()
235}