1use darling::{ast, FromDeriveInput, FromField};
30use proc_macro::TokenStream;
31use quote::{format_ident, quote};
32use syn::{parse::Parser, parse_macro_input, DeriveInput, FnArg, Ident, ItemFn, Meta, Type};
33
34#[derive(Debug, FromDeriveInput)]
36#[darling(attributes(tool), supports(struct_named))]
37struct ToolArgs {
38 ident: Ident,
39 data: ast::Data<(), ToolField>,
40
41 #[darling(default)]
43 name: Option<String>,
44
45 #[darling(default)]
47 description: Option<String>,
48
49 #[darling(default)]
51 dangerous: bool,
52}
53
54#[derive(Debug, FromField)]
56#[darling(attributes(tool))]
57struct ToolField {
58 ident: Option<Ident>,
59 ty: Type,
60
61 #[darling(default)]
63 param: bool,
64
65 #[darling(default)]
67 required: bool,
68
69 #[darling(default)]
71 description: Option<String>,
72
73 #[darling(default)]
75 skip: bool,
76}
77
78#[proc_macro_derive(Tool, attributes(tool))]
101pub fn derive_tool(input: TokenStream) -> TokenStream {
102 let input = parse_macro_input!(input as DeriveInput);
103
104 let args = match ToolArgs::from_derive_input(&input) {
105 Ok(args) => args,
106 Err(e) => return e.write_errors().into(),
107 };
108
109 let expanded = generate_tool_impl(&args);
110
111 TokenStream::from(expanded)
112}
113
114fn generate_tool_impl(args: &ToolArgs) -> proc_macro2::TokenStream {
115 let struct_name = &args.ident;
116
117 let tool_name = args.name.clone().unwrap_or_else(|| {
119 let name = struct_name.to_string();
120 let name = name.strip_suffix("Tool").unwrap_or(&name);
121 to_snake_case(name)
122 });
123
124 let description = args
125 .description
126 .clone()
127 .unwrap_or_else(|| format!("{} tool", tool_name));
128
129 let dangerous = args.dangerous;
130
131 let fields = match &args.data {
133 ast::Data::Struct(fields) => fields,
134 _ => panic!("Tool derive only supports structs"),
135 };
136
137 let params: Vec<_> = fields
138 .fields
139 .iter()
140 .filter(|f| f.param && !f.skip)
141 .collect();
142
143 let param_properties = generate_param_properties(¶ms);
145 let required_params = generate_required_params(¶ms);
146
147 let arg_extractions = generate_arg_extractions(¶ms);
149
150 let param_names: Vec<_> = params.iter().map(|f| f.ident.as_ref().unwrap()).collect();
152
153 quote! {
154 #[async_trait::async_trait]
155 impl cortexai_core::tool::Tool for #struct_name {
156 fn schema(&self) -> cortexai_core::tool::ToolSchema {
157 let mut __properties = serde_json::Map::new();
158 #(#param_properties)*
159
160 cortexai_core::tool::ToolSchema {
161 name: #tool_name.to_string(),
162 description: #description.to_string(),
163 parameters: serde_json::json!({
164 "type": "object",
165 "properties": serde_json::Value::Object(__properties),
166 "required": [#(#required_params),*]
167 }),
168 dangerous: #dangerous,
169 metadata: std::collections::HashMap::new(),
170 required_scopes: vec![],
171 }
172 }
173
174 async fn execute(
175 &self,
176 _context: &cortexai_core::tool::ExecutionContext,
177 arguments: serde_json::Value,
178 ) -> Result<serde_json::Value, cortexai_core::errors::ToolError> {
179 #(#arg_extractions)*
180
181 self.run(#(#param_names),*).await
182 .map_err(|e| cortexai_core::errors::ToolError::ExecutionFailed(e.to_string()))
183 }
184 }
185 }
186}
187
188fn generate_param_properties(params: &[&ToolField]) -> Vec<proc_macro2::TokenStream> {
189 params
190 .iter()
191 .map(|field| {
192 let name = field.ident.as_ref().unwrap().to_string();
193 let description = field
194 .description
195 .clone()
196 .unwrap_or_else(|| format!("Parameter: {}", name));
197 let effective_ty = unwrap_option_type(&field.ty).unwrap_or(&field.ty);
198 let schema_tokens = type_to_json_schema(effective_ty);
199
200 quote! {
201 {
202 let mut __prop_schema = #schema_tokens;
203 if let serde_json::Value::Object(ref mut m) = __prop_schema {
204 m.insert("description".to_string(), serde_json::Value::String(#description.to_string()));
205 }
206 __properties.insert(#name.to_string(), __prop_schema);
207 }
208 }
209 })
210 .collect()
211}
212
213fn generate_required_params(params: &[&ToolField]) -> Vec<proc_macro2::TokenStream> {
214 params
215 .iter()
216 .filter(|f| f.required)
217 .map(|field| {
218 let name = field.ident.as_ref().unwrap().to_string();
219 quote! { #name }
220 })
221 .collect()
222}
223
224fn generate_arg_extractions(params: &[&ToolField]) -> Vec<proc_macro2::TokenStream> {
225 params
226 .iter()
227 .map(|field| {
228 let ident = field.ident.as_ref().unwrap();
229 let name = ident.to_string();
230 let ty = &field.ty;
231
232 if is_option_type(ty) {
233 quote! {
234 let #ident: #ty = arguments.get(#name)
235 .and_then(|v| serde_json::from_value(v.clone()).ok());
236 }
237 } else if field.required {
238 quote! {
239 let #ident: #ty = {
240 let val = arguments.get(#name)
241 .ok_or_else(|| cortexai_core::errors::ToolError::InvalidArguments(
242 format!("Missing required parameter: {}", #name)
243 ))?
244 .clone();
245 serde_json::from_value(val)
246 .map_err(|e| cortexai_core::errors::ToolError::InvalidArguments(
247 format!("Invalid type for {}: {}", #name, e)
248 ))?
249 };
250 }
251 } else {
252 quote! {
254 let #ident: #ty = arguments.get(#name)
255 .and_then(|v| serde_json::from_value(v.clone()).ok())
256 .unwrap_or_default();
257 }
258 }
259 })
260 .collect()
261}
262
263fn type_to_json_schema(ty: &Type) -> proc_macro2::TokenStream {
266 let type_name = extract_type_name(ty);
267
268 match type_name.as_str() {
269 "String" | "&str" | "str" => quote! { serde_json::json!({"type": "string"}) },
270 "bool" => quote! { serde_json::json!({"type": "boolean"}) },
271 "f32" | "f64" => quote! { serde_json::json!({"type": "number"}) },
272 "i8" | "i16" | "i32" | "i64" | "i128" | "isize"
273 | "u8" | "u16" | "u32" | "u64" | "u128" | "usize" => {
274 quote! { serde_json::json!({"type": "integer"}) }
275 }
276 "Vec" | "Array" => {
277 let items_schema = extract_first_generic_arg(ty)
278 .map(|inner| type_to_json_schema(inner))
279 .unwrap_or_else(|| quote! { serde_json::json!({}) });
280 quote! {
281 {
282 let __items = #items_schema;
283 serde_json::json!({"type": "array", "items": __items})
284 }
285 }
286 }
287 "HashMap" | "BTreeMap" => {
288 let value_schema = extract_second_generic_arg(ty)
289 .map(|inner| type_to_json_schema(inner))
290 .unwrap_or_else(|| quote! { serde_json::json!({}) });
291 quote! {
292 {
293 let __additional = #value_schema;
294 serde_json::json!({"type": "object", "additionalProperties": __additional})
295 }
296 }
297 }
298 "Option" => {
299 extract_first_generic_arg(ty)
301 .map(|inner| type_to_json_schema(inner))
302 .unwrap_or_else(|| quote! { serde_json::json!({"type": "object"}) })
303 }
304 _ => quote! { serde_json::json!({"type": "object"}) },
305 }
306}
307
308fn extract_type_name(ty: &Type) -> String {
310 if let Type::Path(type_path) = ty {
311 if let Some(segment) = type_path.path.segments.last() {
312 return segment.ident.to_string();
313 }
314 }
315 String::new()
316}
317
318fn extract_first_generic_arg(ty: &Type) -> Option<&Type> {
320 if let Type::Path(type_path) = ty {
321 if let Some(segment) = type_path.path.segments.last() {
322 if let syn::PathArguments::AngleBracketed(ref args) = segment.arguments {
323 if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
324 return Some(inner);
325 }
326 }
327 }
328 }
329 None
330}
331
332fn extract_second_generic_arg(ty: &Type) -> Option<&Type> {
334 if let Type::Path(type_path) = ty {
335 if let Some(segment) = type_path.path.segments.last() {
336 if let syn::PathArguments::AngleBracketed(ref args) = segment.arguments {
337 if let Some(syn::GenericArgument::Type(inner)) = args.args.iter().nth(1) {
338 return Some(inner);
339 }
340 }
341 }
342 }
343 None
344}
345
346fn unwrap_option_type(ty: &Type) -> Option<&Type> {
348 if extract_type_name(ty) == "Option" {
349 extract_first_generic_arg(ty)
350 } else {
351 None
352 }
353}
354
355fn is_option_type(ty: &Type) -> bool {
356 extract_type_name(ty) == "Option"
357}
358
359fn to_snake_case(s: &str) -> String {
360 let mut result = String::new();
361 for (i, c) in s.chars().enumerate() {
362 if c.is_uppercase() {
363 if i > 0 {
364 result.push('_');
365 }
366 result.push(c.to_ascii_lowercase());
367 } else {
368 result.push(c);
369 }
370 }
371 result
372}
373
374fn to_pascal_case(s: &str) -> String {
375 s.split('_')
376 .map(|part| {
377 let mut chars = part.chars();
378 match chars.next() {
379 None => String::new(),
380 Some(first) => {
381 let upper: String = first.to_uppercase().collect();
382 upper + &chars.collect::<String>()
383 }
384 }
385 })
386 .collect()
387}
388
389struct FnParam {
391 ident: Ident,
392 ty: Type,
393 description: Option<String>,
394 required: bool,
395 name_override: Option<String>,
396 is_option: bool,
397}
398
399fn parse_param_attrs(attrs: &[syn::Attribute]) -> (Option<String>, bool, Option<String>) {
400 let mut description = None;
401 let mut required = false;
402 let mut name_override = None;
403
404 for attr in attrs {
405 if !attr.path().is_ident("param") {
406 continue;
407 }
408 if let Meta::List(meta_list) = &attr.meta {
409 let tokens = meta_list.tokens.clone();
410 let parser = syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated;
411 if let Ok(nested) = parser.parse2(tokens) {
412 for meta in &nested {
413 match meta {
414 Meta::Path(path) if path.is_ident("required") => {
415 required = true;
416 }
417 Meta::NameValue(nv) if nv.path.is_ident("description") => {
418 if let syn::Expr::Lit(syn::ExprLit {
419 lit: syn::Lit::Str(s),
420 ..
421 }) = &nv.value
422 {
423 description = Some(s.value());
424 }
425 }
426 Meta::NameValue(nv) if nv.path.is_ident("name") => {
427 if let syn::Expr::Lit(syn::ExprLit {
428 lit: syn::Lit::Str(s),
429 ..
430 }) = &nv.value
431 {
432 name_override = Some(s.value());
433 }
434 }
435 _ => {}
436 }
437 }
438 }
439 }
440 }
441 (description, required, name_override)
442}
443
444fn extract_fn_params(sig: &syn::Signature) -> Vec<FnParam> {
445 sig.inputs
446 .iter()
447 .filter_map(|arg| {
448 if let FnArg::Typed(pat_type) = arg {
449 if let syn::Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
450 let ident = pat_ident.ident.clone();
451 let ty = *pat_type.ty.clone();
452 let is_option = is_option_type(&ty);
453 let (description, explicit_required, name_override) =
454 parse_param_attrs(&pat_type.attrs);
455 let required = explicit_required || !is_option;
457 return Some(FnParam {
458 ident,
459 ty,
460 description,
461 required,
462 name_override,
463 is_option,
464 });
465 }
466 }
467 None
468 })
469 .collect()
470}
471
472#[proc_macro_attribute]
477pub fn tool(attr: TokenStream, item: TokenStream) -> TokenStream {
478 let attr_args = proc_macro2::TokenStream::from(attr);
479 let input_fn = parse_macro_input!(item as ItemFn);
480
481 match generate_fn_tool(attr_args, &input_fn) {
482 Ok(tokens) => tokens.into(),
483 Err(err) => err.to_compile_error().into(),
484 }
485}
486
487fn parse_tool_attr_args(
488 tokens: proc_macro2::TokenStream,
489) -> syn::Result<(Option<String>, Option<String>)> {
490 let mut description = None;
491 let mut name_override = None;
492
493 if tokens.is_empty() {
494 return Ok((description, name_override));
495 }
496
497 let parser = syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated;
498 let metas = parser.parse2(tokens)?;
499
500 for meta in &metas {
501 if let Meta::NameValue(nv) = meta {
502 if nv.path.is_ident("description") {
503 if let syn::Expr::Lit(syn::ExprLit {
504 lit: syn::Lit::Str(s),
505 ..
506 }) = &nv.value
507 {
508 description = Some(s.value());
509 }
510 } else if nv.path.is_ident("name") {
511 if let syn::Expr::Lit(syn::ExprLit {
512 lit: syn::Lit::Str(s),
513 ..
514 }) = &nv.value
515 {
516 name_override = Some(s.value());
517 }
518 }
519 }
520 }
521
522 Ok((description, name_override))
523}
524
525fn generate_fn_tool(
526 attr_args: proc_macro2::TokenStream,
527 input_fn: &ItemFn,
528) -> syn::Result<proc_macro2::TokenStream> {
529 let (attr_description, attr_name) = parse_tool_attr_args(attr_args)?;
530
531 let fn_name = &input_fn.sig.ident;
532 let fn_name_str = fn_name.to_string();
533
534 let struct_name = format_ident!("{}Tool", to_pascal_case(&fn_name_str));
535 let tool_name = attr_name.unwrap_or_else(|| fn_name_str.clone());
536 let description = attr_description.unwrap_or_else(|| format!("{} tool", tool_name));
537
538 let params = extract_fn_params(&input_fn.sig);
539
540 let param_properties: Vec<proc_macro2::TokenStream> = params
542 .iter()
543 .map(|p| {
544 let name = p
545 .name_override
546 .clone()
547 .unwrap_or_else(|| p.ident.to_string());
548 let desc = p
549 .description
550 .clone()
551 .unwrap_or_else(|| format!("Parameter: {}", name));
552 let effective_ty = unwrap_option_type(&p.ty).unwrap_or(&p.ty);
553 let schema_tokens = type_to_json_schema(effective_ty);
554
555 quote! {
556 {
557 let mut __prop_schema = #schema_tokens;
558 if let serde_json::Value::Object(ref mut m) = __prop_schema {
559 m.insert("description".to_string(), serde_json::Value::String(#desc.to_string()));
560 }
561 __properties.insert(#name.to_string(), __prop_schema);
562 }
563 }
564 })
565 .collect();
566
567 let required_params: Vec<proc_macro2::TokenStream> = params
568 .iter()
569 .filter(|p| p.required)
570 .map(|p| {
571 let name = p
572 .name_override
573 .clone()
574 .unwrap_or_else(|| p.ident.to_string());
575 quote! { #name }
576 })
577 .collect();
578
579 let arg_extractions: Vec<proc_macro2::TokenStream> = params
581 .iter()
582 .map(|p| {
583 let ident = &p.ident;
584 let name = p
585 .name_override
586 .clone()
587 .unwrap_or_else(|| p.ident.to_string());
588 let ty = &p.ty;
589
590 if p.is_option {
591 quote! {
592 let #ident: #ty = arguments.get(#name)
593 .and_then(|v| serde_json::from_value(v.clone()).ok());
594 }
595 } else if p.required {
596 quote! {
597 let #ident: #ty = {
598 let val = arguments.get(#name)
599 .ok_or_else(|| cortexai_core::errors::ToolError::InvalidArguments(
600 format!("Missing required parameter: {}", #name)
601 ))?
602 .clone();
603 serde_json::from_value(val)
604 .map_err(|e| cortexai_core::errors::ToolError::InvalidArguments(
605 format!("Invalid type for {}: {}", #name, e)
606 ))?
607 };
608 }
609 } else {
610 quote! {
611 let #ident: #ty = arguments.get(#name)
612 .and_then(|v| serde_json::from_value(v.clone()).ok())
613 .unwrap_or_default();
614 }
615 }
616 })
617 .collect();
618
619 let param_idents: Vec<&Ident> = params.iter().map(|p| &p.ident).collect();
620
621 let mut clean_fn = input_fn.clone();
623 for arg in &mut clean_fn.sig.inputs {
624 if let FnArg::Typed(pat_type) = arg {
625 pat_type.attrs.retain(|a| !a.path().is_ident("param"));
626 }
627 }
628
629 Ok(quote! {
630 #clean_fn
631
632 #[derive(Default)]
633 pub struct #struct_name;
634
635 #[async_trait::async_trait]
636 impl cortexai_core::tool::Tool for #struct_name {
637 fn schema(&self) -> cortexai_core::tool::ToolSchema {
638 let mut __properties = serde_json::Map::new();
639 #(#param_properties)*
640
641 cortexai_core::tool::ToolSchema {
642 name: #tool_name.to_string(),
643 description: #description.to_string(),
644 parameters: serde_json::json!({
645 "type": "object",
646 "properties": serde_json::Value::Object(__properties),
647 "required": [#(#required_params),*]
648 }),
649 dangerous: false,
650 metadata: std::collections::HashMap::new(),
651 required_scopes: Vec::new(),
652 }
653 }
654
655 async fn execute(
656 &self,
657 _context: &cortexai_core::tool::ExecutionContext,
658 arguments: serde_json::Value,
659 ) -> Result<serde_json::Value, cortexai_core::errors::ToolError> {
660 #(#arg_extractions)*
661
662 #fn_name(#(#param_idents),*).await
663 .map_err(|e| cortexai_core::errors::ToolError::ExecutionFailed(e.to_string()))
664 }
665 }
666 })
667}