Skip to main content

ds_api_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::{Span, TokenStream as TokenStream2};
3use quote::quote;
4use syn::{Expr, FnArg, ImplItem, ItemImpl, Lit, Meta, Pat, Type, parse_macro_input};
5
6fn extract_doc(attrs: &[syn::Attribute]) -> Vec<String> {
7    attrs
8        .iter()
9        .filter_map(|attr| {
10            if !attr.path().is_ident("doc") {
11                return None;
12            }
13            if let Meta::NameValue(nv) = &attr.meta
14                && let Expr::Lit(el) = &nv.value
15                && let Lit::Str(s) = &el.lit
16            {
17                return Some(s.value().trim().to_string());
18            }
19            None
20        })
21        .collect()
22}
23
24fn parse_doc(lines: &[String]) -> (String, std::collections::HashMap<String, String>) {
25    let mut desc_lines = vec![];
26    let mut params = std::collections::HashMap::new();
27    for line in lines {
28        if line.is_empty() {
29            continue;
30        }
31        if let Some((key, val)) = line.split_once(':') {
32            let key = key.trim().to_string();
33            let val = val.trim().to_string();
34            if key.chars().all(|c| c.is_alphanumeric() || c == '_') && !val.is_empty() {
35                params.insert(key, val);
36                continue;
37            }
38        }
39        if params.is_empty() {
40            desc_lines.push(line.clone());
41        }
42    }
43    (desc_lines.join(" ").trim().to_string(), params)
44}
45
46fn type_to_json_schema(ty: &Type) -> TokenStream2 {
47    let type_str = quote!(#ty).to_string().replace(" ", "");
48    match type_str.as_str() {
49        "String" | "&str" => quote!(serde_json::json!({"type": "string"})),
50        "bool" => quote!(serde_json::json!({"type": "boolean"})),
51        "f32" | "f64" => quote!(serde_json::json!({"type": "number"})),
52        s if s.starts_with("Option<") => {
53            let inner = &type_str[7..type_str.len() - 1];
54            match inner {
55                "String" | "&str" => quote!(serde_json::json!({"type": "string"})),
56                "bool" => quote!(serde_json::json!({"type": "boolean"})),
57                "f32" | "f64" => quote!(serde_json::json!({"type": "number"})),
58                _ => quote!(serde_json::json!({"type": "integer"})),
59            }
60        }
61        _ => quote!(serde_json::json!({"type": "integer"})),
62    }
63}
64
65fn is_option(ty: &Type) -> bool {
66    quote!(#ty)
67        .to_string()
68        .replace(" ", "")
69        .starts_with("Option<")
70}
71
72struct ToolMethod {
73    tool_name: String,
74    description: String,
75    params: Vec<ParamInfo>,
76    body: syn::Block,
77}
78
79struct ParamInfo {
80    name: String,
81    ty: Type,
82    desc: String,
83    optional: bool,
84}
85
86#[proc_macro_attribute]
87pub fn tool(attr: TokenStream, item: TokenStream) -> TokenStream {
88    let item_impl = parse_macro_input!(item as ItemImpl);
89
90    let override_name: Option<String> = if !attr.is_empty() {
91        let s = TokenStream2::from(attr).to_string();
92        s.find('"').and_then(|start| {
93            s.rfind('"')
94                .filter(|&end| end > start)
95                .map(|end| s[start + 1..end].to_string())
96        })
97    } else {
98        None
99    };
100
101    let mut tool_methods: Vec<ToolMethod> = vec![];
102
103    for item in &item_impl.items {
104        if let ImplItem::Fn(method) = item {
105            if method.sig.asyncness.is_none() {
106                continue;
107            }
108            let fn_name = method.sig.ident.to_string();
109            let tool_name = override_name.clone().unwrap_or_else(|| fn_name.clone());
110            let doc_lines = extract_doc(&method.attrs);
111            let (description, param_docs) = parse_doc(&doc_lines);
112
113            let mut params = vec![];
114            for arg in &method.sig.inputs {
115                if let FnArg::Typed(pt) = arg {
116                    let name = if let Pat::Ident(pi) = &*pt.pat {
117                        pi.ident.to_string()
118                    } else {
119                        continue;
120                    };
121                    let ty = (*pt.ty).clone();
122                    let desc = param_docs.get(&name).cloned().unwrap_or_default();
123                    let optional = is_option(&ty);
124                    params.push(ParamInfo {
125                        name,
126                        ty,
127                        desc,
128                        optional,
129                    });
130                }
131            }
132            tool_methods.push(ToolMethod {
133                tool_name,
134                description,
135                params,
136                body: method.block.clone(),
137            });
138        }
139    }
140
141    let raw_tools_body = tool_methods.iter().map(|m| {
142        let tool_name = &m.tool_name;
143        let description = &m.description;
144        let prop_inserts = m.params.iter().map(|p| {
145            let pname = &p.name;
146            let pdesc = &p.desc;
147            let schema = type_to_json_schema(&p.ty);
148            quote! {{
149                let mut prop = #schema;
150                prop["description"] = serde_json::json!(#pdesc);
151                properties.insert(#pname.to_string(), prop);
152            }}
153        });
154        let required: Vec<&str> = m
155            .params
156            .iter()
157            .filter(|p| !p.optional)
158            .map(|p| p.name.as_str())
159            .collect();
160        quote! {{
161            let mut properties = serde_json::Map::new();
162            #(#prop_inserts)*
163            let required: Vec<&str> = vec![#(#required),*];
164            ds_api::raw::request::tool::Tool {
165                r#type: ds_api::raw::request::message::ToolType::Function,
166                function: ds_api::raw::request::tool::Function {
167                    name: #tool_name.to_string(),
168                    description: Some(#description.to_string()),
169                    parameters: serde_json::json!({
170                        "type": "object",
171                        "properties": properties,
172                        "required": required,
173                    }),
174                    strict: None,
175                },
176            }
177        }}
178    });
179
180    let call_arms = tool_methods.iter().map(|m| {
181        let tool_name = &m.tool_name;
182        let body = &m.body;
183        let arg_parses = m.params.iter().map(|p| {
184            let pname = syn::Ident::new(&p.name, Span::call_site());
185            let pname_str = &p.name;
186            let ty = &p.ty;
187            quote! {
188                let #pname: #ty = serde_json::from_value(
189                    args.get(#pname_str).cloned().unwrap_or(serde_json::Value::Null)
190                ).expect(concat!("failed to parse param: ", #pname_str));
191            }
192        });
193        quote! {
194            #tool_name => {
195                #(#arg_parses)*
196                { #body }
197            }
198        }
199    });
200
201    let self_ty = &item_impl.self_ty;
202
203    let expanded = quote! {
204        #[async_trait::async_trait]
205        impl ds_api::tool_trait::Tool for #self_ty {
206            fn raw_tools(&self) -> Vec<ds_api::raw::request::tool::Tool> {
207                vec![#(#raw_tools_body),*]
208            }
209
210            async fn call(&self, name: &str, args: serde_json::Value) -> serde_json::Value {
211                match name {
212                    #(#call_arms)*
213                    _ => serde_json::json!({"error": format!("unknown tool: {}", name)}),
214                }
215            }
216        }
217    };
218
219    expanded.into()
220}