mini_langchain_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{quote, ToTokens};
3use syn::{parse_macro_input, AttributeArgs, ItemFn, NestedMeta, Meta, Lit, Pat, FnArg};
4use proc_macro_crate::{crate_name, FoundCrate};
5
6/// Minimal attribute macro #[tool(...)]
7/// Supports: name (optional), description (required), params(...)
8/// params syntax: params( city = "desc", units = "desc" )
9#[proc_macro_attribute]
10pub fn tool(attr: TokenStream, item: TokenStream) -> TokenStream {
11    // parse the attribute arguments
12    let args = parse_macro_input!(attr as AttributeArgs);
13    let input_fn = parse_macro_input!(item as ItemFn);
14    // collect name and description and params map
15    let mut name_override: Option<String> = None;
16    let mut description: Option<String> = None;
17    let mut params_meta: Vec<(String, String)> = Vec::new();
18
19    for nested in args.into_iter() {
20        match nested {
21            NestedMeta::Meta(Meta::NameValue(nv)) => {
22                let ident = nv.path.get_ident().map(|i| i.to_string());
23                if let Some(key) = ident {
24                    match nv.lit {
25                        Lit::Str(s) => {
26                            if key == "name" {
27                                name_override = Some(s.value());
28                            } else if key == "description" {
29                                description = Some(s.value());
30                            }
31                        }
32                        _ => {}
33                    }
34                }
35            }
36            NestedMeta::Meta(Meta::List(list)) => {
37                if list.path.is_ident("params") {
38                    for nm in list.nested.iter() {
39                        match nm {
40                            NestedMeta::Meta(Meta::NameValue(nv)) => {
41                                if let Some(ident) = nv.path.get_ident() {
42                                    if let Lit::Str(s) = &nv.lit {
43                                        params_meta.push((ident.to_string(), s.value()));
44                                    }
45                                }
46                            }
47                            _ => {}
48                        }
49                    }
50                }
51            }
52            _ => {}
53        }
54    }
55
56    // ensure description exists
57    if description.is_none() {
58        return syn::Error::new_spanned(&input_fn.sig.ident, "tool attribute requires a 'description' = \"...\"")
59            .to_compile_error()
60            .into();
61    }
62
63    let fn_name = input_fn.sig.ident.to_string();
64    let tool_name = name_override.unwrap_or(fn_name.clone());
65    let description = description.unwrap();
66
67    // build Params struct fields from function signature
68    let mut fields = Vec::new();
69    let mut param_names = Vec::new();
70    for input in input_fn.sig.inputs.iter() {
71        match input {
72            FnArg::Typed(pt) => {
73                // pattern must be an identifier
74                if let Pat::Ident(pi) = &*pt.pat {
75                    let ident = pi.ident.clone();
76                    let ty = &*pt.ty;
77                    fields.push((ident.clone(), ty.clone()));
78                    param_names.push(ident.to_string());
79                } else {
80                    return syn::Error::new_spanned(&pt.pat, "unsupported pattern in function parameters; use simple identifiers")
81                        .to_compile_error()
82                        .into();
83                }
84            }
85            FnArg::Receiver(_) => {
86                return syn::Error::new_spanned(input, "methods with self are not supported; use free functions")
87                    .to_compile_error()
88                    .into();
89            }
90        }
91    }
92
93    // check params_meta keys match function params
94    for (k, _v) in params_meta.iter() {
95        if !param_names.contains(k) {
96            return syn::Error::new_spanned(&input_fn.sig.ident, format!("params list contains '{}' but function has no parameter with that name", k))
97                .to_compile_error()
98                .into();
99        }
100    }
101
102    // create Params struct identifier
103    let params_struct_ident = syn::Ident::new(&format!("{}Params", pascal_case(&fn_name)), input_fn.sig.ident.span());
104    let tool_struct_ident = syn::Ident::new(&format!("{}Tool", pascal_case(&fn_name)), input_fn.sig.ident.span());
105
106    // generate fields tokens
107    let field_defs: Vec<proc_macro2::TokenStream> = fields.iter().map(|(ident, ty)| {
108        quote! { pub #ident: #ty }
109    }).collect();
110
111    // build ArgSchema vector tokens
112    // figure out how the host crate refers to our library (it might be `crate` when
113    // expanding inside the library itself, or an external name when used from examples
114    // or other crates). Use `proc_macro_crate` to discover the correct root.
115    let host_crate_root = match crate_name("mini-langchain") {
116        Ok(FoundCrate::Itself) => quote! { crate },
117        Ok(FoundCrate::Name(name)) => {
118            let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
119            quote! { ::#ident }
120        }
121        Err(_) => quote! { ::mini_langchain },
122    };
123
124    let mut args_entries = Vec::new();
125    for (ident, _ty) in fields.iter() {
126        // find description for this param
127        let mut desc = String::new();
128        for (k, v) in params_meta.iter() {
129            if k == &ident.to_string() { desc = v.clone(); break; }
130        }
131        if desc.is_empty() {
132            // emit error: description missing
133            return syn::Error::new_spanned(&input_fn.sig.ident, format!("missing description for parameter '{}' in tool attribute params(...)", ident))
134                .to_compile_error()
135                .into();
136        }
137        let name_lit = syn::LitStr::new(&ident.to_string(), ident.span());
138        let desc_lit = syn::LitStr::new(&desc, ident.span());
139        args_entries.push(quote! {
140            #host_crate_root ::tools::traits::ArgSchema {
141                name: #name_lit.to_string(),
142                arg_type: "string".to_string(),
143                description: #desc_lit.to_string(),
144                required: true,
145            }
146        });
147    }
148
149    // prepare calling the original function: collect param idents
150    let call_args: Vec<proc_macro2::TokenStream> = fields.iter().map(|(ident, _)| {
151        quote! { params.#ident }
152    }).collect();
153
154    // determine if the original function is async
155    let is_async = input_fn.sig.asyncness.is_some();
156
157    // generate output tokens
158    let fn_tokens = input_fn.to_token_stream();
159    let fn_ident = input_fn.sig.ident.clone();
160    let params_struct_ident2 = params_struct_ident.clone();
161    let tool_struct_ident2 = tool_struct_ident.clone();
162    let tool_name_lit = syn::LitStr::new(&tool_name, input_fn.sig.ident.span());
163    let description_lit = syn::LitStr::new(&description, input_fn.sig.ident.span());
164
165    let run_body = if is_async {
166        quote! {
167            let params: #params_struct_ident2 = serde_json::from_value(input)
168                .map_err(|e| crate::tools::error::ToolError::ParamsNotMatched(e.to_string()))?;
169            let out = #fn_ident( #(#call_args),* ).await;
170            Ok(out)
171        }
172    } else {
173        quote! {
174            let params: #params_struct_ident2 = serde_json::from_value(input)
175                .map_err(|e| crate::tools::error::ToolError::ParamsNotMatched(e.to_string()))?;
176            let out = #fn_ident( #(#call_args),* );
177            Ok(out)
178        }
179    };
180
181    // figure out how the host crate refers to our library (it might be `crate` when
182    // expanding inside the library itself, or an external name when used from examples
183    // or other crates). Use `proc_macro_crate` to discover the correct root.
184    let host_crate_root = match crate_name("mini-langchain") {
185        Ok(FoundCrate::Itself) => quote! { crate },
186        Ok(FoundCrate::Name(name)) => {
187            let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
188            quote! { ::#ident }
189        }
190        Err(_) => quote! { ::mini_langchain },
191    };
192
193    let expanded = quote! {
194        #fn_tokens
195
196        #[derive(serde::Deserialize)]
197        pub struct #params_struct_ident2 {
198            #(#field_defs,)*
199        }
200
201        pub struct #tool_struct_ident2;
202        #[async_trait::async_trait]
203        impl #host_crate_root ::tools::traits::Tool for #tool_struct_ident2 {
204            fn name(&self) -> &str { #tool_name_lit }
205            fn description(&self) -> &str { #description_lit }
206            fn args(&self) -> Vec<#host_crate_root ::tools::traits::ArgSchema> {
207                vec![ #(#args_entries),* ]
208            }
209
210            async fn run(&self, input: serde_json::Value) -> Result<String, #host_crate_root ::tools::error::ToolError> {
211                #run_body
212            }
213        }
214    };
215
216    TokenStream::from(expanded)
217}
218
219fn pascal_case(s: &str) -> String {
220    s.split('_').map(|part| {
221        let mut c = part.chars();
222        match c.next() {
223            None => String::new(),
224            Some(f) => f.to_uppercase().collect::<String>() + c.as_str()
225        }
226    }).collect::<Vec<_>>().join("")
227}