1#![forbid(unsafe_code)]
32
33use proc_macro::TokenStream;
34use proc_macro2::Span;
35use quote::quote;
36use syn::{FnArg, ItemFn, Pat, PatType, Type, parse_macro_input, parse_quote};
37
38fn upper_camel(s: &str) -> String {
39 let mut out = String::new();
40 let mut up = true;
41 for c in s.chars() {
42 if c == '_' {
43 up = true;
44 } else if up {
45 out.push(c.to_ascii_uppercase());
46 up = false;
47 } else {
48 out.push(c);
49 }
50 }
51 out
52}
53
54fn doc_comment(attrs: &[syn::Attribute]) -> String {
55 let mut out = String::new();
56 for a in attrs {
57 if a.path().is_ident("doc") {
58 if let syn::Meta::NameValue(nv) = a.meta.clone() {
59 if let syn::Expr::Lit(lit) = nv.value {
60 if let syn::Lit::Str(s) = lit.lit {
61 if !out.is_empty() {
62 out.push('\n');
63 }
64 out.push_str(s.value().trim());
65 }
66 }
67 }
68 }
69 }
70 out
71}
72
73#[proc_macro_attribute]
74#[allow(clippy::match_on_vec_items)] pub fn tool(_attr: TokenStream, item: TokenStream) -> TokenStream {
77 let f = parse_macro_input!(item as ItemFn);
78 let vis = &f.vis;
79 let name_ident = &f.sig.ident;
80 let name_str = name_ident.to_string();
81 let pascal = upper_camel(&name_str);
82 let struct_ident = syn::Ident::new(&pascal, Span::call_site());
83 let description = doc_comment(&f.attrs);
84
85 if f.sig.asyncness.is_none() {
86 return TokenStream::from(quote! {
87 compile_error!("#[adk_rs::tool] requires an async fn");
88 });
89 }
90
91 let inputs: Vec<&FnArg> = f.sig.inputs.iter().collect();
93 if inputs.len() != 2 {
94 return TokenStream::from(quote! {
95 compile_error!("#[adk_rs::tool] requires exactly two args: (args: T, ctx: &mut ToolContext)");
96 });
97 }
98 let PatType {
99 pat: arg_pat,
100 ty: arg_ty,
101 ..
102 } = match inputs[0] {
103 FnArg::Typed(p) => p.clone(),
104 FnArg::Receiver(_) => {
105 return TokenStream::from(quote! {
106 compile_error!("#[adk_rs::tool] doesn't support receivers");
107 });
108 }
109 };
110 let arg_ident = match *arg_pat {
111 Pat::Ident(ref id) => id.ident.clone(),
112 _ => {
113 return TokenStream::from(
114 quote! { compile_error!("first arg must be a simple identifier"); },
115 );
116 }
117 };
118
119 let PatType {
120 pat: ctx_pat,
121 ty: _ctx_ty,
122 ..
123 } = match inputs[1] {
124 FnArg::Typed(p) => p.clone(),
125 FnArg::Receiver(_) => {
126 return TokenStream::from(quote! {
127 compile_error!("#[adk_rs::tool] doesn't support receivers");
128 });
129 }
130 };
131 let ctx_ident = match *ctx_pat {
132 Pat::Ident(ref id) => id.ident.clone(),
133 _ => {
134 return TokenStream::from(
135 quote! { compile_error!("second arg must be a simple identifier"); },
136 );
137 }
138 };
139
140 let arg_ty_owned: Type = parse_quote!(#arg_ty);
143
144 let body = &f.block;
145 let ret_ty = &f.sig.output;
146
147 let constructor_name = name_ident.clone();
148
149 let expanded = quote! {
150 #[doc = #description]
151 #[derive(Debug, Default, Clone, Copy)]
152 #vis struct #struct_ident;
153
154 #[::async_trait::async_trait]
155 impl ::adk_rs::__private::DynTool for #struct_ident {
156 fn name(&self) -> &str { #name_str }
157 fn description(&self) -> &str { #description }
158 fn declaration(&self) -> ::std::option::Option<::adk_rs::__private::FunctionDeclaration> {
159 let root = ::schemars::schema_for!(#arg_ty_owned);
160 let schema = ::adk_rs::__private::Schema::from_schemars(&root)
161 .unwrap_or_else(|_| ::adk_rs::__private::Schema::object());
162 ::std::option::Option::Some(
163 ::adk_rs::__private::FunctionDeclaration::new(#name_str, #description)
164 .with_parameters(schema),
165 )
166 }
167 async fn run(
168 &self,
169 args: ::serde_json::Value,
170 #ctx_ident: &mut ::adk_rs::__private::ToolContext,
171 ) -> ::adk_rs::__private::Result<::serde_json::Value> {
172 async fn __inner(#arg_ident: #arg_ty_owned, #ctx_ident: &mut ::adk_rs::__private::ToolContext) #ret_ty #body
173 let typed: #arg_ty_owned = ::serde_json::from_value(args).map_err(|e| {
174 ::adk_rs::__private::Error::Tool(::adk_rs::__private::ToolError::InvalidArgs {
175 tool: #name_str.to_string(),
176 message: e.to_string(),
177 })
178 })?;
179 let r = __inner(typed, #ctx_ident).await?;
180 ::serde_json::to_value(r).map_err(|e| {
181 ::adk_rs::__private::Error::Tool(::adk_rs::__private::ToolError::Execution {
182 tool: #name_str.to_string(),
183 message: e.to_string(),
184 })
185 })
186 }
187 }
188
189 #vis fn #constructor_name() -> ::std::sync::Arc<dyn ::adk_rs::__private::DynTool> {
191 ::std::sync::Arc::new(#struct_ident)
192 }
193 };
194
195 TokenStream::from(expanded)
196}