adk_tool/
stateful_tool.rs1use adk_core::{Result, Tool, ToolContext};
2use async_trait::async_trait;
3use schemars::{JsonSchema, schema::RootSchema};
4use serde::Serialize;
5use serde_json::Value;
6use std::future::Future;
7use std::pin::Pin;
8use std::sync::Arc;
9
10type AsyncStatefulHandler<S> = Box<
11 dyn Fn(
12 Arc<S>,
13 Arc<dyn ToolContext>,
14 Value,
15 ) -> Pin<Box<dyn Future<Output = Result<Value>> + Send>>
16 + Send
17 + Sync,
18>;
19
20pub struct StatefulTool<S: Send + Sync + 'static> {
52 name: String,
53 description: String,
54 state: Arc<S>,
55 handler: AsyncStatefulHandler<S>,
56 long_running: bool,
57 read_only: bool,
58 concurrency_safe: bool,
59 parameters_schema: Option<Value>,
60 response_schema: Option<Value>,
61 scopes: Vec<&'static str>,
62}
63
64impl<S: Send + Sync + 'static> StatefulTool<S> {
65 pub fn new<F, Fut>(
74 name: impl Into<String>,
75 description: impl Into<String>,
76 state: Arc<S>,
77 handler: F,
78 ) -> Self
79 where
80 F: Fn(Arc<S>, Arc<dyn ToolContext>, Value) -> Fut + Send + Sync + 'static,
81 Fut: Future<Output = Result<Value>> + Send + 'static,
82 {
83 Self {
84 name: name.into(),
85 description: description.into(),
86 state,
87 handler: Box::new(move |s, ctx, args| Box::pin(handler(s, ctx, args))),
88 long_running: false,
89 read_only: false,
90 concurrency_safe: false,
91 parameters_schema: None,
92 response_schema: None,
93 scopes: Vec::new(),
94 }
95 }
96
97 pub fn with_long_running(mut self, long_running: bool) -> Self {
98 self.long_running = long_running;
99 self
100 }
101
102 pub fn with_read_only(mut self, read_only: bool) -> Self {
103 self.read_only = read_only;
104 self
105 }
106
107 pub fn with_concurrency_safe(mut self, concurrency_safe: bool) -> Self {
108 self.concurrency_safe = concurrency_safe;
109 self
110 }
111
112 pub fn with_parameters_schema<T>(mut self) -> Self
113 where
114 T: JsonSchema + Serialize,
115 {
116 self.parameters_schema = Some(generate_schema::<T>());
117 self
118 }
119
120 pub fn with_response_schema<T>(mut self) -> Self
121 where
122 T: JsonSchema + Serialize,
123 {
124 self.response_schema = Some(generate_schema::<T>());
125 self
126 }
127
128 pub fn with_scopes(mut self, scopes: &[&'static str]) -> Self {
133 self.scopes = scopes.to_vec();
134 self
135 }
136
137 pub fn parameters_schema(&self) -> Option<&Value> {
138 self.parameters_schema.as_ref()
139 }
140
141 pub fn response_schema(&self) -> Option<&Value> {
142 self.response_schema.as_ref()
143 }
144}
145
146const LONG_RUNNING_NOTE: &str = "NOTE: This is a long-running operation. Do not call this tool again if it has already returned some intermediate or pending status.";
148
149#[async_trait]
150impl<S: Send + Sync + 'static> Tool for StatefulTool<S> {
151 fn name(&self) -> &str {
152 &self.name
153 }
154
155 fn description(&self) -> &str {
156 &self.description
157 }
158
159 fn enhanced_description(&self) -> String {
160 if self.long_running {
161 if self.description.is_empty() {
162 LONG_RUNNING_NOTE.to_string()
163 } else {
164 format!("{}\n\n{}", self.description, LONG_RUNNING_NOTE)
165 }
166 } else {
167 self.description.clone()
168 }
169 }
170
171 fn is_long_running(&self) -> bool {
172 self.long_running
173 }
174
175 fn is_read_only(&self) -> bool {
176 self.read_only
177 }
178
179 fn is_concurrency_safe(&self) -> bool {
180 self.concurrency_safe
181 }
182
183 fn parameters_schema(&self) -> Option<Value> {
184 self.parameters_schema.clone()
185 }
186
187 fn response_schema(&self) -> Option<Value> {
188 self.response_schema.clone()
189 }
190
191 fn required_scopes(&self) -> &[&str] {
192 &self.scopes
193 }
194
195 #[adk_telemetry::instrument(
196 skip(self, ctx, args),
197 fields(
198 tool.name = %self.name,
199 tool.description = %self.description,
200 tool.long_running = %self.long_running,
201 function_call.id = %ctx.function_call_id()
202 )
203 )]
204 async fn execute(&self, ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value> {
205 adk_telemetry::debug!("Executing stateful tool");
206 let state = Arc::clone(&self.state);
207 (self.handler)(state, ctx, args).await
208 }
209}
210
211fn generate_schema<T>() -> Value
212where
213 T: JsonSchema + Serialize,
214{
215 let settings = schemars::r#gen::SchemaSettings::openapi3().with(|s| {
216 s.inline_subschemas = true;
217 s.meta_schema = None;
218 });
219 let generator = schemars::r#gen::SchemaGenerator::new(settings);
220 let mut schema: RootSchema = generator.into_root_schema_for::<T>();
221 schema.schema.metadata().title = None;
222 serde_json::to_value(schema).unwrap()
223}