genai_rs_macros/
lib.rs

1#![cfg_attr(test, allow(dead_code))]
2
3use proc_macro::TokenStream;
4use syn::Pat;
5use utoipa::openapi::RefOr;
6use utoipa::openapi::schema::{ObjectBuilder, Schema};
7
8mod codegen;
9mod parsing;
10mod schema;
11
12use parsing::parse_input;
13use schema::{build_param_schema, get_type_info};
14
15/// Generates a function that returns a `FunctionDeclaration` for the annotated function.
16///
17/// # Example
18/// ```ignore
19/// use genai_rs_macros::tool;
20///
21/// #[tool(
22///     location(description = "The city and state"),
23///     unit(enum_values = ["celsius", "fahrenheit"])
24/// )]
25/// fn get_weather(location: String, unit: Option<String>) -> String {
26///     format!("Weather for {}", location)
27/// }
28///
29/// // The macro generates:
30/// // pub fn get_weather_declaration() -> genai_rs::FunctionDeclaration { ... }
31/// ```
32#[proc_macro_attribute]
33pub fn tool(attr_input: TokenStream, item: TokenStream) -> TokenStream {
34    let input = match parse_input(attr_input, item) {
35        Ok(input) => input,
36        Err(e) => return e.to_compile_error().into(),
37    };
38
39    let func = input.func;
40    let config_map = input.param_configs;
41    let func_name = func.sig.ident.to_string();
42
43    // Collect actual function parameter names
44    let mut actual_param_names = std::collections::HashSet::new();
45    for fn_arg in &func.sig.inputs {
46        if let syn::FnArg::Typed(pat_type) = fn_arg
47            && let Pat::Ident(pat_ident) = &*pat_type.pat
48        {
49            actual_param_names.insert(pat_ident.ident.to_string());
50        }
51    }
52
53    // Check that all macro-referenced parameters actually exist in the function
54    for referenced_param in config_map.keys() {
55        if !actual_param_names.contains(referenced_param) {
56            return syn::Error::new(
57                func.sig.ident.span(),
58                format!(
59                    "Parameter '{}' referenced in #[tool] attribute does not exist in function '{}'. \
60                     Available parameters: {:?}",
61                    referenced_param,
62                    func_name,
63                    actual_param_names.iter().collect::<Vec<_>>()
64                ),
65            )
66            .to_compile_error()
67            .into();
68        }
69    }
70
71    let func_description = parsing::extract_doc_comments(&func.attrs);
72    let mut object_builder = ObjectBuilder::new();
73    let mut required_params_for_struct_field = Vec::new();
74
75    for fn_arg in &func.sig.inputs {
76        if let syn::FnArg::Typed(pat_type) = fn_arg
77            && let Pat::Ident(pat_ident) = &*pat_type.pat
78        {
79            let param_name = pat_ident.ident.to_string();
80            let config = config_map.get(&param_name);
81            let param_schema = build_param_schema(pat_type, config);
82
83            object_builder = object_builder.property(param_name.clone(), param_schema);
84
85            let (is_option, _) = get_type_info(&pat_type.ty);
86            if !is_option {
87                required_params_for_struct_field.push(param_name.clone());
88            }
89        }
90    }
91
92    let parameters_schema_obj = object_builder.build();
93    let parameters_schema_ref_or: RefOr<Schema> = RefOr::T(Schema::Object(parameters_schema_obj));
94
95    codegen::generate_declaration_function(
96        &func,
97        &func_name,
98        &func_description,
99        &parameters_schema_ref_or,
100        &required_params_for_struct_field,
101    )
102    .into()
103}