1use proc_macro::TokenStream;
2use quote::{quote, format_ident};
3use syn::{
4 parse_macro_input, ItemStruct, ItemImpl, ImplItem, Meta, Expr, ExprLit, Lit,
5 FnArg, Pat, punctuated::Punctuated, Token,
6 parse::Parser,
7};
8use heck::ToUpperCamelCase;
9use std::collections::HashMap;
10
11fn get_meta_map(meta: &Punctuated<Meta, Token![, ]>) -> HashMap<String, String> {
12 let mut map = HashMap::new();
13 for m in meta {
14 if let Meta::NameValue(nv) = m {
15 if let Some(ident) = nv.path.get_ident() {
16 if let Expr::Lit(ExprLit { lit: Lit::Str(ls), .. }) = &nv.value {
17 map.insert(ident.to_string(), ls.value());
18 }
19 }
20 }
21 }
22 map
23}
24
25#[proc_macro_attribute]
26pub fn agent(args: TokenStream, input: TokenStream) -> TokenStream {
27 let args_parsed = Punctuated::<Meta, Token![,]>::parse_terminated.parse(args).unwrap_or_default();
28 let attr_map = get_meta_map(&args_parsed);
29
30 let input = parse_macro_input!(input as ItemStruct);
31 let ident = &input.ident;
32 let name = attr_map.get("name").cloned().unwrap_or_else(|| ident.to_string());
33 let system_prompt = attr_map.get("system_prompt").cloned();
34
35 let system_prompt_tokens = if let Some(prompt) = system_prompt {
36 quote! { Some(#prompt) }
37 } else {
38 quote! { None }
39 };
40
41 let expanded = quote! {
42 #[derive(Clone)]
43 #input
44
45 impl agentlib_core::Agent for #ident {
46 fn name(&self) -> &str { #name }
47 fn system_prompt(&self) -> Option<&str> { #system_prompt_tokens }
48 fn register_tools(self: std::sync::Arc<Self>, registry: &mut agentlib_core::ToolRegistry) {
49 self.register_tools_internal(registry);
50 }
51 }
52 };
53
54 TokenStream::from(expanded)
55}
56
57#[proc_macro_attribute]
58pub fn tool(_args: TokenStream, input: TokenStream) -> TokenStream {
59 match syn::parse::<ItemImpl>(input.clone()) {
60 Ok(mut item_impl) => {
61 let mut tool_structs = Vec::new();
62 let mut registration_calls = Vec::new();
63 let self_ty = &item_impl.self_ty;
64
65 for item in &mut item_impl.items {
66 if let ImplItem::Fn(method) = item {
67 let mut is_tool = false;
68 let mut tool_meta = HashMap::new();
69 let mut tool_attr_idx = None;
70
71 for (idx, attr) in method.attrs.iter().enumerate() {
72 if attr.path().is_ident("tool") {
73 is_tool = true;
74 tool_attr_idx = Some(idx);
75 if let Meta::List(meta_list) = &attr.meta {
76 if let Ok(nested) = meta_list.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated) {
77 tool_meta = get_meta_map(&nested);
78 }
79 }
80 }
81 }
82
83 if is_tool {
84 if let Some(idx) = tool_attr_idx {
85 method.attrs.remove(idx);
86 }
87
88 method.vis = syn::Visibility::Public(syn::token::Pub::default());
89
90 let method_name = &method.sig.ident;
91 let tool_name = tool_meta.get("name").cloned().unwrap_or_else(|| method_name.to_string());
92 let tool_desc = tool_meta.get("description").cloned();
93 let tool_desc_tokens = if let Some(d) = tool_desc { quote! { Some(#d) } } else { quote! { None } };
94
95 let struct_name_str = method_name.to_string().to_upper_camel_case() + "Tool";
96 let tool_struct_name = format_ident!("{}", struct_name_str);
97
98 let mut props = Vec::new();
99 let mut arg_deserialization = Vec::new();
100 let mut call_args = Vec::new();
101 let mut required_args = Vec::new();
102
103 for arg in &mut method.sig.inputs {
104 if let FnArg::Typed(pat_type) = arg {
105 if let Pat::Ident(pat_id) = &*pat_type.pat {
106 let arg_name = &pat_id.ident;
107 let arg_name_str = arg_name.to_string();
108
109 let mut arg_desc = String::new();
110 let mut arg_attr_idx = None;
111 for (idx, attr) in pat_type.attrs.iter().enumerate() {
112 if attr.path().is_ident("arg") {
113 arg_attr_idx = Some(idx);
114 if let Meta::List(ml) = &attr.meta {
115 if let Ok(nested) = ml.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated) {
116 let arg_meta = get_meta_map(&nested);
117 if let Some(desc) = arg_meta.get("description") {
118 arg_desc = desc.clone();
119 }
120 }
121 }
122 }
123 }
124
125 if let Some(idx) = arg_attr_idx {
126 pat_type.attrs.remove(idx);
127 }
128
129 props.push(quote! {
130 #arg_name_str: agentlib_core::serde_json::json!({
131 "type": "string",
132 "description": #arg_desc
133 })
134 });
135
136 required_args.push(arg_name_str.clone());
137
138 arg_deserialization.push(quote! {
139 let #arg_name = arguments.get(#arg_name_str)
140 .and_then(|v| agentlib_core::serde_json::from_value(v.clone()).ok())
141 .ok_or_else(|| anyhow::anyhow!("Missing or invalid argument: {}", #arg_name_str))?;
142 });
143
144 call_args.push(quote! { #arg_name });
145 }
146 }
147 }
148
149 tool_structs.push(quote! {
150 pub struct #tool_struct_name {
151 pub agent: std::sync::Arc<#self_ty>,
152 }
153
154 #[agentlib_core::async_trait]
155 impl agentlib_core::Tool for #tool_struct_name {
156 fn name(&self) -> &str { #tool_name }
157 fn description(&self) -> Option<&str> { #tool_desc_tokens }
158 fn parameters(&self) -> agentlib_core::serde_json::Value {
159 agentlib_core::serde_json::json!({
160 "type": "object",
161 "properties": {
162 #(#props),*
163 },
164 "required": [ #(#required_args),* ]
165 })
166 }
167
168 async fn call(&self, arguments: agentlib_core::serde_json::Value) -> anyhow::Result<agentlib_core::serde_json::Value> {
169 #(#arg_deserialization)*
170 let result = self.agent.#method_name(#(#call_args),*).await;
171 Ok(agentlib_core::serde_json::to_value(result)?)
172 }
173 }
174 });
175
176 registration_calls.push(quote! {
177 registry.register(std::sync::Arc::new(#tool_struct_name { agent: self.clone() }));
178 });
179 }
180 }
181 }
182
183 let expanded = quote! {
184 #item_impl
185
186 impl #self_ty {
187 pub fn register_tools_internal(self: std::sync::Arc<Self>, registry: &mut agentlib_core::ToolRegistry) {
188 #(#registration_calls)*
189 }
190 }
191
192 #(#tool_structs)*
193 };
194 TokenStream::from(expanded)
195 }
196 Err(_) => input,
197 }
198}
199
200#[proc_macro_attribute]
201pub fn arg(_args: TokenStream, input: TokenStream) -> TokenStream {
202 input
203}