Skip to main content

adk_tool/
function_tool.rs

1use adk_core::{Result, Tool, ToolContext};
2use async_trait::async_trait;
3use schemars::{
4    JsonSchema,
5    generate::{SchemaGenerator, SchemaSettings},
6};
7use serde::Serialize;
8use serde_json::Value;
9use std::future::Future;
10use std::pin::Pin;
11use std::sync::Arc;
12
13type AsyncHandler = Box<
14    dyn Fn(Arc<dyn ToolContext>, Value) -> Pin<Box<dyn Future<Output = Result<Value>> + Send>>
15        + Send
16        + Sync,
17>;
18
19pub struct FunctionTool {
20    name: String,
21    description: String,
22    handler: AsyncHandler,
23    long_running: bool,
24    read_only: bool,
25    concurrency_safe: bool,
26    parameters_schema: Option<Value>,
27    response_schema: Option<Value>,
28    scopes: Vec<&'static str>,
29}
30
31impl FunctionTool {
32    pub fn new<F, Fut>(name: impl Into<String>, description: impl Into<String>, handler: F) -> Self
33    where
34        F: Fn(Arc<dyn ToolContext>, Value) -> Fut + Send + Sync + 'static,
35        Fut: Future<Output = Result<Value>> + Send + 'static,
36    {
37        Self {
38            name: name.into(),
39            description: description.into(),
40            handler: Box::new(move |ctx, args| Box::pin(handler(ctx, args))),
41            long_running: false,
42            read_only: false,
43            concurrency_safe: false,
44            parameters_schema: None,
45            response_schema: None,
46            scopes: Vec::new(),
47        }
48    }
49
50    pub fn with_long_running(mut self, long_running: bool) -> Self {
51        self.long_running = long_running;
52        self
53    }
54
55    pub fn with_read_only(mut self, read_only: bool) -> Self {
56        self.read_only = read_only;
57        self
58    }
59
60    pub fn with_concurrency_safe(mut self, concurrency_safe: bool) -> Self {
61        self.concurrency_safe = concurrency_safe;
62        self
63    }
64
65    pub fn with_parameters_schema<T>(mut self) -> Self
66    where
67        T: JsonSchema + Serialize,
68    {
69        self.parameters_schema = Some(generate_schema::<T>());
70        self
71    }
72
73    pub fn with_response_schema<T>(mut self) -> Self
74    where
75        T: JsonSchema + Serialize,
76    {
77        self.response_schema = Some(generate_schema::<T>());
78        self
79    }
80
81    /// Declare the scopes required to execute this tool.
82    ///
83    /// When set, the framework will enforce that the calling user possesses
84    /// **all** listed scopes before dispatching `execute()`.
85    ///
86    /// # Example
87    ///
88    /// ```rust,ignore
89    /// let tool = FunctionTool::new("transfer", "Transfer funds", handler)
90    ///     .with_scopes(&["finance:write", "verified"]);
91    /// ```
92    pub fn with_scopes(mut self, scopes: &[&'static str]) -> Self {
93        self.scopes = scopes.to_vec();
94        self
95    }
96
97    pub fn parameters_schema(&self) -> Option<&Value> {
98        self.parameters_schema.as_ref()
99    }
100
101    pub fn response_schema(&self) -> Option<&Value> {
102        self.response_schema.as_ref()
103    }
104}
105
106/// The note appended to long-running tool descriptions to prevent duplicate calls.
107const LONG_RUNNING_NOTE: &str = "NOTE: This is a long-running operation. Do not call this tool again if it has already returned some intermediate or pending status.";
108
109#[async_trait]
110impl Tool for FunctionTool {
111    fn name(&self) -> &str {
112        &self.name
113    }
114
115    fn description(&self) -> &str {
116        &self.description
117    }
118
119    /// Returns an enhanced description for long-running tools that includes
120    /// a note warning the model not to call the tool again if it's already pending.
121    fn enhanced_description(&self) -> String {
122        if self.long_running {
123            if self.description.is_empty() {
124                LONG_RUNNING_NOTE.to_string()
125            } else {
126                format!("{}\n\n{}", self.description, LONG_RUNNING_NOTE)
127            }
128        } else {
129            self.description.clone()
130        }
131    }
132
133    fn is_long_running(&self) -> bool {
134        self.long_running
135    }
136
137    fn is_read_only(&self) -> bool {
138        self.read_only
139    }
140
141    fn is_concurrency_safe(&self) -> bool {
142        self.concurrency_safe
143    }
144
145    fn parameters_schema(&self) -> Option<Value> {
146        self.parameters_schema.clone()
147    }
148
149    fn response_schema(&self) -> Option<Value> {
150        self.response_schema.clone()
151    }
152
153    fn required_scopes(&self) -> &[&str] {
154        &self.scopes
155    }
156
157    #[adk_telemetry::instrument(
158        skip(self, ctx, args),
159        fields(
160            tool.name = %self.name,
161            tool.description = %self.description,
162            tool.long_running = %self.long_running,
163            function_call.id = %ctx.function_call_id()
164        )
165    )]
166    async fn execute(&self, ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value> {
167        adk_telemetry::debug!("Executing tool");
168        (self.handler)(ctx, args).await
169    }
170}
171
172fn generate_schema<T>() -> Value
173where
174    T: JsonSchema + Serialize,
175{
176    let settings = SchemaSettings::openapi3().with(|s| {
177        s.inline_subschemas = true;
178        s.meta_schema = None;
179    });
180    let generator = SchemaGenerator::new(settings);
181    let mut schema = generator.into_root_schema_for::<T>();
182    if let Some(object) = schema.as_object_mut() {
183        object.remove("title");
184    }
185    serde_json::to_value(schema).unwrap()
186}