Skip to main content

agentlib_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{quote, format_ident};
3use syn::{
4    parse_macro_input, ItemStruct, ItemImpl, ImplItem, Meta, Expr, ExprLit, Lit,
5    FnArg, Pat, punctuated::Punctuated, Token,
6    parse::Parser,
7};
8use heck::ToUpperCamelCase;
9use std::collections::HashMap;
10
11fn get_meta_map(meta: &Punctuated<Meta, Token![, ]>) -> HashMap<String, String> {
12    let mut map = HashMap::new();
13    for m in meta {
14        if let Meta::NameValue(nv) = m {
15            if let Some(ident) = nv.path.get_ident() {
16                if let Expr::Lit(ExprLit { lit: Lit::Str(ls), .. }) = &nv.value {
17                    map.insert(ident.to_string(), ls.value());
18                }
19            }
20        }
21    }
22    map
23}
24
25#[proc_macro_attribute]
26pub fn agent(args: TokenStream, input: TokenStream) -> TokenStream {
27    let args_parsed = Punctuated::<Meta, Token![,]>::parse_terminated.parse(args).unwrap_or_default();
28    let attr_map = get_meta_map(&args_parsed);
29    
30    let input = parse_macro_input!(input as ItemStruct);
31    let ident = &input.ident;
32    let name = attr_map.get("name").cloned().unwrap_or_else(|| ident.to_string());
33    let system_prompt = attr_map.get("system_prompt").cloned();
34
35    let system_prompt_tokens = if let Some(prompt) = system_prompt {
36        quote! { Some(#prompt) }
37    } else {
38        quote! { None }
39    };
40
41    let expanded = quote! {
42        #[derive(Clone)]
43        #input
44
45        impl agentlib_core::Agent for #ident {
46            fn name(&self) -> &str { #name }
47            fn system_prompt(&self) -> Option<&str> { #system_prompt_tokens }
48            fn register_tools(self: std::sync::Arc<Self>, registry: &mut agentlib_core::ToolRegistry) {
49                self.register_tools_internal(registry);
50            }
51        }
52    };
53
54    TokenStream::from(expanded)
55}
56
57#[proc_macro_attribute]
58pub fn tool(_args: TokenStream, input: TokenStream) -> TokenStream {
59    match syn::parse::<ItemImpl>(input.clone()) {
60        Ok(mut item_impl) => {
61            let mut tool_structs = Vec::new();
62            let mut registration_calls = Vec::new();
63            let self_ty = &item_impl.self_ty;
64
65            for item in &mut item_impl.items {
66                if let ImplItem::Fn(method) = item {
67                    let mut is_tool = false;
68                    let mut tool_meta = HashMap::new();
69                    let mut tool_attr_idx = None;
70
71                    for (idx, attr) in method.attrs.iter().enumerate() {
72                        if attr.path().is_ident("tool") {
73                            is_tool = true;
74                            tool_attr_idx = Some(idx);
75                            if let Meta::List(meta_list) = &attr.meta {
76                                if let Ok(nested) = meta_list.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated) {
77                                    tool_meta = get_meta_map(&nested);
78                                }
79                            }
80                        }
81                    }
82
83                    if is_tool {
84                        if let Some(idx) = tool_attr_idx {
85                            method.attrs.remove(idx);
86                        }
87
88                        method.vis = syn::Visibility::Public(syn::token::Pub::default());
89
90                        let method_name = &method.sig.ident;
91                        let tool_name = tool_meta.get("name").cloned().unwrap_or_else(|| method_name.to_string());
92                        let tool_desc = tool_meta.get("description").cloned();
93                        let tool_desc_tokens = if let Some(d) = tool_desc { quote! { Some(#d) } } else { quote! { None } };
94                        
95                        let struct_name_str = method_name.to_string().to_upper_camel_case() + "Tool";
96                        let tool_struct_name = format_ident!("{}", struct_name_str);
97
98                        let mut props = Vec::new();
99                        let mut arg_deserialization = Vec::new();
100                        let mut call_args = Vec::new();
101                        let mut required_args = Vec::new();
102
103                        for arg in &mut method.sig.inputs {
104                            if let FnArg::Typed(pat_type) = arg {
105                                if let Pat::Ident(pat_id) = &*pat_type.pat {
106                                    let arg_name = &pat_id.ident;
107                                    let arg_name_str = arg_name.to_string();
108                                    
109                                    let mut arg_desc = String::new();
110                                    let mut arg_attr_idx = None;
111                                    for (idx, attr) in pat_type.attrs.iter().enumerate() {
112                                        if attr.path().is_ident("arg") {
113                                            arg_attr_idx = Some(idx);
114                                            if let Meta::List(ml) = &attr.meta {
115                                                if let Ok(nested) = ml.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated) {
116                                                    let arg_meta = get_meta_map(&nested);
117                                                    if let Some(desc) = arg_meta.get("description") {
118                                                        arg_desc = desc.clone();
119                                                    }
120                                                }
121                                            }
122                                        }
123                                    }
124
125                                    if let Some(idx) = arg_attr_idx {
126                                        pat_type.attrs.remove(idx);
127                                    }
128
129                                    props.push(quote! {
130                                        #arg_name_str: agentlib_core::serde_json::json!({
131                                            "type": "string",
132                                            "description": #arg_desc
133                                        })
134                                    });
135
136                                    required_args.push(arg_name_str.clone());
137
138                                    arg_deserialization.push(quote! {
139                                        let #arg_name = arguments.get(#arg_name_str)
140                                            .and_then(|v| agentlib_core::serde_json::from_value(v.clone()).ok())
141                                            .ok_or_else(|| anyhow::anyhow!("Missing or invalid argument: {}", #arg_name_str))?;
142                                    });
143
144                                    call_args.push(quote! { #arg_name });
145                                }
146                            }
147                        }
148
149                        tool_structs.push(quote! {
150                            pub struct #tool_struct_name {
151                                pub agent: std::sync::Arc<#self_ty>,
152                            }
153
154                            #[agentlib_core::async_trait]
155                            impl agentlib_core::Tool for #tool_struct_name {
156                                fn name(&self) -> &str { #tool_name }
157                                fn description(&self) -> Option<&str> { #tool_desc_tokens }
158                                fn parameters(&self) -> agentlib_core::serde_json::Value {
159                                    agentlib_core::serde_json::json!({
160                                        "type": "object",
161                                        "properties": {
162                                            #(#props),*
163                                        },
164                                        "required": [ #(#required_args),* ]
165                                    })
166                                }
167
168                                async fn call(&self, arguments: agentlib_core::serde_json::Value) -> anyhow::Result<agentlib_core::serde_json::Value> {
169                                    #(#arg_deserialization)*
170                                    let result = self.agent.#method_name(#(#call_args),*).await;
171                                    Ok(agentlib_core::serde_json::to_value(result)?)
172                                }
173                            }
174                        });
175
176                        registration_calls.push(quote! {
177                            registry.register(std::sync::Arc::new(#tool_struct_name { agent: self.clone() }));
178                        });
179                    }
180                }
181            }
182
183            let expanded = quote! {
184                #item_impl
185
186                impl #self_ty {
187                    pub fn register_tools_internal(self: std::sync::Arc<Self>, registry: &mut agentlib_core::ToolRegistry) {
188                        #(#registration_calls)*
189                    }
190                }
191
192                #(#tool_structs)*
193            };
194            TokenStream::from(expanded)
195        }
196        Err(_) => input,
197    }
198}
199
200#[proc_macro_attribute]
201pub fn arg(_args: TokenStream, input: TokenStream) -> TokenStream {
202    input
203}