Skip to main content

adk_rs_macros/
lib.rs

1//! `#[adk::tool]` proc-macro.
2//!
3//! Usage:
4//!
5//! ```ignore
6//! use adk_rs::ToolContext;
7//! use adk_rs::Result;
8//! use serde::{Deserialize, Serialize};
9//! use schemars::JsonSchema;
10//!
11//! #[derive(Deserialize, JsonSchema)]
12//! struct GetWeatherArgs {
13//!     /// City name.
14//!     city: String,
15//! }
16//!
17//! #[derive(Serialize)]
18//! struct WeatherReport { temp_c: f32 }
19//!
20//! #[adk_rs::tool]
21//! /// Look up the weather in `args.city`.
22//! async fn get_weather(args: GetWeatherArgs, _ctx: &mut ToolContext) -> Result<WeatherReport> {
23//!     Ok(WeatherReport { temp_c: 22.0 })
24//! }
25//! ```
26//!
27//! The macro emits a unit struct named after the function (`PascalCased`) and a
28//! free constructor `get_weather() -> Arc<dyn adk_rs::Tool>`. The args
29//! struct must implement `serde::Deserialize` and `schemars::JsonSchema`.
30
31#![forbid(unsafe_code)]
32
33use proc_macro::TokenStream;
34use proc_macro2::Span;
35use quote::quote;
36use syn::{FnArg, ItemFn, Pat, PatType, Type, parse_macro_input, parse_quote};
37
38fn upper_camel(s: &str) -> String {
39    let mut out = String::new();
40    let mut up = true;
41    for c in s.chars() {
42        if c == '_' {
43            up = true;
44        } else if up {
45            out.push(c.to_ascii_uppercase());
46            up = false;
47        } else {
48            out.push(c);
49        }
50    }
51    out
52}
53
54fn doc_comment(attrs: &[syn::Attribute]) -> String {
55    let mut out = String::new();
56    for a in attrs {
57        if a.path().is_ident("doc") {
58            if let syn::Meta::NameValue(nv) = a.meta.clone() {
59                if let syn::Expr::Lit(lit) = nv.value {
60                    if let syn::Lit::Str(s) = lit.lit {
61                        if !out.is_empty() {
62                            out.push('\n');
63                        }
64                        out.push_str(s.value().trim());
65                    }
66                }
67            }
68        }
69    }
70    out
71}
72
73#[proc_macro_attribute]
74/// `#[adk::tool]` — see crate docs.
75#[allow(clippy::match_on_vec_items)] // `inputs.len() != 2` is checked first
76pub fn tool(_attr: TokenStream, item: TokenStream) -> TokenStream {
77    let f = parse_macro_input!(item as ItemFn);
78    let vis = &f.vis;
79    let name_ident = &f.sig.ident;
80    let name_str = name_ident.to_string();
81    let pascal = upper_camel(&name_str);
82    let struct_ident = syn::Ident::new(&pascal, Span::call_site());
83    let description = doc_comment(&f.attrs);
84
85    if f.sig.asyncness.is_none() {
86        return TokenStream::from(quote! {
87            compile_error!("#[adk_rs::tool] requires an async fn");
88        });
89    }
90
91    // Inspect arguments: expect (args: ArgsType, ctx: &mut ToolContext)
92    let inputs: Vec<&FnArg> = f.sig.inputs.iter().collect();
93    if inputs.len() != 2 {
94        return TokenStream::from(quote! {
95            compile_error!("#[adk_rs::tool] requires exactly two args: (args: T, ctx: &mut ToolContext)");
96        });
97    }
98    let PatType {
99        pat: arg_pat,
100        ty: arg_ty,
101        ..
102    } = match inputs[0] {
103        FnArg::Typed(p) => p.clone(),
104        FnArg::Receiver(_) => {
105            return TokenStream::from(quote! {
106                compile_error!("#[adk_rs::tool] doesn't support receivers");
107            });
108        }
109    };
110    let arg_ident = match *arg_pat {
111        Pat::Ident(ref id) => id.ident.clone(),
112        _ => {
113            return TokenStream::from(
114                quote! { compile_error!("first arg must be a simple identifier"); },
115            );
116        }
117    };
118
119    let PatType {
120        pat: ctx_pat,
121        ty: _ctx_ty,
122        ..
123    } = match inputs[1] {
124        FnArg::Typed(p) => p.clone(),
125        FnArg::Receiver(_) => {
126            return TokenStream::from(quote! {
127                compile_error!("#[adk_rs::tool] doesn't support receivers");
128            });
129        }
130    };
131    let ctx_ident = match *ctx_pat {
132        Pat::Ident(ref id) => id.ident.clone(),
133        _ => {
134            return TokenStream::from(
135                quote! { compile_error!("second arg must be a simple identifier"); },
136            );
137        }
138    };
139
140    // We need the args type to implement schemars::JsonSchema; we'll call
141    // `schemars::schema_for!` on it at runtime via a helper.
142    let arg_ty_owned: Type = parse_quote!(#arg_ty);
143
144    let body = &f.block;
145    let ret_ty = &f.sig.output;
146
147    let constructor_name = name_ident.clone();
148
149    let expanded = quote! {
150        #[doc = #description]
151        #[derive(Debug, Default, Clone, Copy)]
152        #vis struct #struct_ident;
153
154        #[::async_trait::async_trait]
155        impl ::adk_rs::__private::DynTool for #struct_ident {
156            fn name(&self) -> &str { #name_str }
157            fn description(&self) -> &str { #description }
158            fn declaration(&self) -> ::std::option::Option<::adk_rs::__private::FunctionDeclaration> {
159                let root = ::schemars::schema_for!(#arg_ty_owned);
160                let schema = ::adk_rs::__private::Schema::from_schemars(&root)
161                    .unwrap_or_else(|_| ::adk_rs::__private::Schema::object());
162                ::std::option::Option::Some(
163                    ::adk_rs::__private::FunctionDeclaration::new(#name_str, #description)
164                        .with_parameters(schema),
165                )
166            }
167            async fn run(
168                &self,
169                args: ::serde_json::Value,
170                #ctx_ident: &mut ::adk_rs::__private::ToolContext,
171            ) -> ::adk_rs::__private::Result<::serde_json::Value> {
172                async fn __inner(#arg_ident: #arg_ty_owned, #ctx_ident: &mut ::adk_rs::__private::ToolContext) #ret_ty #body
173                let typed: #arg_ty_owned = ::serde_json::from_value(args).map_err(|e| {
174                    ::adk_rs::__private::Error::Tool(::adk_rs::__private::ToolError::InvalidArgs {
175                        tool: #name_str.to_string(),
176                        message: e.to_string(),
177                    })
178                })?;
179                let r = __inner(typed, #ctx_ident).await?;
180                ::serde_json::to_value(r).map_err(|e| {
181                    ::adk_rs::__private::Error::Tool(::adk_rs::__private::ToolError::Execution {
182                        tool: #name_str.to_string(),
183                        message: e.to_string(),
184                    })
185                })
186            }
187        }
188
189        /// Construct the tool.
190        #vis fn #constructor_name() -> ::std::sync::Arc<dyn ::adk_rs::__private::DynTool> {
191            ::std::sync::Arc::new(#struct_ident)
192        }
193    };
194
195    TokenStream::from(expanded)
196}