Skip to main content

ds_api_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::{Span, TokenStream as TokenStream2};
3use quote::quote;
4use syn::{Expr, FnArg, ImplItem, ItemFn, ItemImpl, Lit, Meta, Pat, Type, parse_macro_input};
5
6fn extract_doc(attrs: &[syn::Attribute]) -> Vec<String> {
7    attrs
8        .iter()
9        .filter_map(|attr| {
10            if !attr.path().is_ident("doc") {
11                return None;
12            }
13            if let Meta::NameValue(nv) = &attr.meta
14                && let Expr::Lit(el) = &nv.value
15                && let Lit::Str(s) = &el.lit
16            {
17                return Some(s.value().trim().to_string());
18            }
19            None
20        })
21        .collect()
22}
23
24fn parse_doc(lines: &[String]) -> (String, std::collections::HashMap<String, String>) {
25    let mut desc_lines = vec![];
26    let mut params = std::collections::HashMap::new();
27    for line in lines {
28        if line.is_empty() {
29            continue;
30        }
31        if let Some((key, val)) = line.split_once(':') {
32            let key = key.trim().to_string();
33            let val = val.trim().to_string();
34            if key.chars().all(|c| c.is_alphanumeric() || c == '_') && !val.is_empty() {
35                params.insert(key, val);
36                continue;
37            }
38        }
39        if params.is_empty() {
40            desc_lines.push(line.clone());
41        }
42    }
43    (desc_lines.join(" ").trim().to_string(), params)
44}
45
46/// Recursively map a `syn::Type` to a JSON Schema snippet.
47///
48/// Matches on the *structure* of the type rather than its string representation,
49/// so path aliases (`std::string::String`), references (`&str`), and generic
50/// wrappers (`Option<T>`, `Vec<T>`) all resolve correctly.
51fn type_to_json_schema(ty: &Type) -> TokenStream2 {
52    match ty {
53        // &str, &String, &T — strip the reference and recurse
54        Type::Reference(r) => type_to_json_schema(&r.elem),
55
56        Type::Path(tp) => {
57            // Only look at the final path segment so that
58            // `std::string::String` and `String` both work.
59            let seg = match tp.path.segments.last() {
60                Some(s) => s,
61                None => return unsupported(ty),
62            };
63
64            match seg.ident.to_string().as_str() {
65                "String" | "str" => quote!(serde_json::json!({"type": "string"})),
66                "bool" => quote!(serde_json::json!({"type": "boolean"})),
67                "f32" | "f64" => quote!(serde_json::json!({"type": "number"})),
68                "u8" | "u16" | "u32" | "u64" | "u128" | "usize" | "i8" | "i16" | "i32" | "i64"
69                | "i128" | "isize" => {
70                    quote!(serde_json::json!({"type": "integer"}))
71                }
72                // Option<T> — recurse into T
73                "Option" => match inner_type_arg(seg) {
74                    Some(inner) => type_to_json_schema(inner),
75                    None => unsupported(ty),
76                },
77                // Vec<T> — recurse into T for the items schema
78                "Vec" => match inner_type_arg(seg) {
79                    Some(inner) => {
80                        let items = type_to_json_schema(inner);
81                        quote!(serde_json::json!({"type": "array", "items": #items}))
82                    }
83                    None => unsupported(ty),
84                },
85                _ => unsupported(ty),
86            }
87        }
88
89        _ => unsupported(ty),
90    }
91}
92
93/// Emit a compile-time error pointing at the offending type.
94fn unsupported(ty: &Type) -> TokenStream2 {
95    syn::Error::new_spanned(
96        ty,
97        "unsupported type in #[tool]: use String, bool, f32/f64, \
98         an integer primitive, Vec<T>, or Option<T>",
99    )
100    .to_compile_error()
101}
102
103/// Extract the first generic type argument from a path segment, e.g. the `T`
104/// in `Option<T>` or `Vec<T>`.
105fn inner_type_arg(seg: &syn::PathSegment) -> Option<&Type> {
106    if let syn::PathArguments::AngleBracketed(args) = &seg.arguments
107        && let Some(syn::GenericArgument::Type(ty)) = args.args.first()
108    {
109        return Some(ty);
110    }
111    None
112}
113
114fn is_option(ty: &Type) -> bool {
115    if let Type::Path(tp) = ty
116        && let Some(seg) = tp.path.segments.last()
117    {
118        return seg.ident == "Option";
119    }
120    false
121}
122
123struct ToolMethod {
124    tool_name: String,
125    description: String,
126    params: Vec<ParamInfo>,
127    body: syn::Block,
128}
129
130struct ParamInfo {
131    name: String,
132    ty: Type,
133    desc: String,
134    optional: bool,
135}
136
137#[proc_macro_attribute]
138pub fn tool(attr: TokenStream, item: TokenStream) -> TokenStream {
139    // 先尝试解析为独立 async fn
140    if let Ok(item_fn) = syn::parse::<ItemFn>(item.clone()) {
141        if item_fn.sig.asyncness.is_some() {
142            return tool_from_fn(attr, item_fn);
143        }
144    }
145    // 否则走 impl 块路径
146    tool_from_impl(attr, item)
147}
148
149fn tool_from_fn(attr: TokenStream, item_fn: ItemFn) -> TokenStream {
150    let fn_name = item_fn.sig.ident.to_string();
151    let struct_ident = item_fn.sig.ident.clone();
152
153    let override_name: Option<String> = if !attr.is_empty() {
154        let s = TokenStream2::from(attr).to_string();
155        s.find('"').and_then(|start| {
156            s.rfind('"')
157                .filter(|&end| end > start)
158                .map(|end| s[start + 1..end].to_string())
159        })
160    } else {
161        None
162    };
163
164    let tool_name = override_name.unwrap_or_else(|| fn_name.clone());
165    let doc_lines = extract_doc(&item_fn.attrs);
166    let (description, param_docs) = parse_doc(&doc_lines);
167
168    let mut params = vec![];
169    for arg in &item_fn.sig.inputs {
170        if let FnArg::Typed(pt) = arg {
171            let name = if let Pat::Ident(pi) = &*pt.pat {
172                pi.ident.to_string()
173            } else {
174                continue;
175            };
176            let ty = (*pt.ty).clone();
177            let desc = param_docs.get(&name).cloned().unwrap_or_default();
178            let optional = is_option(&ty);
179            params.push(ParamInfo {
180                name,
181                ty,
182                desc,
183                optional,
184            });
185        }
186    }
187
188    let method = ToolMethod {
189        tool_name,
190        description,
191        params,
192        body: *item_fn.block,
193    };
194
195    let raw_tools_body = {
196        let tool_name = &method.tool_name;
197        let description = &method.description;
198        let prop_inserts = method.params.iter().map(|p| {
199            let pname = &p.name;
200            let pdesc = &p.desc;
201            let schema = type_to_json_schema(&p.ty);
202            quote! {{
203                let mut prop = #schema;
204                prop["description"] = serde_json::json!(#pdesc);
205                properties.insert(#pname.to_string(), prop);
206            }}
207        });
208        let required: Vec<&str> = method
209            .params
210            .iter()
211            .filter(|p| !p.optional)
212            .map(|p| p.name.as_str())
213            .collect();
214        quote! {{
215            let mut properties = serde_json::Map::new();
216            #(#prop_inserts)*
217            let required: Vec<&str> = vec![#(#required),*];
218            ds_api::raw::request::tool::Tool {
219                r#type: ds_api::raw::request::message::ToolType::Function,
220                function: ds_api::raw::request::tool::Function {
221                    name: #tool_name.to_string(),
222                    description: Some(#description.to_string()),
223                    parameters: serde_json::json!({
224                        "type": "object",
225                        "properties": properties,
226                        "required": required,
227                    }),
228                    strict: None,
229                },
230            }
231        }}
232    };
233
234    let call_arm = {
235        let tool_name = &method.tool_name;
236        let body = &method.body;
237        let arg_parses = method.params.iter().map(|p| {
238            let pname = syn::Ident::new(&p.name, Span::call_site());
239            let pname_str = &p.name;
240            let ty = &p.ty;
241            quote! {
242                let #pname: #ty = match serde_json::from_value(
243                    args.get(#pname_str).cloned().unwrap_or(serde_json::Value::Null)
244                ) {
245                    Ok(v) => v,
246                    Err(e) => return serde_json::json!({
247                        "error": format!("invalid argument '{}': {}", #pname_str, e)
248                    }),
249                };
250            }
251        });
252        quote! {
253            #tool_name => {
254                #(#arg_parses)*
255                let __result = (async move || { #body })().await;
256                match serde_json::to_value(__result) {
257                    Ok(v) => v,
258                    Err(e) => serde_json::json!({ "error": format!("serialization error: {}", e) }),
259                }
260            }
261        }
262    };
263
264    let expanded = quote! {
265        #[allow(non_camel_case_types)]
266        pub struct #struct_ident;
267
268        #[async_trait::async_trait]
269        impl ds_api::tool_trait::Tool for #struct_ident {
270            fn raw_tools(&self) -> Vec<ds_api::raw::request::tool::Tool> {
271                vec![#raw_tools_body]
272            }
273
274            async fn call(&self, name: &str, args: serde_json::Value) -> serde_json::Value {
275                match name {
276                    #call_arm
277                    _ => serde_json::json!({"error": format!("unknown tool: {}", name)}),
278                }
279            }
280        }
281    };
282
283    expanded.into()
284}
285
286fn tool_from_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
287    let item_impl = parse_macro_input!(item as ItemImpl);
288
289    let override_name: Option<String> = if !attr.is_empty() {
290        let s = TokenStream2::from(attr).to_string();
291        s.find('"').and_then(|start| {
292            s.rfind('"')
293                .filter(|&end| end > start)
294                .map(|end| s[start + 1..end].to_string())
295        })
296    } else {
297        None
298    };
299
300    let mut tool_methods: Vec<ToolMethod> = vec![];
301
302    for item in &item_impl.items {
303        if let ImplItem::Fn(method) = item {
304            if method.sig.asyncness.is_none() {
305                continue;
306            }
307            let fn_name = method.sig.ident.to_string();
308            let tool_name = override_name.clone().unwrap_or_else(|| fn_name.clone());
309            let doc_lines = extract_doc(&method.attrs);
310            let (description, param_docs) = parse_doc(&doc_lines);
311
312            let mut params = vec![];
313            for arg in &method.sig.inputs {
314                if let FnArg::Typed(pt) = arg {
315                    let name = if let Pat::Ident(pi) = &*pt.pat {
316                        pi.ident.to_string()
317                    } else {
318                        continue;
319                    };
320                    let ty = (*pt.ty).clone();
321                    let desc = param_docs.get(&name).cloned().unwrap_or_default();
322                    let optional = is_option(&ty);
323                    params.push(ParamInfo {
324                        name,
325                        ty,
326                        desc,
327                        optional,
328                    });
329                }
330            }
331            tool_methods.push(ToolMethod {
332                tool_name,
333                description,
334                params,
335                body: method.block.clone(),
336            });
337        }
338    }
339
340    let raw_tools_body = tool_methods.iter().map(|m| {
341        let tool_name = &m.tool_name;
342        let description = &m.description;
343        let prop_inserts = m.params.iter().map(|p| {
344            let pname = &p.name;
345            let pdesc = &p.desc;
346            let schema = type_to_json_schema(&p.ty);
347            quote! {{
348                let mut prop = #schema;
349                prop["description"] = serde_json::json!(#pdesc);
350                properties.insert(#pname.to_string(), prop);
351            }}
352        });
353        let required: Vec<&str> = m
354            .params
355            .iter()
356            .filter(|p| !p.optional)
357            .map(|p| p.name.as_str())
358            .collect();
359        quote! {{
360            let mut properties = serde_json::Map::new();
361            #(#prop_inserts)*
362            let required: Vec<&str> = vec![#(#required),*];
363            ds_api::raw::request::tool::Tool {
364                r#type: ds_api::raw::request::message::ToolType::Function,
365                function: ds_api::raw::request::tool::Function {
366                    name: #tool_name.to_string(),
367                    description: Some(#description.to_string()),
368                    parameters: serde_json::json!({
369                        "type": "object",
370                        "properties": properties,
371                        "required": required,
372                    }),
373                    strict: None,
374                },
375            }
376        }}
377    });
378
379    let call_arms = tool_methods.iter().map(|m| {
380        let tool_name = &m.tool_name;
381        let body = &m.body;
382        let arg_parses = m.params.iter().map(|p| {
383            let pname = syn::Ident::new(&p.name, Span::call_site());
384            let pname_str = &p.name;
385            let ty = &p.ty;
386            quote! {
387                let #pname: #ty = match serde_json::from_value(
388                    args.get(#pname_str).cloned().unwrap_or(serde_json::Value::Null)
389                ) {
390                    Ok(v) => v,
391                    Err(e) => return serde_json::json!({
392                        "error": format!("invalid argument '{}': {}", #pname_str, e)
393                    }),
394                };
395            }
396        });
397        quote! {
398            #tool_name => {
399                #(#arg_parses)*
400                let __result = (async move || { #body })().await;
401                match serde_json::to_value(__result) {
402                    Ok(v) => v,
403                    Err(e) => serde_json::json!({ "error": format!("serialization error: {}", e) }),
404                }
405            }
406        }
407    });
408
409    let self_ty = &item_impl.self_ty;
410
411    let expanded = quote! {
412        #[async_trait::async_trait]
413        impl ds_api::tool_trait::Tool for #self_ty {
414            fn raw_tools(&self) -> Vec<ds_api::raw::request::tool::Tool> {
415                vec![#(#raw_tools_body),*]
416            }
417
418            async fn call(&self, name: &str, args: serde_json::Value) -> serde_json::Value {
419                match name {
420                    #(#call_arms)*
421                    _ => serde_json::json!({"error": format!("unknown tool: {}", name)}),
422                }
423            }
424        }
425    };
426
427    expanded.into()
428}