1use proc_macro::TokenStream;
6use quote::quote;
7use syn::parse::{Parse, ParseStream};
8use syn::{parse_macro_input, FnArg, ItemFn, Lit, Pat, PatType, Token};
9
10struct AgenticArgs {
12 model: Option<String>,
13 api_key: Option<String>,
14 autobind: bool,
15 native_tools: bool,
16}
17
18impl Parse for AgenticArgs {
19 fn parse(input: ParseStream) -> syn::Result<Self> {
20 let mut model = None;
21 let mut api_key = None;
22 let mut autobind = true;
23 let mut native_tools = true; while !input.is_empty() {
26 let ident: syn::Ident = input.parse()?;
27 input.parse::<Token![=]>()?;
28
29 match ident.to_string().as_str() {
30 "model" => {
31 let lit: Lit = input.parse()?;
32 if let Lit::Str(s) = lit {
33 model = Some(s.value());
34 }
35 }
36 "api_key" => {
37 let lit: Lit = input.parse()?;
38 if let Lit::Str(s) = lit {
39 api_key = Some(s.value());
40 }
41 }
42 "autobind" => {
43 let lit: Lit = input.parse()?;
44 if let Lit::Bool(b) = lit {
45 autobind = b.value();
46 }
47 }
48 "native_tools" => {
49 let lit: Lit = input.parse()?;
50 if let Lit::Bool(b) = lit {
51 native_tools = b.value();
52 }
53 }
54 _ => {
55 return Err(syn::Error::new(ident.span(), "Unknown attribute"));
56 }
57 }
58
59 if !input.is_empty() {
60 input.parse::<Token![,]>()?;
61 }
62 }
63
64 Ok(Self {
65 model,
66 api_key,
67 autobind,
68 native_tools,
69 })
70 }
71}
72
73#[proc_macro_attribute]
103pub fn agentic(attr: TokenStream, item: TokenStream) -> TokenStream {
104 let args = parse_macro_input!(attr as AgenticArgs);
105 let input_fn = parse_macro_input!(item as ItemFn);
106
107 let model = args.model;
108 let api_key = args.api_key;
109 let _autobind = args.autobind;
110 let _native_tools = args.native_tools;
111
112 let fn_name = &input_fn.sig.ident;
114 let fn_vis = &input_fn.vis;
115 let fn_generics = &input_fn.sig.generics;
116 let fn_output = &input_fn.sig.output;
117 let fn_block = &input_fn.block;
118 let fn_attrs = &input_fn.attrs;
119 let fn_asyncness = &input_fn.sig.asyncness;
120
121 let mut other_params = Vec::new();
123 let mut has_runtime = false;
124
125 for param in input_fn.sig.inputs.iter() {
126 match param {
127 FnArg::Typed(PatType { pat, .. }) => {
128 if let Pat::Ident(pat_ident) = pat.as_ref() {
129 if pat_ident.ident == "runtime" {
130 has_runtime = true;
131 continue; }
133 }
134 other_params.push(param.clone());
135 }
136 FnArg::Receiver(_) => {
137 other_params.push(param.clone());
138 }
139 }
140 }
141
142 if !has_runtime {
143 return syn::Error::new_spanned(
144 &input_fn.sig.ident,
145 "#[agentic] function must have a `runtime: Runtime` parameter",
146 )
147 .to_compile_error()
148 .into();
149 }
150
151 let model_setup = if let Some(m) = model {
153 quote! { Some(#m.to_string()) }
154 } else {
155 quote! { None }
156 };
157
158 let api_key_setup = if let Some(k) = api_key {
159 quote! { Some(#k.to_string()) }
160 } else {
161 quote! { None }
162 };
163
164 let expanded = quote! {
169 #(#fn_attrs)*
170 #fn_vis #fn_asyncness fn #fn_name #fn_generics(#(#other_params),*) #fn_output {
171 let mut runtime = ::reson_agentic::runtime::Runtime::with_config(
173 #model_setup,
174 #api_key_setup,
175 );
176
177 let result = {
179 #fn_block
181 };
182
183 if !runtime.used {
185 panic!(
186 "agentic function '{}' completed without calling runtime.run() or runtime.run_stream()",
187 stringify!(#fn_name)
188 );
189 }
190
191 result
192 }
193 };
194
195 TokenStream::from(expanded)
196}
197
198#[proc_macro_attribute]
216pub fn agentic_generator(attr: TokenStream, item: TokenStream) -> TokenStream {
217 agentic(attr, item)
220}
221
222#[proc_macro_derive(Tool, attributes(tool))]
240pub fn derive_tool(input: TokenStream) -> TokenStream {
241 let input = parse_macro_input!(input as syn::DeriveInput);
242 let name = &input.ident;
243 let name_str = name.to_string();
244
245 let tool_name = convert_to_snake_case(&name_str);
247
248 let struct_description = extract_doc_comments(&input.attrs);
250
251 let fields = match &input.data {
253 syn::Data::Struct(data_struct) => match &data_struct.fields {
254 syn::Fields::Named(fields) => &fields.named,
255 _ => {
256 return syn::Error::new_spanned(
257 name,
258 "Tool derive only supports structs with named fields",
259 )
260 .to_compile_error()
261 .into();
262 }
263 },
264 _ => {
265 return syn::Error::new_spanned(name, "Tool derive only supports structs")
266 .to_compile_error()
267 .into();
268 }
269 };
270
271 let mut schema_properties = Vec::new();
273 let mut required_fields = Vec::new();
274
275 for field in fields {
276 let field_name = field.ident.as_ref().unwrap();
277 let field_name_str = field_name.to_string();
278
279 let field_desc = extract_doc_comments(&field.attrs);
281
282 let json_type = get_json_type(&field.ty);
284
285 let is_optional = is_option_type(&field.ty);
287
288 if !is_optional {
289 required_fields.push(field_name_str.clone());
290 }
291
292 let array_item_info = get_array_item_type(&field.ty);
294
295 match array_item_info {
296 Some(ArrayItemType::Primitive(item_type)) => {
297 schema_properties.push(quote! {
299 properties.insert(
300 #field_name_str.to_string(),
301 serde_json::json!({
302 "type": #json_type,
303 "description": #field_desc,
304 "items": {
305 "type": #item_type
306 }
307 })
308 );
309 });
310 }
311 Some(ArrayItemType::Complex(inner_ty)) => {
312 schema_properties.push(quote! {
314 {
315 let mut arr_schema = serde_json::json!({
316 "type": #json_type,
317 "description": #field_desc
318 });
319 arr_schema["items"] = #inner_ty::schema();
320 properties.insert(#field_name_str.to_string(), arr_schema);
321 }
322 });
323 }
324 None => {
325 schema_properties.push(quote! {
327 properties.insert(
328 #field_name_str.to_string(),
329 serde_json::json!({
330 "type": #json_type,
331 "description": #field_desc
332 })
333 );
334 });
335 }
336 }
337 }
338
339 let required_array = if required_fields.is_empty() {
340 quote! { serde_json::json!([]) }
341 } else {
342 let req_fields = required_fields.iter();
343 quote! { serde_json::json!([#(#req_fields),*]) }
344 };
345
346 let expanded = quote! {
347 impl #name {
348 pub fn tool_name() -> &'static str {
350 #tool_name
351 }
352
353 pub fn description() -> &'static str {
355 #struct_description
356 }
357
358 pub fn schema() -> serde_json::Value {
360 let mut properties = serde_json::Map::new();
361 #(#schema_properties)*
362
363 serde_json::json!({
364 "type": "object",
365 "properties": serde_json::Value::Object(properties),
366 "required": #required_array
367 })
368 }
369
370 pub fn tool_schema(generator: &dyn ::reson_agentic::schema::SchemaGenerator) -> serde_json::Value {
372 generator.generate_schema(
373 #tool_name,
374 #struct_description,
375 Self::schema()
376 )
377 }
378 }
379 };
380
381 TokenStream::from(expanded)
382}
383
384#[proc_macro_derive(Deserializable)]
402pub fn derive_deserializable(input: TokenStream) -> TokenStream {
403 let input = parse_macro_input!(input as syn::DeriveInput);
404 let name = &input.ident;
405
406 let fields = match &input.data {
408 syn::Data::Struct(data_struct) => match &data_struct.fields {
409 syn::Fields::Named(fields) => &fields.named,
410 _ => {
411 return syn::Error::new_spanned(
412 name,
413 "Deserializable derive only supports structs with named fields",
414 )
415 .to_compile_error()
416 .into();
417 }
418 },
419 _ => {
420 return syn::Error::new_spanned(name, "Deserializable derive only supports structs")
421 .to_compile_error()
422 .into();
423 }
424 };
425
426 let mut field_desc_tokens = Vec::new();
428 let mut validation_checks = Vec::new();
429
430 for field in fields {
431 let field_name = field.ident.as_ref().unwrap();
432 let field_name_str = field_name.to_string();
433 let field_desc = extract_doc_comments(&field.attrs);
434 let field_type = &field.ty;
435 let is_optional = is_option_type(&field.ty);
436 let is_required = !is_optional;
437
438 field_desc_tokens.push(quote! {
439 ::reson_agentic::parsers::FieldDescription {
440 name: #field_name_str.to_string(),
441 field_type: ::std::any::type_name::<#field_type>().to_string(),
442 description: #field_desc.to_string(),
443 required: #is_required,
444 }
445 });
446
447 if is_required {
449 validation_checks.push(quote! {
450 if let serde_json::Value::Null = serde_json::to_value(&self.#field_name)
451 .map_err(|e| ::reson_agentic::error::Error::NonRetryable(e.to_string()))? {
452 return Err(::reson_agentic::error::Error::NonRetryable(
453 format!("Required field '{}' is missing or null", #field_name_str)
454 ));
455 }
456 });
457 }
458 }
459
460 let validation_logic = if validation_checks.is_empty() {
461 quote! { Ok(()) }
462 } else {
463 quote! {
464 #(#validation_checks)*
465 Ok(())
466 }
467 };
468
469 let expanded = quote! {
470 impl ::reson_agentic::parsers::Deserializable for #name {
471 fn from_partial(partial: serde_json::Value) -> ::reson_agentic::error::Result<Self> {
472 serde_json::from_value(partial).map_err(|e| {
473 ::reson_agentic::error::Error::NonRetryable(format!("Failed to parse {}: {}", stringify!(#name), e))
474 })
475 }
476
477 fn validate_complete(&self) -> ::reson_agentic::error::Result<()> {
478 #validation_logic
479 }
480
481 fn field_descriptions() -> Vec<::reson_agentic::parsers::FieldDescription> {
482 vec![
483 #(#field_desc_tokens),*
484 ]
485 }
486 }
487 };
488
489 TokenStream::from(expanded)
490}
491
492fn convert_to_snake_case(s: &str) -> String {
494 let mut result = String::new();
495 for (i, ch) in s.chars().enumerate() {
496 if ch.is_uppercase() {
497 if i > 0 {
498 result.push('_');
499 }
500 result.push(ch.to_lowercase().next().unwrap());
501 } else {
502 result.push(ch);
503 }
504 }
505 result
506}
507
508fn is_option_type(ty: &syn::Type) -> bool {
510 if let syn::Type::Path(type_path) = ty {
511 if let Some(segment) = type_path.path.segments.last() {
512 return segment.ident == "Option";
513 }
514 }
515 false
516}
517
518fn extract_doc_comments(attrs: &[syn::Attribute]) -> String {
520 let mut docs = Vec::new();
521 for attr in attrs {
522 if attr.path().is_ident("doc") {
523 if let syn::Meta::NameValue(meta) = &attr.meta {
524 if let syn::Expr::Lit(expr_lit) = &meta.value {
525 if let syn::Lit::Str(lit_str) = &expr_lit.lit {
526 docs.push(lit_str.value().trim().to_string());
527 }
528 }
529 }
530 }
531 }
532 docs.join(" ")
533}
534
535fn get_json_type(ty: &syn::Type) -> String {
537 if let syn::Type::Path(type_path) = ty {
538 if let Some(segment) = type_path.path.segments.last() {
539 let ident = segment.ident.to_string();
540
541 if ident == "Option" {
543 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
544 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
545 return get_json_type(inner_ty);
546 }
547 }
548 }
549
550 return match ident.as_str() {
552 "String" | "str" => "string",
553 "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32" | "u64"
554 | "u128" | "usize" => "integer",
555 "f32" | "f64" => "number",
556 "bool" => "boolean",
557 "Vec" => "array",
558 "HashMap" | "BTreeMap" => "object",
559 _ => "object", }
561 .to_string();
562 }
563 }
564 "object".to_string()
565}
566
567enum ArrayItemType {
569 Primitive(String),
571 Complex(syn::Type),
573}
574
575fn get_array_item_type(ty: &syn::Type) -> Option<ArrayItemType> {
577 if let syn::Type::Path(type_path) = ty {
578 if let Some(segment) = type_path.path.segments.last() {
579 let ident = segment.ident.to_string();
580
581 if ident == "Option" {
583 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
584 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
585 return get_array_item_type(inner_ty);
586 }
587 }
588 }
589
590 if ident == "Vec" {
592 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
593 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
594 let json_type = get_json_type(inner_ty);
595 if matches!(json_type.as_str(), "string" | "integer" | "number" | "boolean") {
597 return Some(ArrayItemType::Primitive(json_type));
598 }
599 return Some(ArrayItemType::Complex(inner_ty.clone()));
601 }
602 }
603 return Some(ArrayItemType::Primitive("string".to_string()));
605 }
606 }
607 }
608 None
609}
610
611#[cfg(test)]
612mod tests {
613 use super::*;
614
615 #[test]
616 fn test_snake_case_conversion() {
617 assert_eq!(convert_to_snake_case("CalculatorTool"), "calculator_tool");
618 assert_eq!(convert_to_snake_case("GetWeather"), "get_weather");
619 assert_eq!(convert_to_snake_case("HTTPClient"), "h_t_t_p_client");
620 assert_eq!(convert_to_snake_case("Simple"), "simple");
621 }
622}