1use std::collections::HashSet;
2
3use proc_macro2::{Span, TokenStream};
4use quote::{quote, ToTokens, TokenStreamExt};
5use regex::Regex;
6use syn::spanned::Spanned;
7use syn::{parse_macro_input, GenericArgument, ItemImpl, PathArguments, Signature};
8use syn::{FnArg, Ident, Pat, Type};
9
10fn create_tool_schema_const_indentifier(struct_name: &str) -> Ident {
11 Ident::new(
12 &format!("_{}_SCHEMA", struct_name.to_uppercase(),),
13 Span::call_site(),
14 )
15}
16
17struct FunctionDefintion {
18 is_async: bool,
19 name: Ident,
20 name_str: String,
21 parameters: Vec<Parameter>,
22 return_type: ReturnType,
23 description: Option<String>,
25}
26
27impl FunctionDefintion {
28 fn create_schema_const_indentifier(&self, struct_name: &str) -> Ident {
29 Ident::new(
30 &format!(
31 "_{}_{}_PARMETER_SCHEMA",
32 struct_name.to_uppercase(),
33 self.name_str.to_uppercase()
34 ),
35 Span::call_site(),
36 )
37 }
38}
39
40struct Parameter {
41 name: Ident,
42 name_str: String,
43 param_type: syn::Type,
44 description: Option<String>,
46}
47
48enum ReturnType {
49 Result(ResultReturnType),
50 Other(OtherReturnType),
51}
52
53struct ResultReturnType {
54 okay: Type,
55 error: Type,
56}
57
58struct OtherReturnType {
59 other: Type,
60}
61
62#[proc_macro_attribute]
63pub fn tool(
64 _attr: proc_macro::TokenStream,
65 item: proc_macro::TokenStream,
66) -> proc_macro::TokenStream {
67 let mut input = parse_macro_input!(item as ItemImpl);
68 let struct_name = match &*input.self_ty {
69 Type::Path(type_path) => &type_path.path.segments.last().unwrap().ident,
70 _ => panic!("Invalid impl type"),
71 };
72 let generics = &input.generics;
73 let struct_name_str = struct_name.to_token_stream().to_string();
74
75 let methods: Vec<_> = input
76 .items
77 .clone()
78 .into_iter()
79 .filter_map(|item| {
80 if let syn::ImplItem::Fn(method) = item {
81 let attrs = &method.attrs;
82 for attr in attrs.iter() {
83 let path = attr.path();
84 if path.is_ident("tool_part") {
85 return Some(method);
86 }
87 }
88 }
89 None
90 })
91 .collect();
92
93
94 input
95 .items
96 .iter_mut()
97 .for_each(|item| {
98 if let syn::ImplItem::Fn(method) = item {
99 method.attrs.retain(|attr|{
100 !attr.path().is_ident("tool_part")
101 });
102 }
103 });
104
105
106 let mut function_definitions = Vec::new();
107 for method in methods {
108 let syn::ImplItemFn {
109 attrs,
110 vis: _,
111 defaultness: _,
112 sig,
113 block: _,
114 } = method;
115 let mut function_definition = match extract_function_defintion(sig) {
116 Ok(okay) => okay,
117 Err(error) => return error.into_compile_error().into(),
118 };
119 match extract_description(&mut function_definition, attrs) {
120 Ok(_) => {}
121 Err(error) => return error.into_compile_error().into(),
122 }
123 function_definitions.push(function_definition);
124 }
125
126 if function_definitions.is_empty() {
127 return syn::Error::new_spanned(
128 struct_name,
129 "No functions found in this tool. Please add functions to the tool with the `#[tool_part]` attribute.",
130 )
131 .into_compile_error()
132 .into();
133 }
134
135 let function_schema = create_tool_json_schema(&struct_name_str, &mut function_definitions);
136 let parameter_json_schema = function_definitions.iter_mut().map(|function_definition| {
137 create_function_parameter_json_schema(&struct_name_str, function_definition)
138 }).fold(TokenStream::new(), |mut acc, item| { acc.append_all(item); acc });
139
140 let impl_traits = impl_traits(&struct_name, &struct_name_str, generics, &function_definitions);
141
142 let expanded = quote! {
143 #input
144
145 #function_schema
146
147 #parameter_json_schema
148
149 #impl_traits
150 };
151
152 proc_macro::TokenStream::from(expanded)
153}
154
155struct CommonReturnTypes<'a> {
156 result_err: HashSet<&'a Type>,
157 result_ok_and_regular: HashSet<&'a Type>,
158}
159
160impl<'a> CommonReturnTypes<'a> {
161 pub fn new() -> Self {
162 Self {
163 result_err: HashSet::new(),
164 result_ok_and_regular: HashSet::new(),
165 }
166 }
167}
168
169fn impl_traits(struct_name: &syn::Ident, struct_name_str: &str, generics: &syn::Generics, function_definitions: &Vec<FunctionDefintion>) -> TokenStream {
170 let mut common_return_types = CommonReturnTypes::new();
171 for function_definition in function_definitions.iter() {
172 match &function_definition.return_type {
173 ReturnType::Result(result_return_type) => {
174 common_return_types.result_err.insert(&result_return_type.error);
175 common_return_types.result_ok_and_regular.insert(&result_return_type.okay);
176 }
177 ReturnType::Other(other_return_type) => {
178 common_return_types.result_ok_and_regular.insert(&other_return_type.other);
179 }
180 }
181 }
182
183 let mut common_err_type: Option<Type> = None;
184 let all_are_results_with_same_err_type = common_return_types.result_err.len() == 1;
185 if all_are_results_with_same_err_type {
186 let first = *common_return_types.result_err.iter().next().unwrap();
187 common_err_type = Some(first.clone());
188 }
189 let mut common_ok_type: Option<Type> = None;
190 let all_have_same_ok_type = common_return_types.result_ok_and_regular.len() == 1;
191 if all_have_same_ok_type {
192 let first = *common_return_types.result_ok_and_regular.iter().next().unwrap();
193 common_ok_type = Some(first.clone());
194 }
195
196 let all_functions_are_regular = common_return_types.result_err.len() == 0; let impls_needed = determine_impls_needed(common_ok_type, common_err_type, all_functions_are_regular);
198
199 let mut all_impl_tokens = TokenStream::new();
200
201 let box_any_type = quote! {
202 Box<dyn std::any::Any>
203 };
204 let box_error_type = quote! {
205 Box<dyn std::error::Error>
206 };
207 let infallible_type = quote! {
208 std::convert::Infallible
209 };
210 for impl_needed in impls_needed {
211 let tokens = match impl_needed {
212 ImplTypes::BoxAndBox => impl_trait(struct_name, struct_name_str, generics,function_definitions, true, true, &box_any_type, &box_error_type),
213 ImplTypes::BoxAndSpecific(err_type) => impl_trait(struct_name, struct_name_str, generics,function_definitions, true, false, &box_any_type, &err_type.to_token_stream()),
214 ImplTypes::SpecificAndBox(ok_type) => impl_trait(struct_name, struct_name_str, generics,function_definitions, false, true, &ok_type.to_token_stream(), &box_error_type),
215 ImplTypes::SpecificAndSpecific(ok_type, err_type) => impl_trait(struct_name, struct_name_str, generics,function_definitions, false, false, &ok_type.to_token_stream(), &err_type.to_token_stream()),
216 ImplTypes::BoxAndInfallible => impl_trait(struct_name, struct_name_str, generics,function_definitions, true, false, &box_any_type, &infallible_type),
217 ImplTypes::SpecificAndInfallible(ok_type) => impl_trait(struct_name, struct_name_str, generics,function_definitions, false, false, &ok_type.to_token_stream(), &infallible_type),
218 };
219 all_impl_tokens.append_all(tokens);
220 }
221
222 all_impl_tokens
223}
224
225enum ImplTypes {
226 BoxAndBox,
227 BoxAndSpecific(Type),
228 SpecificAndBox(Type),
229 SpecificAndSpecific(Type, Type),
230 BoxAndInfallible,
231 SpecificAndInfallible(Type),
232}
233
234fn determine_impls_needed(common_ok_type: Option<Type>, common_err_type: Option<Type>, all_functions_are_regular: bool) -> Vec<ImplTypes> {
235 let mut vecs = match (common_ok_type.clone(), common_err_type.clone()) {
236 (None, None) => vec![],
237 (None, Some(err_type)) => vec![ImplTypes::BoxAndSpecific(err_type)],
238 (Some(ok_type), None) => vec![ImplTypes::SpecificAndBox(ok_type)],
239 (Some(ok_type), Some(err_type)) => vec![ImplTypes::BoxAndSpecific(err_type.clone()), ImplTypes::SpecificAndBox(ok_type.clone()), ImplTypes::SpecificAndSpecific(ok_type, err_type)],
240 };
241 if all_functions_are_regular {
242 assert!(common_err_type.is_none(), "If there are no result functions, there should be no error type");
243 vecs.push(ImplTypes::BoxAndInfallible);
244 if let Some(common_ok_type) = common_ok_type {
245 vecs.push(ImplTypes::SpecificAndInfallible(common_ok_type));
246 }
247 }
248 vecs.push(ImplTypes::BoxAndBox);
249 vecs
250}
251
252fn impl_trait(struct_name: &syn::Ident, struct_name_str:&str, generics: &syn::Generics, function_definitions: &Vec<FunctionDefintion>, ok_needs_box: bool, err_needs_box: bool, ok_type: &TokenStream, err_type: &TokenStream) -> TokenStream {
253 let function_names = function_definitions.iter().map(|function_definition| {
254 &function_definition.name_str
255 });
256
257 let run_arms = function_definitions.iter().map(|function_definition| {
258 let function_parameter_statements = function_definition.parameters.iter().map(|parameter|{
259 let Parameter {
260 name,
261 name_str,
262 param_type,
263 description: _,
264 } = parameter;
265 let serde_message = format!("Parameter `{}` does not follow schema", name_str);
266 let missing_message = format!("Missing `{}` parameter", name_str);
267 let deserialize= match param_type {
268 Type::Reference(type_reference) => match &*type_reference.elem {
269 Type::Path(type_path) => {
270 if type_path.path.get_ident().is_some_and(|item| &*item.to_string() == "str") {
271 Some(quote! {
272 let #name: &str = &*serde_json::from_value::<String>(#name).map_err(|_| llmtoolbox::FunctionCallError::parsing(#serde_message.to_owned()))?;
273 })
274 }
275 else {
276 Some(quote! {
277 let #name: #param_type = &serde_json::from_value::<#type_path>(#name).map_err(|_| llmtoolbox::FunctionCallError::parsing(#serde_message.to_owned()))?;
278 })
279 }
280 },
281 _ => None,
282 },
283 _ => None,
284 }.unwrap_or(quote! {
285 let #name: #param_type = serde_json::from_value::<#param_type>(#name).map_err(|_| llmtoolbox::FunctionCallError::parsing(#serde_message.to_owned()))?;
286 });
287 quote! {
288 let #name = parameters.remove(#name_str).ok_or_else(|| llmtoolbox::FunctionCallError::parsing(#missing_message.to_owned()))?;
289 #deserialize
290 }
291 });
292 let return_statement =
293 make_return_statement(function_definition, ok_needs_box, err_needs_box);
294 let function_name_str = &function_definition.name_str;
295 quote! {
296 #function_name_str => {
297 #(#function_parameter_statements)*
298 #return_statement
299 }
300 }
301 }).fold(TokenStream::new(), |mut acc, item| { acc.append_all(item); acc });
302
303 let schema = create_tool_schema_const_indentifier(struct_name_str);
304 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
305 quote! {
306 impl #impl_generics llmtoolbox::Tool<#ok_type, #err_type> for #struct_name #ty_generics #where_clause {
308 fn function_names(&self) -> &[&'static str] {
309 &[
310 #(#function_names),*
311 ]
312 }
313
314 fn schema(&self) -> &'static serde_json::Map<String, serde_json::Value> {
315 #schema.as_object().unwrap()
316 }
317
318 fn call_function<'life0, 'life1, 'async_trait>(
319 &'life0 self,
320 name: &'life1 str,
321 parameters: serde_json::Map<String, serde_json::Value>,
322 ) -> ::core::pin::Pin<
323 Box<
324 dyn ::core::future::Future<
325 Output = Result<
326 Result<#ok_type, #err_type>,
327 llmtoolbox::FunctionCallError,
328 >,
329 > + ::core::marker::Send
330 + 'async_trait,
331 >,
332 >
333 where
334 'life0: 'async_trait,
335 'life1: 'async_trait,
336 Self: 'async_trait,
337 {
338 Box::pin(async move {
339 if let ::core::option::Option::Some(__ret) = ::core::option::Option::None::<
340 Result<
341 Result<#ok_type, #err_type>,
342 llmtoolbox::FunctionCallError,
343 >,
344 > {
345 #[allow(unreachable_code)]
346 return __ret;
347 }
348 let __self = self;
349 let mut parameters = parameters;
350 let __ret: Result<
351 Result<#ok_type, #err_type>,
352 llmtoolbox::FunctionCallError,
353 > = {
354 match &*name {
355 #run_arms
356 _ => return Err(llmtoolbox::FunctionCallError::function_not_found(name.to_owned())),
357 }
358 };
359 #[allow(unreachable_code)]
360 __ret
361 })
362 }
363 }
376 }
377}
378
379fn make_return_statement(function_definition: &FunctionDefintion, ok_needs_box: bool, err_needs_box: bool) -> TokenStream {
380 let async_part;
381 if function_definition.is_async {
382 async_part = quote! {
383 .await
384 }
385 }
386 else {
387 async_part = quote! {}
388 }
389 let function_parameters = function_definition.parameters.iter().map(|parameter| {
390 ¶meter.name
391 });
392 let function_name = &function_definition.name;
393 match function_definition.return_type {
394 ReturnType::Result(_) => {
395 if ok_needs_box {
396 if err_needs_box {
397 quote! {
398 return Ok(match self.#function_name(#(#function_parameters),*)#async_part {
399 Ok(value) => Ok(Box::new(value) as Box<dyn std::any::Any>),
400 Err(value) => Err(Box::new(value) as Box<dyn std::error::Error>),
401 });
402 }
403 }
404 else {
405 quote! {
406 return Ok(self.#function_name(#(#function_parameters),*)#async_part.map(|value| Box::new(value) as Box<dyn std::any::Any>));
407 }
408 }
409 }
410 else {
411 if err_needs_box {
412 quote! {
413 return Ok(self.#function_name(#(#function_parameters),*)#async_part.map_err(|error| Box::new(error) as Box<dyn std::error::Error>));
414 }
415 }
416 else {
417 quote! {
418 return Ok(self.#function_name(#(#function_parameters),*)#async_part);
419 }
420 }
421 }
422 },
423 ReturnType::Other(_) => {
424 if ok_needs_box {
425 quote! {
426 return Ok(Ok(Box::new(self.#function_name(#(#function_parameters),*)#async_part)));
427 }
428 }
429 else {
430 quote! {
431 return Ok(Ok(self.#function_name(#(#function_parameters),*)#async_part));
432 }
433 }
434 }
435 }
436}
437
438fn extract_function_defintion(signature: Signature) -> syn::Result<FunctionDefintion> {
439 let inputs = &signature.inputs;
440 let parameters = inputs
441 .iter()
442 .filter_map(|arg| {
443 if let FnArg::Typed(arg) = arg {
444 if let Pat::Ident(pat_ident) = &*arg.pat {
445 let name_str = pat_ident.ident.to_string();
446 let name = pat_ident.ident.clone();
447 let type_ = *arg.ty.clone();
449
450 Some(Parameter {
451 name,
452 name_str,
453 param_type: type_,
454 description: None,
455 })
456 } else {
457 None
458 }
459 } else {
460 None
461 }
462 })
463 .collect::<Vec<_>>();
464
465 let return_type = match signature.output {
466 syn::ReturnType::Default => {
467 return Err(syn::Error::new_spanned(
468 signature,
469 "Currently, tool functions must have a return type, even if it is just `()`.",
470 ))
471 }
472 syn::ReturnType::Type(_, return_type) => *return_type,
473 };
474 let return_type = (|| {
475 match &return_type {
476 Type::Path(type_path) => {
477 let segments = &type_path.path.segments;
478 if segments.len() != 1 {
479 return ReturnType::Other(OtherReturnType { other: return_type });
480 }
481 let segment = segments.last().unwrap();
482 if let PathArguments::AngleBracketed(angle_bracketed_args) = &segment.arguments {
483 let mut generics = angle_bracketed_args.args.iter();
484
485 if let (Some(GenericArgument::Type(okay)), Some(GenericArgument::Type(error))) =
486 (generics.next(), generics.next())
487 {
488 return ReturnType::Result(ResultReturnType {
489 okay: okay.clone(),
490 error: error.clone(),
491 });
492 }
493 }
494 }
495 _ => {}
496 }
497 return ReturnType::Other(OtherReturnType { other: return_type });
498 })();
499
500 let is_async = signature.asyncness.is_some();
501 let name = signature.ident;
502 let name_str = name.to_string();
503 Ok(FunctionDefintion {
504 is_async,
505 name,
506 name_str,
507 parameters,
508 return_type,
509 description: None,
510 })
511}
512
513fn extract_description(
514 function_definition: &mut FunctionDefintion,
515 attrs: Vec<syn::Attribute>,
516) -> syn::Result<()> {
517 let FunctionDefintion {
518 is_async: _,
519 name,
520 name_str,
521 parameters,
522 return_type: _,
523 description,
524 } = function_definition;
525 let re = Regex::new(r".*?`(?<name>.*?)`\s*-\s*(?<description>.*)$").unwrap();
526 for attr in attrs.iter() {
527 match &attr.meta {
528 syn::Meta::NameValue(name_value) => match &name_value.value {
529 syn::Expr::Lit(lit) => match &lit.lit {
530 syn::Lit::Str(str) => {
531 let haystack = str.value();
532 let arg_caps = match re.captures(&haystack) {
533 Some(caps) => caps,
534 None => {
535 if let Some(description) = description {
536 description.push_str(&*format!("{}\n", &str.value().trim()));
537 } else {
538 let _ = description.insert(str.value().trim().to_string());
539 }
540 continue;
541 }
542 };
543 let name = arg_caps["name"].to_string();
544 let desc = arg_caps["description"].to_string();
545 if let Some(param) = parameters.iter_mut().find(|p| p.name_str == name) {
546 param.description = Some(desc);
547 } else {
548 return Err(syn::Error::new_spanned(
549 attr,
550 format!("parameter `{}` not found in function definition", name),
551 ));
552 }
553 }
554 _ => {}
555 },
556 _ => {}
557 },
558 _ => {}
559 }
560 }
561 for parameter in parameters {
562 if parameter.description.is_none() {
563 return Err(syn::Error::new_spanned(
564 parameter.name.clone(),
565 format!("missing description for parameter `{}`. Descriptions are doc comments the form of:\n\
566 /// `parameter_name` - This is the description for the parameter.", parameter.name_str),
567 ));
568 }
569 }
570 if function_definition.description.is_none() {
571 return Err(syn::Error::new_spanned(
572 name.clone(),
573 format!("missing description for function `{}`", name_str),
574 ));
575 }
576 Ok(())
577}
578
579fn rust_type_to_known_json_schema_type(ty: &Type) -> Option<&'static str> {
581 match ty {
582 Type::Path(type_path) => {
583 if let Some(segment) = type_path.path.segments.last() {
584 return match segment.ident.to_string().as_str() {
585 "String" | "str" => Some("string"),
586 "i8" | "i16" | "i32" | "i64" | "isize" => Some("integer"),
588 "u8" | "u16" | "u32" | "u64" | "usize" => Some("integer"), "u128" | "i128" => Some("integer"), "f32" | "f64" => Some("number"),
591 "bool" => Some("boolean"),
592 _ => None,
593 };
594 } else {
595 None
596 }
597 }
598 Type::Reference(type_ref) => rust_type_to_known_json_schema_type(&type_ref.elem),
599 _ => None,
600 }
601}
602
603fn create_tool_json_schema(
604 struct_name: &str,
605 function_definitions: &Vec<FunctionDefintion>,
606) -> proc_macro2::TokenStream {
607 let mut function_schemas = Vec::new();
608 for function_definition in function_definitions {
609 let id = function_definition.create_schema_const_indentifier(struct_name);
610 let description = &function_definition.description;
611 let name = &function_definition.name;
612
613 function_schemas.push(quote! {
614 serde_json::json!(
615 {
616 "type": "object",
617 "description": stringify!(#description),
618 "properties": {
619 "function_name": {
620 "const": stringify!(#name),
621 },
622 "parameters": *#id
623 },
624 "required": ["function_name", "parameters"]
625 }
626 )
627 });
628 }
629 let id = create_tool_schema_const_indentifier(struct_name);
630 quote! {
631 const #id: std::cell::LazyCell<&'static serde_json::Value> = std::cell::LazyCell::new(|| {
632 Box::leak(Box::new(serde_json::json!(
633 {
634 "$schema": "http://json-schema.org/draft-07/schema#",
635 "oneOf": [
636 #(#function_schemas),*
637 ]
638 }
639 )))
640 });
641 }
642}
643
644fn create_function_parameter_json_schema(
645 struct_name: &str,
646 function_definition: &mut FunctionDefintion,
647) -> proc_macro2::TokenStream {
648 let parameters = &function_definition.parameters;
649 let mut known_properties = Vec::new();
650 let mut known_required_property_name = Vec::new();
651 let mut computed_required_property_name = Vec::new();
652 let mut computed_properties_outer_definitions = Vec::new();
654 let mut computed_properties = Vec::new();
655 let mut num_of_computed_properties = 0;
656 for parameter in parameters {
657 let name = ¶meter.name_str;
658 let description = ¶meter.description;
659 let param_type = ¶meter.param_type;
660 let json_schema_type = rust_type_to_known_json_schema_type(¶meter.param_type);
661 if let Some(param_type) = json_schema_type {
662 known_properties.push(quote! {
663 #name: {
664 "type": #param_type,
665 "description": #description
666 }
667 });
668 known_required_property_name.push(quote! {
669 #name
670 });
671 } else {
672 num_of_computed_properties +=1;
673 let id = Ident::new(
674 &format!("computed{num_of_computed_properties}"),
675 json_schema_type.span(),
676 );
677 computed_properties_outer_definitions.push(quote! {
678 let #id = (|| {
679 let schema_settings = schemars::generate::SchemaSettings::draft07();
680 let schema = schemars::SchemaGenerator::new(schema_settings).into_root_schema_for::<#param_type>();
681 let mut schema = schema.to_value();
682 llmtoolbox::clean_up_schema(&mut schema);
683 match schema {
684 serde_json::Value::Object(ref mut map) => {
685 map.insert("description".to_string(), serde_json::Value::String(#description.to_string()));
686 },
687 _ => panic!("schema should always generate a map type.")
688 }
689 return schema;
690 })();
691 });
692 computed_properties.push(quote! {
693 #name: #id
694 });
695 computed_required_property_name.push(quote! {
696 #name
697 });
698 }
699 }
700 let id = function_definition.create_schema_const_indentifier(struct_name);
701 quote! {
702 const #id: std::cell::LazyCell<serde_json::Value> = std::cell::LazyCell::new(|| {
703 #(#computed_properties_outer_definitions)*
704 serde_json::json!(
705 {
706 "type": "object",
707 "required": [
708 #(#known_required_property_name),*
709 #(#computed_required_property_name),*
710 ],
711 "properties": {
712 #(#known_properties),*
713 #(#computed_properties),*
714 },
715 }
716 )
717 });
718 }
719}