1use proc_macro::TokenStream;
2use proc_macro2::{Span, TokenStream as TokenStream2};
3use quote::quote;
4use syn::{Expr, FnArg, ImplItem, ItemImpl, Lit, Meta, Pat, Type, parse_macro_input};
5
6fn extract_doc(attrs: &[syn::Attribute]) -> Vec<String> {
7 attrs
8 .iter()
9 .filter_map(|attr| {
10 if !attr.path().is_ident("doc") {
11 return None;
12 }
13 if let Meta::NameValue(nv) = &attr.meta
14 && let Expr::Lit(el) = &nv.value
15 && let Lit::Str(s) = &el.lit
16 {
17 return Some(s.value().trim().to_string());
18 }
19 None
20 })
21 .collect()
22}
23
24fn parse_doc(lines: &[String]) -> (String, std::collections::HashMap<String, String>) {
25 let mut desc_lines = vec![];
26 let mut params = std::collections::HashMap::new();
27 for line in lines {
28 if line.is_empty() {
29 continue;
30 }
31 if let Some((key, val)) = line.split_once(':') {
32 let key = key.trim().to_string();
33 let val = val.trim().to_string();
34 if key.chars().all(|c| c.is_alphanumeric() || c == '_') && !val.is_empty() {
35 params.insert(key, val);
36 continue;
37 }
38 }
39 if params.is_empty() {
40 desc_lines.push(line.clone());
41 }
42 }
43 (desc_lines.join(" ").trim().to_string(), params)
44}
45
46fn type_to_json_schema(ty: &Type) -> TokenStream2 {
52 match ty {
53 Type::Reference(r) => type_to_json_schema(&r.elem),
55
56 Type::Path(tp) => {
57 let seg = match tp.path.segments.last() {
60 Some(s) => s,
61 None => return unsupported(ty),
62 };
63
64 match seg.ident.to_string().as_str() {
65 "String" | "str" => quote!(serde_json::json!({"type": "string"})),
66 "bool" => quote!(serde_json::json!({"type": "boolean"})),
67 "f32" | "f64" => quote!(serde_json::json!({"type": "number"})),
68 "u8" | "u16" | "u32" | "u64" | "u128" | "usize" | "i8" | "i16" | "i32" | "i64"
69 | "i128" | "isize" => {
70 quote!(serde_json::json!({"type": "integer"}))
71 }
72 "Option" => match inner_type_arg(seg) {
74 Some(inner) => type_to_json_schema(inner),
75 None => unsupported(ty),
76 },
77 "Vec" => match inner_type_arg(seg) {
79 Some(inner) => {
80 let items = type_to_json_schema(inner);
81 quote!(serde_json::json!({"type": "array", "items": #items}))
82 }
83 None => unsupported(ty),
84 },
85 _ => unsupported(ty),
86 }
87 }
88
89 _ => unsupported(ty),
90 }
91}
92
93fn unsupported(ty: &Type) -> TokenStream2 {
95 syn::Error::new_spanned(
96 ty,
97 "unsupported type in #[tool]: use String, bool, f32/f64, \
98 an integer primitive, Vec<T>, or Option<T>",
99 )
100 .to_compile_error()
101}
102
103fn inner_type_arg(seg: &syn::PathSegment) -> Option<&Type> {
106 if let syn::PathArguments::AngleBracketed(args) = &seg.arguments
107 && let Some(syn::GenericArgument::Type(ty)) = args.args.first()
108 {
109 return Some(ty);
110 }
111 None
112}
113
114fn is_option(ty: &Type) -> bool {
115 if let Type::Path(tp) = ty
116 && let Some(seg) = tp.path.segments.last()
117 {
118 return seg.ident == "Option";
119 }
120 false
121}
122
123struct ToolMethod {
124 tool_name: String,
125 description: String,
126 params: Vec<ParamInfo>,
127 body: syn::Block,
128}
129
130struct ParamInfo {
131 name: String,
132 ty: Type,
133 desc: String,
134 optional: bool,
135}
136
137#[proc_macro_attribute]
138pub fn tool(attr: TokenStream, item: TokenStream) -> TokenStream {
139 let item_impl = parse_macro_input!(item as ItemImpl);
140
141 let override_name: Option<String> = if !attr.is_empty() {
142 let s = TokenStream2::from(attr).to_string();
143 s.find('"').and_then(|start| {
144 s.rfind('"')
145 .filter(|&end| end > start)
146 .map(|end| s[start + 1..end].to_string())
147 })
148 } else {
149 None
150 };
151
152 let mut tool_methods: Vec<ToolMethod> = vec![];
153
154 for item in &item_impl.items {
155 if let ImplItem::Fn(method) = item {
156 if method.sig.asyncness.is_none() {
157 continue;
158 }
159 let fn_name = method.sig.ident.to_string();
160 let tool_name = override_name.clone().unwrap_or_else(|| fn_name.clone());
161 let doc_lines = extract_doc(&method.attrs);
162 let (description, param_docs) = parse_doc(&doc_lines);
163
164 let mut params = vec![];
165 for arg in &method.sig.inputs {
166 if let FnArg::Typed(pt) = arg {
167 let name = if let Pat::Ident(pi) = &*pt.pat {
168 pi.ident.to_string()
169 } else {
170 continue;
171 };
172 let ty = (*pt.ty).clone();
173 let desc = param_docs.get(&name).cloned().unwrap_or_default();
174 let optional = is_option(&ty);
175 params.push(ParamInfo {
176 name,
177 ty,
178 desc,
179 optional,
180 });
181 }
182 }
183 tool_methods.push(ToolMethod {
184 tool_name,
185 description,
186 params,
187 body: method.block.clone(),
188 });
189 }
190 }
191
192 let raw_tools_body = tool_methods.iter().map(|m| {
193 let tool_name = &m.tool_name;
194 let description = &m.description;
195 let prop_inserts = m.params.iter().map(|p| {
196 let pname = &p.name;
197 let pdesc = &p.desc;
198 let schema = type_to_json_schema(&p.ty);
199 quote! {{
200 let mut prop = #schema;
201 prop["description"] = serde_json::json!(#pdesc);
202 properties.insert(#pname.to_string(), prop);
203 }}
204 });
205 let required: Vec<&str> = m
206 .params
207 .iter()
208 .filter(|p| !p.optional)
209 .map(|p| p.name.as_str())
210 .collect();
211 quote! {{
212 let mut properties = serde_json::Map::new();
213 #(#prop_inserts)*
214 let required: Vec<&str> = vec![#(#required),*];
215 ds_api::raw::request::tool::Tool {
216 r#type: ds_api::raw::request::message::ToolType::Function,
217 function: ds_api::raw::request::tool::Function {
218 name: #tool_name.to_string(),
219 description: Some(#description.to_string()),
220 parameters: serde_json::json!({
221 "type": "object",
222 "properties": properties,
223 "required": required,
224 }),
225 strict: None,
226 },
227 }
228 }}
229 });
230
231 let call_arms = tool_methods.iter().map(|m| {
232 let tool_name = &m.tool_name;
233 let body = &m.body;
234 let arg_parses = m.params.iter().map(|p| {
235 let pname = syn::Ident::new(&p.name, Span::call_site());
236 let pname_str = &p.name;
237 let ty = &p.ty;
238 quote! {
239 let #pname: #ty = match serde_json::from_value(
240 args.get(#pname_str).cloned().unwrap_or(serde_json::Value::Null)
241 ) {
242 Ok(v) => v,
243 Err(e) => return serde_json::json!({
244 "error": format!("invalid argument '{}': {}", #pname_str, e)
245 }),
246 };
247 }
248 });
249 quote! {
250 #tool_name => {
251 #(#arg_parses)*
252 let __result = { #body };
253 match serde_json::to_value(__result) {
254 Ok(v) => v,
255 Err(e) => serde_json::json!({ "error": format!("serialization error: {}", e) }),
256 }
257 }
258 }
259 });
260
261 let self_ty = &item_impl.self_ty;
262
263 let expanded = quote! {
264 #[async_trait::async_trait]
265 impl ds_api::tool_trait::Tool for #self_ty {
266 fn raw_tools(&self) -> Vec<ds_api::raw::request::tool::Tool> {
267 vec![#(#raw_tools_body),*]
268 }
269
270 async fn call(&self, name: &str, args: serde_json::Value) -> serde_json::Value {
271 match name {
272 #(#call_arms)*
273 _ => serde_json::json!({"error": format!("unknown tool: {}", name)}),
274 }
275 }
276 }
277 };
278
279 expanded.into()
280}