mini_langchain_macros/
lib.rs1use proc_macro::TokenStream;
2use quote::{quote, ToTokens};
3use syn::{parse_macro_input, AttributeArgs, ItemFn, NestedMeta, Meta, Lit, Pat, FnArg};
4use proc_macro_crate::{crate_name, FoundCrate};
5
6#[proc_macro_attribute]
10pub fn tool(attr: TokenStream, item: TokenStream) -> TokenStream {
11 let args = parse_macro_input!(attr as AttributeArgs);
13 let input_fn = parse_macro_input!(item as ItemFn);
14 let mut name_override: Option<String> = None;
16 let mut description: Option<String> = None;
17 let mut params_meta: Vec<(String, String)> = Vec::new();
18
19 for nested in args.into_iter() {
20 match nested {
21 NestedMeta::Meta(Meta::NameValue(nv)) => {
22 let ident = nv.path.get_ident().map(|i| i.to_string());
23 if let Some(key) = ident {
24 match nv.lit {
25 Lit::Str(s) => {
26 if key == "name" {
27 name_override = Some(s.value());
28 } else if key == "description" {
29 description = Some(s.value());
30 }
31 }
32 _ => {}
33 }
34 }
35 }
36 NestedMeta::Meta(Meta::List(list)) => {
37 if list.path.is_ident("params") {
38 for nm in list.nested.iter() {
39 match nm {
40 NestedMeta::Meta(Meta::NameValue(nv)) => {
41 if let Some(ident) = nv.path.get_ident() {
42 if let Lit::Str(s) = &nv.lit {
43 params_meta.push((ident.to_string(), s.value()));
44 }
45 }
46 }
47 _ => {}
48 }
49 }
50 }
51 }
52 _ => {}
53 }
54 }
55
56 if description.is_none() {
58 return syn::Error::new_spanned(&input_fn.sig.ident, "tool attribute requires a 'description' = \"...\"")
59 .to_compile_error()
60 .into();
61 }
62
63 let fn_name = input_fn.sig.ident.to_string();
64 let tool_name = name_override.unwrap_or(fn_name.clone());
65 let description = description.unwrap();
66
67 let mut fields = Vec::new();
69 let mut param_names = Vec::new();
70 for input in input_fn.sig.inputs.iter() {
71 match input {
72 FnArg::Typed(pt) => {
73 if let Pat::Ident(pi) = &*pt.pat {
75 let ident = pi.ident.clone();
76 let ty = &*pt.ty;
77 fields.push((ident.clone(), ty.clone()));
78 param_names.push(ident.to_string());
79 } else {
80 return syn::Error::new_spanned(&pt.pat, "unsupported pattern in function parameters; use simple identifiers")
81 .to_compile_error()
82 .into();
83 }
84 }
85 FnArg::Receiver(_) => {
86 return syn::Error::new_spanned(input, "methods with self are not supported; use free functions")
87 .to_compile_error()
88 .into();
89 }
90 }
91 }
92
93 for (k, _v) in params_meta.iter() {
95 if !param_names.contains(k) {
96 return syn::Error::new_spanned(&input_fn.sig.ident, format!("params list contains '{}' but function has no parameter with that name", k))
97 .to_compile_error()
98 .into();
99 }
100 }
101
102 let params_struct_ident = syn::Ident::new(&format!("{}Params", pascal_case(&fn_name)), input_fn.sig.ident.span());
104 let tool_struct_ident = syn::Ident::new(&format!("{}Tool", pascal_case(&fn_name)), input_fn.sig.ident.span());
105
106 let field_defs: Vec<proc_macro2::TokenStream> = fields.iter().map(|(ident, ty)| {
108 quote! { pub #ident: #ty }
109 }).collect();
110
111 let host_crate_root = match crate_name("mini-langchain") {
116 Ok(FoundCrate::Itself) => quote! { crate },
117 Ok(FoundCrate::Name(name)) => {
118 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
119 quote! { ::#ident }
120 }
121 Err(_) => quote! { ::mini_langchain },
122 };
123
124 let mut args_entries = Vec::new();
125 for (ident, _ty) in fields.iter() {
126 let mut desc = String::new();
128 for (k, v) in params_meta.iter() {
129 if k == &ident.to_string() { desc = v.clone(); break; }
130 }
131 if desc.is_empty() {
132 return syn::Error::new_spanned(&input_fn.sig.ident, format!("missing description for parameter '{}' in tool attribute params(...)", ident))
134 .to_compile_error()
135 .into();
136 }
137 let name_lit = syn::LitStr::new(&ident.to_string(), ident.span());
138 let desc_lit = syn::LitStr::new(&desc, ident.span());
139 args_entries.push(quote! {
140 #host_crate_root ::tools::traits::ArgSchema {
141 name: #name_lit.to_string(),
142 arg_type: "string".to_string(),
143 description: #desc_lit.to_string(),
144 required: true,
145 }
146 });
147 }
148
149 let call_args: Vec<proc_macro2::TokenStream> = fields.iter().map(|(ident, _)| {
151 quote! { params.#ident }
152 }).collect();
153
154 let is_async = input_fn.sig.asyncness.is_some();
156
157 let fn_tokens = input_fn.to_token_stream();
159 let fn_ident = input_fn.sig.ident.clone();
160 let params_struct_ident2 = params_struct_ident.clone();
161 let tool_struct_ident2 = tool_struct_ident.clone();
162 let tool_name_lit = syn::LitStr::new(&tool_name, input_fn.sig.ident.span());
163 let description_lit = syn::LitStr::new(&description, input_fn.sig.ident.span());
164
165 let run_body = if is_async {
166 quote! {
167 let params: #params_struct_ident2 = serde_json::from_value(input)
168 .map_err(|e| crate::tools::error::ToolError::ParamsNotMatched(e.to_string()))?;
169 let out = #fn_ident( #(#call_args),* ).await;
170 Ok(out)
171 }
172 } else {
173 quote! {
174 let params: #params_struct_ident2 = serde_json::from_value(input)
175 .map_err(|e| crate::tools::error::ToolError::ParamsNotMatched(e.to_string()))?;
176 let out = #fn_ident( #(#call_args),* );
177 Ok(out)
178 }
179 };
180
181 let host_crate_root = match crate_name("mini-langchain") {
185 Ok(FoundCrate::Itself) => quote! { crate },
186 Ok(FoundCrate::Name(name)) => {
187 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
188 quote! { ::#ident }
189 }
190 Err(_) => quote! { ::mini_langchain },
191 };
192
193 let expanded = quote! {
194 #fn_tokens
195
196 #[derive(serde::Deserialize)]
197 pub struct #params_struct_ident2 {
198 #(#field_defs,)*
199 }
200
201 pub struct #tool_struct_ident2;
202 #[async_trait::async_trait]
203 impl #host_crate_root ::tools::traits::Tool for #tool_struct_ident2 {
204 fn name(&self) -> &str { #tool_name_lit }
205 fn description(&self) -> &str { #description_lit }
206 fn args(&self) -> Vec<#host_crate_root ::tools::traits::ArgSchema> {
207 vec![ #(#args_entries),* ]
208 }
209
210 async fn run(&self, input: serde_json::Value) -> Result<String, #host_crate_root ::tools::error::ToolError> {
211 #run_body
212 }
213 }
214 };
215
216 TokenStream::from(expanded)
217}
218
219fn pascal_case(s: &str) -> String {
220 s.split('_').map(|part| {
221 let mut c = part.chars();
222 match c.next() {
223 None => String::new(),
224 Some(f) => f.to_uppercase().collect::<String>() + c.as_str()
225 }
226 }).collect::<Vec<_>>().join("")
227}