Skip to main content

ds_api_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::{Span, TokenStream as TokenStream2};
3use quote::quote;
4use syn::{Expr, FnArg, ImplItem, ItemImpl, Lit, Meta, Pat, Type, parse_macro_input};
5
6fn extract_doc(attrs: &[syn::Attribute]) -> Vec<String> {
7    attrs
8        .iter()
9        .filter_map(|attr| {
10            if !attr.path().is_ident("doc") {
11                return None;
12            }
13            if let Meta::NameValue(nv) = &attr.meta
14                && let Expr::Lit(el) = &nv.value
15                && let Lit::Str(s) = &el.lit
16            {
17                return Some(s.value().trim().to_string());
18            }
19            None
20        })
21        .collect()
22}
23
24fn parse_doc(lines: &[String]) -> (String, std::collections::HashMap<String, String>) {
25    let mut desc_lines = vec![];
26    let mut params = std::collections::HashMap::new();
27    for line in lines {
28        if line.is_empty() {
29            continue;
30        }
31        if let Some((key, val)) = line.split_once(':') {
32            let key = key.trim().to_string();
33            let val = val.trim().to_string();
34            if key.chars().all(|c| c.is_alphanumeric() || c == '_') && !val.is_empty() {
35                params.insert(key, val);
36                continue;
37            }
38        }
39        if params.is_empty() {
40            desc_lines.push(line.clone());
41        }
42    }
43    (desc_lines.join(" ").trim().to_string(), params)
44}
45
46/// Recursively map a `syn::Type` to a JSON Schema snippet.
47///
48/// Matches on the *structure* of the type rather than its string representation,
49/// so path aliases (`std::string::String`), references (`&str`), and generic
50/// wrappers (`Option<T>`, `Vec<T>`) all resolve correctly.
51fn type_to_json_schema(ty: &Type) -> TokenStream2 {
52    match ty {
53        // &str, &String, &T — strip the reference and recurse
54        Type::Reference(r) => type_to_json_schema(&r.elem),
55
56        Type::Path(tp) => {
57            // Only look at the final path segment so that
58            // `std::string::String` and `String` both work.
59            let seg = match tp.path.segments.last() {
60                Some(s) => s,
61                None => return unsupported(ty),
62            };
63
64            match seg.ident.to_string().as_str() {
65                "String" | "str" => quote!(serde_json::json!({"type": "string"})),
66                "bool" => quote!(serde_json::json!({"type": "boolean"})),
67                "f32" | "f64" => quote!(serde_json::json!({"type": "number"})),
68                "u8" | "u16" | "u32" | "u64" | "u128" | "usize" | "i8" | "i16" | "i32" | "i64"
69                | "i128" | "isize" => {
70                    quote!(serde_json::json!({"type": "integer"}))
71                }
72                // Option<T> — recurse into T
73                "Option" => match inner_type_arg(seg) {
74                    Some(inner) => type_to_json_schema(inner),
75                    None => unsupported(ty),
76                },
77                // Vec<T> — recurse into T for the items schema
78                "Vec" => match inner_type_arg(seg) {
79                    Some(inner) => {
80                        let items = type_to_json_schema(inner);
81                        quote!(serde_json::json!({"type": "array", "items": #items}))
82                    }
83                    None => unsupported(ty),
84                },
85                _ => unsupported(ty),
86            }
87        }
88
89        _ => unsupported(ty),
90    }
91}
92
93/// Emit a compile-time error pointing at the offending type.
94fn unsupported(ty: &Type) -> TokenStream2 {
95    syn::Error::new_spanned(
96        ty,
97        "unsupported type in #[tool]: use String, bool, f32/f64, \
98         an integer primitive, Vec<T>, or Option<T>",
99    )
100    .to_compile_error()
101}
102
103/// Extract the first generic type argument from a path segment, e.g. the `T`
104/// in `Option<T>` or `Vec<T>`.
105fn inner_type_arg(seg: &syn::PathSegment) -> Option<&Type> {
106    if let syn::PathArguments::AngleBracketed(args) = &seg.arguments
107        && let Some(syn::GenericArgument::Type(ty)) = args.args.first()
108    {
109        return Some(ty);
110    }
111    None
112}
113
114fn is_option(ty: &Type) -> bool {
115    if let Type::Path(tp) = ty
116        && let Some(seg) = tp.path.segments.last()
117    {
118        return seg.ident == "Option";
119    }
120    false
121}
122
123struct ToolMethod {
124    tool_name: String,
125    description: String,
126    params: Vec<ParamInfo>,
127    body: syn::Block,
128}
129
130struct ParamInfo {
131    name: String,
132    ty: Type,
133    desc: String,
134    optional: bool,
135}
136
137#[proc_macro_attribute]
138pub fn tool(attr: TokenStream, item: TokenStream) -> TokenStream {
139    let item_impl = parse_macro_input!(item as ItemImpl);
140
141    let override_name: Option<String> = if !attr.is_empty() {
142        let s = TokenStream2::from(attr).to_string();
143        s.find('"').and_then(|start| {
144            s.rfind('"')
145                .filter(|&end| end > start)
146                .map(|end| s[start + 1..end].to_string())
147        })
148    } else {
149        None
150    };
151
152    let mut tool_methods: Vec<ToolMethod> = vec![];
153
154    for item in &item_impl.items {
155        if let ImplItem::Fn(method) = item {
156            if method.sig.asyncness.is_none() {
157                continue;
158            }
159            let fn_name = method.sig.ident.to_string();
160            let tool_name = override_name.clone().unwrap_or_else(|| fn_name.clone());
161            let doc_lines = extract_doc(&method.attrs);
162            let (description, param_docs) = parse_doc(&doc_lines);
163
164            let mut params = vec![];
165            for arg in &method.sig.inputs {
166                if let FnArg::Typed(pt) = arg {
167                    let name = if let Pat::Ident(pi) = &*pt.pat {
168                        pi.ident.to_string()
169                    } else {
170                        continue;
171                    };
172                    let ty = (*pt.ty).clone();
173                    let desc = param_docs.get(&name).cloned().unwrap_or_default();
174                    let optional = is_option(&ty);
175                    params.push(ParamInfo {
176                        name,
177                        ty,
178                        desc,
179                        optional,
180                    });
181                }
182            }
183            tool_methods.push(ToolMethod {
184                tool_name,
185                description,
186                params,
187                body: method.block.clone(),
188            });
189        }
190    }
191
192    let raw_tools_body = tool_methods.iter().map(|m| {
193        let tool_name = &m.tool_name;
194        let description = &m.description;
195        let prop_inserts = m.params.iter().map(|p| {
196            let pname = &p.name;
197            let pdesc = &p.desc;
198            let schema = type_to_json_schema(&p.ty);
199            quote! {{
200                let mut prop = #schema;
201                prop["description"] = serde_json::json!(#pdesc);
202                properties.insert(#pname.to_string(), prop);
203            }}
204        });
205        let required: Vec<&str> = m
206            .params
207            .iter()
208            .filter(|p| !p.optional)
209            .map(|p| p.name.as_str())
210            .collect();
211        quote! {{
212            let mut properties = serde_json::Map::new();
213            #(#prop_inserts)*
214            let required: Vec<&str> = vec![#(#required),*];
215            ds_api::raw::request::tool::Tool {
216                r#type: ds_api::raw::request::message::ToolType::Function,
217                function: ds_api::raw::request::tool::Function {
218                    name: #tool_name.to_string(),
219                    description: Some(#description.to_string()),
220                    parameters: serde_json::json!({
221                        "type": "object",
222                        "properties": properties,
223                        "required": required,
224                    }),
225                    strict: None,
226                },
227            }
228        }}
229    });
230
231    let call_arms = tool_methods.iter().map(|m| {
232        let tool_name = &m.tool_name;
233        let body = &m.body;
234        let arg_parses = m.params.iter().map(|p| {
235            let pname = syn::Ident::new(&p.name, Span::call_site());
236            let pname_str = &p.name;
237            let ty = &p.ty;
238            quote! {
239                let #pname: #ty = match serde_json::from_value(
240                    args.get(#pname_str).cloned().unwrap_or(serde_json::Value::Null)
241                ) {
242                    Ok(v) => v,
243                    Err(e) => return serde_json::json!({
244                        "error": format!("invalid argument '{}': {}", #pname_str, e)
245                    }),
246                };
247            }
248        });
249        quote! {
250            #tool_name => {
251                #(#arg_parses)*
252                let __result = { #body };
253                match serde_json::to_value(__result) {
254                    Ok(v) => v,
255                    Err(e) => serde_json::json!({ "error": format!("serialization error: {}", e) }),
256                }
257            }
258        }
259    });
260
261    let self_ty = &item_impl.self_ty;
262
263    let expanded = quote! {
264        #[async_trait::async_trait]
265        impl ds_api::tool_trait::Tool for #self_ty {
266            fn raw_tools(&self) -> Vec<ds_api::raw::request::tool::Tool> {
267                vec![#(#raw_tools_body),*]
268            }
269
270            async fn call(&self, name: &str, args: serde_json::Value) -> serde_json::Value {
271                match name {
272                    #(#call_arms)*
273                    _ => serde_json::json!({"error": format!("unknown tool: {}", name)}),
274                }
275            }
276        }
277    };
278
279    expanded.into()
280}