1use proc_macro::TokenStream;
2use proc_macro2::{Span, TokenStream as TokenStream2};
3use quote::quote;
4use syn::{Expr, FnArg, ImplItem, ItemFn, ItemImpl, Lit, Meta, Pat, Type, parse_macro_input};
5
6fn extract_doc(attrs: &[syn::Attribute]) -> Vec<String> {
7 attrs
8 .iter()
9 .filter_map(|attr| {
10 if !attr.path().is_ident("doc") {
11 return None;
12 }
13 if let Meta::NameValue(nv) = &attr.meta
14 && let Expr::Lit(el) = &nv.value
15 && let Lit::Str(s) = &el.lit
16 {
17 return Some(s.value().trim().to_string());
18 }
19 None
20 })
21 .collect()
22}
23
24fn parse_doc(lines: &[String]) -> (String, std::collections::HashMap<String, String>) {
25 let mut desc_lines = vec![];
26 let mut params = std::collections::HashMap::new();
27 for line in lines {
28 if line.is_empty() {
29 continue;
30 }
31 if let Some((key, val)) = line.split_once(':') {
32 let key = key.trim().to_string();
33 let val = val.trim().to_string();
34 if key.chars().all(|c| c.is_alphanumeric() || c == '_') && !val.is_empty() {
35 params.insert(key, val);
36 continue;
37 }
38 }
39 if params.is_empty() {
40 desc_lines.push(line.clone());
41 }
42 }
43 (desc_lines.join(" ").trim().to_string(), params)
44}
45
46fn type_to_json_schema(ty: &Type) -> TokenStream2 {
52 match ty {
53 Type::Reference(r) => type_to_json_schema(&r.elem),
55
56 Type::Path(tp) => {
57 let seg = match tp.path.segments.last() {
60 Some(s) => s,
61 None => return unsupported(ty),
62 };
63
64 match seg.ident.to_string().as_str() {
65 "String" | "str" => quote!(serde_json::json!({"type": "string"})),
66 "bool" => quote!(serde_json::json!({"type": "boolean"})),
67 "f32" | "f64" => quote!(serde_json::json!({"type": "number"})),
68 "u8" | "u16" | "u32" | "u64" | "u128" | "usize" | "i8" | "i16" | "i32" | "i64"
69 | "i128" | "isize" => {
70 quote!(serde_json::json!({"type": "integer"}))
71 }
72 "Option" => match inner_type_arg(seg) {
74 Some(inner) => type_to_json_schema(inner),
75 None => unsupported(ty),
76 },
77 "Vec" => match inner_type_arg(seg) {
79 Some(inner) => {
80 let items = type_to_json_schema(inner);
81 quote!(serde_json::json!({"type": "array", "items": #items}))
82 }
83 None => unsupported(ty),
84 },
85 _ => unsupported(ty),
86 }
87 }
88
89 _ => unsupported(ty),
90 }
91}
92
93fn unsupported(ty: &Type) -> TokenStream2 {
95 syn::Error::new_spanned(
96 ty,
97 "unsupported type in #[tool]: use String, bool, f32/f64, \
98 an integer primitive, Vec<T>, or Option<T>",
99 )
100 .to_compile_error()
101}
102
103fn inner_type_arg(seg: &syn::PathSegment) -> Option<&Type> {
106 if let syn::PathArguments::AngleBracketed(args) = &seg.arguments
107 && let Some(syn::GenericArgument::Type(ty)) = args.args.first()
108 {
109 return Some(ty);
110 }
111 None
112}
113
114fn is_option(ty: &Type) -> bool {
115 if let Type::Path(tp) = ty
116 && let Some(seg) = tp.path.segments.last()
117 {
118 return seg.ident == "Option";
119 }
120 false
121}
122
123struct ToolMethod {
124 tool_name: String,
125 description: String,
126 params: Vec<ParamInfo>,
127 body: syn::Block,
128}
129
130struct ParamInfo {
131 name: String,
132 ty: Type,
133 desc: String,
134 optional: bool,
135}
136
137#[proc_macro_attribute]
138pub fn tool(attr: TokenStream, item: TokenStream) -> TokenStream {
139 if let Ok(item_fn) = syn::parse::<ItemFn>(item.clone()) {
141 if item_fn.sig.asyncness.is_some() {
142 return tool_from_fn(attr, item_fn);
143 }
144 }
145 tool_from_impl(attr, item)
147}
148
149fn tool_from_fn(attr: TokenStream, item_fn: ItemFn) -> TokenStream {
150 let fn_name = item_fn.sig.ident.to_string();
151 let struct_ident = item_fn.sig.ident.clone();
152
153 let override_name: Option<String> = if !attr.is_empty() {
154 let s = TokenStream2::from(attr).to_string();
155 s.find('"').and_then(|start| {
156 s.rfind('"')
157 .filter(|&end| end > start)
158 .map(|end| s[start + 1..end].to_string())
159 })
160 } else {
161 None
162 };
163
164 let tool_name = override_name.unwrap_or_else(|| fn_name.clone());
165 let doc_lines = extract_doc(&item_fn.attrs);
166 let (description, param_docs) = parse_doc(&doc_lines);
167
168 let mut params = vec![];
169 for arg in &item_fn.sig.inputs {
170 if let FnArg::Typed(pt) = arg {
171 let name = if let Pat::Ident(pi) = &*pt.pat {
172 pi.ident.to_string()
173 } else {
174 continue;
175 };
176 let ty = (*pt.ty).clone();
177 let desc = param_docs.get(&name).cloned().unwrap_or_default();
178 let optional = is_option(&ty);
179 params.push(ParamInfo {
180 name,
181 ty,
182 desc,
183 optional,
184 });
185 }
186 }
187
188 let method = ToolMethod {
189 tool_name,
190 description,
191 params,
192 body: *item_fn.block,
193 };
194
195 let raw_tools_body = {
196 let tool_name = &method.tool_name;
197 let description = &method.description;
198 let prop_inserts = method.params.iter().map(|p| {
199 let pname = &p.name;
200 let pdesc = &p.desc;
201 let schema = type_to_json_schema(&p.ty);
202 quote! {{
203 let mut prop = #schema;
204 prop["description"] = serde_json::json!(#pdesc);
205 properties.insert(#pname.to_string(), prop);
206 }}
207 });
208 let required: Vec<&str> = method
209 .params
210 .iter()
211 .filter(|p| !p.optional)
212 .map(|p| p.name.as_str())
213 .collect();
214 quote! {{
215 let mut properties = serde_json::Map::new();
216 #(#prop_inserts)*
217 let required: Vec<&str> = vec![#(#required),*];
218 ds_api::raw::request::tool::Tool {
219 r#type: ds_api::raw::request::message::ToolType::Function,
220 function: ds_api::raw::request::tool::Function {
221 name: #tool_name.to_string(),
222 description: Some(#description.to_string()),
223 parameters: serde_json::json!({
224 "type": "object",
225 "properties": properties,
226 "required": required,
227 }),
228 strict: None,
229 },
230 }
231 }}
232 };
233
234 let call_arm = {
235 let tool_name = &method.tool_name;
236 let body = &method.body;
237 let arg_parses = method.params.iter().map(|p| {
238 let pname = syn::Ident::new(&p.name, Span::call_site());
239 let pname_str = &p.name;
240 let ty = &p.ty;
241 quote! {
242 let #pname: #ty = match serde_json::from_value(
243 args.get(#pname_str).cloned().unwrap_or(serde_json::Value::Null)
244 ) {
245 Ok(v) => v,
246 Err(e) => return serde_json::json!({
247 "error": format!("invalid argument '{}': {}", #pname_str, e)
248 }),
249 };
250 }
251 });
252 quote! {
253 #tool_name => {
254 #(#arg_parses)*
255 let __result = (async move || { #body })().await;
256 match serde_json::to_value(__result) {
257 Ok(v) => v,
258 Err(e) => serde_json::json!({ "error": format!("serialization error: {}", e) }),
259 }
260 }
261 }
262 };
263
264 let expanded = quote! {
265 #[allow(non_camel_case_types)]
266 pub struct #struct_ident;
267
268 #[async_trait::async_trait]
269 impl ds_api::tool_trait::Tool for #struct_ident {
270 fn raw_tools(&self) -> Vec<ds_api::raw::request::tool::Tool> {
271 vec![#raw_tools_body]
272 }
273
274 async fn call(&self, name: &str, args: serde_json::Value) -> serde_json::Value {
275 match name {
276 #call_arm
277 _ => serde_json::json!({"error": format!("unknown tool: {}", name)}),
278 }
279 }
280 }
281 };
282
283 expanded.into()
284}
285
286fn tool_from_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
287 let item_impl = parse_macro_input!(item as ItemImpl);
288
289 let override_name: Option<String> = if !attr.is_empty() {
290 let s = TokenStream2::from(attr).to_string();
291 s.find('"').and_then(|start| {
292 s.rfind('"')
293 .filter(|&end| end > start)
294 .map(|end| s[start + 1..end].to_string())
295 })
296 } else {
297 None
298 };
299
300 let mut tool_methods: Vec<ToolMethod> = vec![];
301
302 for item in &item_impl.items {
303 if let ImplItem::Fn(method) = item {
304 if method.sig.asyncness.is_none() {
305 continue;
306 }
307 let fn_name = method.sig.ident.to_string();
308 let tool_name = override_name.clone().unwrap_or_else(|| fn_name.clone());
309 let doc_lines = extract_doc(&method.attrs);
310 let (description, param_docs) = parse_doc(&doc_lines);
311
312 let mut params = vec![];
313 for arg in &method.sig.inputs {
314 if let FnArg::Typed(pt) = arg {
315 let name = if let Pat::Ident(pi) = &*pt.pat {
316 pi.ident.to_string()
317 } else {
318 continue;
319 };
320 let ty = (*pt.ty).clone();
321 let desc = param_docs.get(&name).cloned().unwrap_or_default();
322 let optional = is_option(&ty);
323 params.push(ParamInfo {
324 name,
325 ty,
326 desc,
327 optional,
328 });
329 }
330 }
331 tool_methods.push(ToolMethod {
332 tool_name,
333 description,
334 params,
335 body: method.block.clone(),
336 });
337 }
338 }
339
340 let raw_tools_body = tool_methods.iter().map(|m| {
341 let tool_name = &m.tool_name;
342 let description = &m.description;
343 let prop_inserts = m.params.iter().map(|p| {
344 let pname = &p.name;
345 let pdesc = &p.desc;
346 let schema = type_to_json_schema(&p.ty);
347 quote! {{
348 let mut prop = #schema;
349 prop["description"] = serde_json::json!(#pdesc);
350 properties.insert(#pname.to_string(), prop);
351 }}
352 });
353 let required: Vec<&str> = m
354 .params
355 .iter()
356 .filter(|p| !p.optional)
357 .map(|p| p.name.as_str())
358 .collect();
359 quote! {{
360 let mut properties = serde_json::Map::new();
361 #(#prop_inserts)*
362 let required: Vec<&str> = vec![#(#required),*];
363 ds_api::raw::request::tool::Tool {
364 r#type: ds_api::raw::request::message::ToolType::Function,
365 function: ds_api::raw::request::tool::Function {
366 name: #tool_name.to_string(),
367 description: Some(#description.to_string()),
368 parameters: serde_json::json!({
369 "type": "object",
370 "properties": properties,
371 "required": required,
372 }),
373 strict: None,
374 },
375 }
376 }}
377 });
378
379 let call_arms = tool_methods.iter().map(|m| {
380 let tool_name = &m.tool_name;
381 let body = &m.body;
382 let arg_parses = m.params.iter().map(|p| {
383 let pname = syn::Ident::new(&p.name, Span::call_site());
384 let pname_str = &p.name;
385 let ty = &p.ty;
386 quote! {
387 let #pname: #ty = match serde_json::from_value(
388 args.get(#pname_str).cloned().unwrap_or(serde_json::Value::Null)
389 ) {
390 Ok(v) => v,
391 Err(e) => return serde_json::json!({
392 "error": format!("invalid argument '{}': {}", #pname_str, e)
393 }),
394 };
395 }
396 });
397 quote! {
398 #tool_name => {
399 #(#arg_parses)*
400 let __result = (async move || { #body })().await;
401 match serde_json::to_value(__result) {
402 Ok(v) => v,
403 Err(e) => serde_json::json!({ "error": format!("serialization error: {}", e) }),
404 }
405 }
406 }
407 });
408
409 let self_ty = &item_impl.self_ty;
410
411 let expanded = quote! {
412 #[async_trait::async_trait]
413 impl ds_api::tool_trait::Tool for #self_ty {
414 fn raw_tools(&self) -> Vec<ds_api::raw::request::tool::Tool> {
415 vec![#(#raw_tools_body),*]
416 }
417
418 async fn call(&self, name: &str, args: serde_json::Value) -> serde_json::Value {
419 match name {
420 #(#call_arms)*
421 _ => serde_json::json!({"error": format!("unknown tool: {}", name)}),
422 }
423 }
424 }
425 };
426
427 expanded.into()
428}