1use proc_macro::TokenStream;
2use proc_macro2::{Span, TokenStream as TokenStream2};
3use quote::quote;
4use syn::{Expr, FnArg, ImplItem, ItemImpl, Lit, Meta, Pat, Type, parse_macro_input};
5
6fn extract_doc(attrs: &[syn::Attribute]) -> Vec<String> {
7 attrs
8 .iter()
9 .filter_map(|attr| {
10 if !attr.path().is_ident("doc") {
11 return None;
12 }
13 if let Meta::NameValue(nv) = &attr.meta
14 && let Expr::Lit(el) = &nv.value
15 && let Lit::Str(s) = &el.lit
16 {
17 return Some(s.value().trim().to_string());
18 }
19 None
20 })
21 .collect()
22}
23
24fn parse_doc(lines: &[String]) -> (String, std::collections::HashMap<String, String>) {
25 let mut desc_lines = vec![];
26 let mut params = std::collections::HashMap::new();
27 for line in lines {
28 if line.is_empty() {
29 continue;
30 }
31 if let Some((key, val)) = line.split_once(':') {
32 let key = key.trim().to_string();
33 let val = val.trim().to_string();
34 if key.chars().all(|c| c.is_alphanumeric() || c == '_') && !val.is_empty() {
35 params.insert(key, val);
36 continue;
37 }
38 }
39 if params.is_empty() {
40 desc_lines.push(line.clone());
41 }
42 }
43 (desc_lines.join(" ").trim().to_string(), params)
44}
45
46fn type_to_json_schema(ty: &Type) -> TokenStream2 {
47 let type_str = quote!(#ty).to_string().replace(" ", "");
48 match type_str.as_str() {
49 "String" | "&str" => quote!(serde_json::json!({"type": "string"})),
50 "bool" => quote!(serde_json::json!({"type": "boolean"})),
51 "f32" | "f64" => quote!(serde_json::json!({"type": "number"})),
52 s if s.starts_with("Option<") => {
53 let inner = &type_str[7..type_str.len() - 1];
54 match inner {
55 "String" | "&str" => quote!(serde_json::json!({"type": "string"})),
56 "bool" => quote!(serde_json::json!({"type": "boolean"})),
57 "f32" | "f64" => quote!(serde_json::json!({"type": "number"})),
58 _ => quote!(serde_json::json!({"type": "integer"})),
59 }
60 }
61 _ => quote!(serde_json::json!({"type": "integer"})),
62 }
63}
64
65fn is_option(ty: &Type) -> bool {
66 quote!(#ty)
67 .to_string()
68 .replace(" ", "")
69 .starts_with("Option<")
70}
71
72struct ToolMethod {
73 tool_name: String,
74 description: String,
75 params: Vec<ParamInfo>,
76 body: syn::Block,
77}
78
79struct ParamInfo {
80 name: String,
81 ty: Type,
82 desc: String,
83 optional: bool,
84}
85
86#[proc_macro_attribute]
87pub fn tool(attr: TokenStream, item: TokenStream) -> TokenStream {
88 let item_impl = parse_macro_input!(item as ItemImpl);
89
90 let override_name: Option<String> = if !attr.is_empty() {
91 let s = TokenStream2::from(attr).to_string();
92 s.find('"').and_then(|start| {
93 s.rfind('"')
94 .filter(|&end| end > start)
95 .map(|end| s[start + 1..end].to_string())
96 })
97 } else {
98 None
99 };
100
101 let mut tool_methods: Vec<ToolMethod> = vec![];
102
103 for item in &item_impl.items {
104 if let ImplItem::Fn(method) = item {
105 if method.sig.asyncness.is_none() {
106 continue;
107 }
108 let fn_name = method.sig.ident.to_string();
109 let tool_name = override_name.clone().unwrap_or_else(|| fn_name.clone());
110 let doc_lines = extract_doc(&method.attrs);
111 let (description, param_docs) = parse_doc(&doc_lines);
112
113 let mut params = vec![];
114 for arg in &method.sig.inputs {
115 if let FnArg::Typed(pt) = arg {
116 let name = if let Pat::Ident(pi) = &*pt.pat {
117 pi.ident.to_string()
118 } else {
119 continue;
120 };
121 let ty = (*pt.ty).clone();
122 let desc = param_docs.get(&name).cloned().unwrap_or_default();
123 let optional = is_option(&ty);
124 params.push(ParamInfo {
125 name,
126 ty,
127 desc,
128 optional,
129 });
130 }
131 }
132 tool_methods.push(ToolMethod {
133 tool_name,
134 description,
135 params,
136 body: method.block.clone(),
137 });
138 }
139 }
140
141 let raw_tools_body = tool_methods.iter().map(|m| {
142 let tool_name = &m.tool_name;
143 let description = &m.description;
144 let prop_inserts = m.params.iter().map(|p| {
145 let pname = &p.name;
146 let pdesc = &p.desc;
147 let schema = type_to_json_schema(&p.ty);
148 quote! {{
149 let mut prop = #schema;
150 prop["description"] = serde_json::json!(#pdesc);
151 properties.insert(#pname.to_string(), prop);
152 }}
153 });
154 let required: Vec<&str> = m
155 .params
156 .iter()
157 .filter(|p| !p.optional)
158 .map(|p| p.name.as_str())
159 .collect();
160 quote! {{
161 let mut properties = serde_json::Map::new();
162 #(#prop_inserts)*
163 let required: Vec<&str> = vec![#(#required),*];
164 ds_api::raw::request::tool::Tool {
165 r#type: ds_api::raw::request::message::ToolType::Function,
166 function: ds_api::raw::request::tool::Function {
167 name: #tool_name.to_string(),
168 description: Some(#description.to_string()),
169 parameters: serde_json::json!({
170 "type": "object",
171 "properties": properties,
172 "required": required,
173 }),
174 strict: None,
175 },
176 }
177 }}
178 });
179
180 let call_arms = tool_methods.iter().map(|m| {
181 let tool_name = &m.tool_name;
182 let body = &m.body;
183 let arg_parses = m.params.iter().map(|p| {
184 let pname = syn::Ident::new(&p.name, Span::call_site());
185 let pname_str = &p.name;
186 let ty = &p.ty;
187 quote! {
188 let #pname: #ty = serde_json::from_value(
189 args.get(#pname_str).cloned().unwrap_or(serde_json::Value::Null)
190 ).expect(concat!("failed to parse param: ", #pname_str));
191 }
192 });
193 quote! {
194 #tool_name => {
195 #(#arg_parses)*
196 { #body }
197 }
198 }
199 });
200
201 let self_ty = &item_impl.self_ty;
202
203 let expanded = quote! {
204 #[async_trait::async_trait]
205 impl ds_api::tool_trait::Tool for #self_ty {
206 fn raw_tools(&self) -> Vec<ds_api::raw::request::tool::Tool> {
207 vec![#(#raw_tools_body),*]
208 }
209
210 async fn call(&self, name: &str, args: serde_json::Value) -> serde_json::Value {
211 match name {
212 #(#call_arms)*
213 _ => serde_json::json!({"error": format!("unknown tool: {}", name)}),
214 }
215 }
216 }
217 };
218
219 expanded.into()
220}