1use inflector::Inflector;
2use proc_macro::TokenStream;
3use quote::quote;
4use syn::{Attribute, FnArg, ItemFn, Pat, PatIdent, Type, parse_macro_input};
5
6fn path_last_ident_is(path: &syn::Path, name: &str) -> bool {
7 path.segments
8 .last()
9 .map(|seg| seg.ident == name)
10 .unwrap_or(false)
11}
12
13fn path_ends_with(path: &syn::Path, segments: &[&str]) -> bool {
14 if segments.is_empty() {
15 return false;
16 }
17 let pathlen = path.segments.len();
18 if pathlen < segments.len() {
19 return false;
20 }
21 path.segments
22 .iter()
23 .skip(pathlen - segments.len())
24 .zip(segments.iter())
25 .all(|(a, b)| a.ident == *b)
26}
27
28fn first_generic_arg<'a>(tp: &'a syn::TypePath) -> Option<&'a Type> {
29 tp.path.segments.last().and_then(|seg| {
30 if let syn::PathArguments::AngleBracketed(ab) = &seg.arguments {
31 ab.args.iter().find_map(|ga| {
32 if let syn::GenericArgument::Type(t) = ga {
33 Some(t)
34 } else {
35 None
36 }
37 })
38 } else {
39 None
40 }
41 })
42}
43
44fn primitive_json_type_name(path: &syn::Path) -> Option<&'static str> {
45 let ident = path.segments.last()?.ident.to_string();
46 Some(match ident.as_str() {
47 "i8" | "i16" | "i32" | "i64" | "i128" | "u8" | "u16" | "u32" | "u64" | "u128" | "isize"
48 | "usize" => "integer",
49 "f32" | "f64" => "number",
50 "bool" => "boolean",
51 "String" | "str" => "string",
52 _ => return None,
53 })
54}
55
56fn ty_to_schema(ty: &Type) -> serde_json::Value {
57 match ty {
58 Type::Reference(r) => ty_to_schema(&r.elem),
60
61 Type::Array(arr) => serde_json::json!({
63 "type": "array",
64 "items": ty_to_schema(&arr.elem),
65 }),
66
67 Type::Tuple(t) => {
69 let items: Vec<serde_json::Value> = t.elems.iter().map(ty_to_schema).collect();
70 serde_json::json!({
71 "type": "array",
72 "items": items,
73 "minItems": items.len(),
74 "maxItems": items.len(),
75 })
76 }
77
78 Type::Path(tp) => {
80 if path_last_ident_is(&tp.path, "Vec") {
81 if let Some(inner) = first_generic_arg(tp) {
82 return serde_json::json!({
83 "type": "array",
84 "items": ty_to_schema(inner),
85 });
86 }
87 return serde_json::json!({
88 "type": "array",
89 "items": { "type": "string" },
90 });
91 }
92
93 if path_last_ident_is(&tp.path, "Option") {
94 if let Some(inner) = first_generic_arg(tp) {
95 return ty_to_schema(inner);
96 }
97 return serde_json::Value::Object(serde_json::Map::new());
98 }
99
100 if let Some(json_ty) = primitive_json_type_name(&tp.path) {
101 return serde_json::json!({"type": json_ty});
102 }
103
104 if path_ends_with(&tp.path, &["serde_json", "Value"])
105 || path_last_ident_is(&tp.path, "Value")
106 {
107 return serde_json::Value::Object(serde_json::Map::new());
108 }
109
110 serde_json::json!({"type": "string"})
111 }
112
113 _ => serde_json::json!({"type": "string"}),
114 }
115}
116
117fn collect_doc(attrs: &[Attribute]) -> String {
118 attrs
119 .iter()
120 .filter(|attr| attr.path().is_ident("doc"))
121 .filter_map(|attr| {
122 let Ok(nv) = attr.meta.require_name_value() else {
123 return None;
124 };
125 let syn::Expr::Lit(syn::ExprLit {
126 lit: syn::Lit::Str(s),
127 ..
128 }) = &nv.value
129 else {
130 return None;
131 };
132 Some(s.value())
133 })
134 .collect::<Vec<_>>()
135 .join("\n")
136}
137
138#[proc_macro_attribute]
139pub fn aitool(_attr: TokenStream, code: TokenStream) -> TokenStream {
140 let func: ItemFn = parse_macro_input!(code);
141 let funcsig = &func.sig;
142 let func_name = funcsig.ident.to_string();
143 let doc = collect_doc(&func.attrs);
144
145 let is_async = funcsig.asyncness.is_some();
146
147 let mut fields = Vec::new();
148 let mut field_names = Vec::new();
149 let mut field_idents: Vec<syn::Ident> = Vec::new();
150 let mut args = serde_json::Map::new();
151 let mut required_args = Vec::new();
152 let mut errors: Vec<syn::Error> = Vec::new();
153
154 for input in &funcsig.inputs {
155 match input {
156 FnArg::Typed(pat_ty) => {
157 let param_ident = match &*pat_ty.pat {
159 Pat::Ident(PatIdent { ident, .. }) => ident.clone(),
160 _ => {
161 errors.push(syn::Error::new_spanned(
162 &pat_ty.pat,
163 "unsupported parameter pattern. expected a simple identifier like `arg: T`.\n\
164 Examples of unsupported patterns: `(_: T)`, `(a, b): (T, U)`, `S { x, y }: S`."
165 ));
166 continue;
167 }
168 };
169
170 let pat_ty_ty = &pat_ty.ty;
171 let param_name = param_ident.to_string();
172 fields.push(quote!(pub #param_ident: #pat_ty_ty));
173 field_names.push(param_name.clone());
174 field_idents.push(param_ident.clone());
175 let schema = ty_to_schema(pat_ty_ty);
176 args.insert(param_name.clone(), schema);
177 let mut is_optional = false;
178 if let Type::Path(tp) = &*pat_ty.ty {
179 if path_last_ident_is(&tp.path, "Option") {
180 is_optional = true;
181 }
182 }
183 if !is_optional {
184 required_args.push(param_name);
185 }
186 }
187 FnArg::Receiver(recv) => {
188 errors.push(syn::Error::new_spanned(
189 recv,
190 "#[aitool] must be placed on a free-standing function (no `self`).\
191 Move the function out of the `impl` block or remove the receiver.",
192 ));
193 }
194 }
195 }
196
197 if !errors.is_empty() {
198 let compile_errors = errors.into_iter().map(|err| err.to_compile_error());
199 return quote! { #(#compile_errors)* }.into();
200 }
201
202 let args_struct_ident = syn::Ident::new(
203 &format!("{}Args", func_name.to_table_case().to_pascal_case()),
204 funcsig.ident.span(),
205 );
206
207 let fields_tokens = quote!(#(#fields),*);
208 let required_array = serde_json::Value::Array(
209 required_args
210 .iter()
211 .map(|arg| serde_json::Value::String(arg.clone()))
212 .collect(),
213 );
214
215 let mut schema = serde_json::Map::new();
216 schema.insert(
217 "name".to_string(),
218 serde_json::Value::String(func_name.clone()),
219 );
220 schema.insert(
221 "description".to_string(),
222 serde_json::Value::String(doc.clone()),
223 );
224
225 let mut parameters = serde_json::Map::new();
226 parameters.insert(
227 "type".to_string(),
228 serde_json::Value::String("object".to_string()),
229 );
230 parameters.insert("properties".to_string(), serde_json::Value::Object(args));
231 parameters.insert("required".to_string(), required_array);
232
233 schema.insert(
234 "parameters".to_string(),
235 serde_json::Value::Object(parameters),
236 );
237 let json_schema = serde_json::to_string(&schema).unwrap();
238 let name_lit = syn::LitStr::new(&func_name, funcsig.ident.span());
239 let desc_lit = syn::LitStr::new(&doc, funcsig.ident.span());
240 let json_schema_lit = syn::LitStr::new(&json_schema, funcsig.ident.span());
241
242 let func_wrapper_name = syn::Ident::new(
243 &format!("__invoke_{}", func_name.clone()),
244 funcsig.ident.span(),
245 );
246 let reg_name = syn::Ident::new(
247 &format!("__REG_{}", func_name.clone().to_screaming_snake_case()),
248 funcsig.ident.span(),
249 );
250
251 let ident = &funcsig.ident;
252 let invoke_fn = if is_async {
253 quote! {
254 fn #func_wrapper_name(args: ::serde_json::Value) -> ::reductool::InvokeFuture {
255 Box::pin(async move {
256 let parsed: #args_struct_ident = ::serde_json::from_value(args)?;
257 let out = #ident(#(parsed.#field_idents),*).await;
258 ::serde_json::to_value(out).map_err(Into::into)
259 })
260 }
261 }
262 } else {
263 quote! {
264 fn #func_wrapper_name(args: ::serde_json::Value) -> ::reductool::InvokeFuture {
265 Box::pin(async move {
266 let parsed: #args_struct_ident = ::serde_json::from_value(args)?;
267 let out = #ident(#(parsed.#field_idents),*);
268 ::serde_json::to_value(out).map_err(Into::into)
269 })
270 }
271 }
272 };
273
274 let expanded = quote! {
275 #func
276
277 #[derive(::serde::Deserialize)]
278 struct #args_struct_ident {
279 #fields_tokens
280 }
281
282 #invoke_fn
283
284 #[::reductool::__linkme::distributed_slice(::reductool::ALL_TOOLS)]
285 static #reg_name: ::reductool::ToolDefinition = ::reductool::ToolDefinition {
286 name: #name_lit,
287 description: #desc_lit,
288 json_schema: #json_schema_lit,
289 invoke: #func_wrapper_name,
290 };
291 };
292 expanded.into()
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298 use serde_json::{Value, json};
299 use syn::{FnArg, ItemFn, Pat, PatIdent, Type, parse_str};
300
301 fn build_props_and_required(func: &ItemFn) -> (serde_json::Map<String, Value>, Vec<String>) {
303 let mut props = serde_json::Map::new();
304 let mut required = Vec::new();
305
306 for input in &func.sig.inputs {
307 if let FnArg::Typed(pat_ty) = input {
308 let param_ident = match &*pat_ty.pat {
309 Pat::Ident(PatIdent { ident, .. }) => ident.clone(),
310 _ => continue,
311 };
312 let name = param_ident.to_string();
313
314 let schema = ty_to_schema(&pat_ty.ty);
316 props.insert(name.clone(), schema);
317
318 let mut is_optional = false;
320 if let Type::Path(tp) = &*pat_ty.ty {
321 if path_last_ident_is(&tp.path, "Option") {
322 is_optional = true;
323 }
324 }
325 if !is_optional {
326 required.push(name);
327 }
328 }
329 }
330 (props, required)
331 }
332
333 fn parse_fn(src: &str) -> ItemFn {
334 parse_str::<ItemFn>(src).expect("failed to parse function")
335 }
336
337 #[test]
338 fn test_primitives_and_refs() {
339 let func = parse_fn("fn f(a: i32, b: f64, c: bool, d: String, e: &str) {}");
340 let (props, required) = build_props_and_required(&func);
341
342 assert_eq!(props.get("a").unwrap(), &json!({ "type": "integer" }));
343 assert_eq!(props.get("b").unwrap(), &json!({ "type": "number" }));
344 assert_eq!(props.get("c").unwrap(), &json!({ "type": "boolean" }));
345 assert_eq!(props.get("d").unwrap(), &json!({ "type": "string" }));
346 assert_eq!(props.get("e").unwrap(), &json!({ "type": "string" }));
347
348 assert_eq!(required, vec!["a", "b", "c", "d", "e"]);
349 }
350
351 #[test]
352 fn test_array_and_tuple() {
353 let func = parse_fn("fn g(x: [i32; 3], y: (i32, String, bool)) {}");
354 let (props, required) = build_props_and_required(&func);
355
356 assert_eq!(
357 props.get("x").unwrap(),
358 &json!({ "type": "array", "items": { "type": "integer" } })
359 );
360 assert_eq!(
361 props.get("y").unwrap(),
362 &json!({
363 "type": "array",
364 "items": [
365 { "type": "integer" },
366 { "type": "string" },
367 { "type": "boolean" }
368 ],
369 "minItems": 3,
370 "maxItems": 3
371 })
372 );
373
374 assert_eq!(required, vec!["x", "y"]);
375 }
376
377 #[test]
378 fn test_vec_and_option() {
379 let func = parse_fn("fn h(a: Vec<i32>, b: Option<String>, c: Option<Vec<bool>>) {}");
380 let (props, required) = build_props_and_required(&func);
381
382 assert_eq!(
383 props.get("a").unwrap(),
384 &json!({ "type": "array", "items": { "type": "integer" } })
385 );
386 assert_eq!(props.get("b").unwrap(), &json!({ "type": "string" }));
387 assert_eq!(
388 props.get("c").unwrap(),
389 &json!({ "type": "array", "items": { "type": "boolean" } })
390 );
391
392 assert_eq!(required, vec!["a"]);
394 }
395
396 #[test]
397 fn test_json_value_and_bare_value() {
398 let func = parse_fn("fn j(a: serde_json::Value, b: Value) {}");
399 let (props, required) = build_props_and_required(&func);
400
401 assert_eq!(props.get("a").unwrap(), &json!({}));
402 assert_eq!(props.get("b").unwrap(), &json!({}));
403
404 assert_eq!(required, vec!["a", "b"]);
405 }
406
407 #[test]
408 fn test_custom_type_and_ref_of_option() {
409 let func = parse_fn("fn k(a: MyType, b: &Option<String>) {}");
411 let (props, required) = build_props_and_required(&func);
412
413 assert_eq!(props.get("a").unwrap(), &json!({ "type": "string" }));
414 assert_eq!(props.get("b").unwrap(), &json!({ "type": "string" }));
415 assert_eq!(required, vec!["a", "b"]);
416 }
417}