Skip to main content

gemini_proc_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4    Attribute, Data, DeriveInput, Fields, FnArg, ItemFn, Lit, Meta, Pat, Type, parse_macro_input,
5};
6
7/// Attribute macro to mark a function as callable by Gemini.
8///
9/// This macro generates a schema for the function and an `execute` method to call it
10/// with JSON arguments.
11///
12/// # Requirements
13/// - Function arguments must be owned types that implement `GeminiSchema` (e.g., `String`, `i32`, `bool`).
14/// - References are not supported.
15/// - The function can be `async` and can return a `Result` (the `Ok` value must implement `Serialize`).
16///
17/// # Example
18/// ```ignore
19/// #[gemini_function]
20/// /// Returns the current weather for a location.
21/// fn get_weather(location: String) -> String {
22///     format!("The weather in {} is sunny.", location)
23/// }
24/// ```
25#[proc_macro_attribute]
26pub fn gemini_function(_attr: TokenStream, item: TokenStream) -> TokenStream {
27    let mut input_fn = parse_macro_input!(item as ItemFn);
28    let fn_name = &input_fn.sig.ident;
29    let args_struct_name = syn::Ident::new(&format!("{}_args", fn_name), fn_name.span());
30    let fn_description = extract_doc_comments(&input_fn.attrs);
31
32    let mut properties = Vec::new();
33    let mut required = Vec::new();
34    let mut param_names = Vec::new();
35    let mut param_types = Vec::new();
36
37    for arg in input_fn.sig.inputs.iter_mut() {
38        if let FnArg::Typed(pat_type) = arg {
39            if let Pat::Ident(pat_ident) = &*pat_type.pat {
40                let param_name = pat_ident.ident.clone();
41                let param_name_str = param_name.to_string();
42                let param_type = (*pat_type.ty).clone();
43                let param_desc = extract_doc_comments(&pat_type.attrs);
44
45                if has_reference(&param_type) {
46                    return syn::Error::new_spanned(
47                        &param_type,
48                        "references are not supported in gemini_function. Use owned types like String instead.",
49                    )
50                    .to_compile_error()
51                    .into();
52                }
53
54                // Remove doc attributes from the function signature so it compiles
55                pat_type.attrs.retain(|attr| !attr.path().is_ident("doc"));
56
57                let is_optional = is_option(&param_type);
58
59                properties.push(quote! {
60                    let mut schema = <#param_type as GeminiSchema>::gemini_schema();
61                    if !#param_desc.is_empty() {
62                        if let Some(obj) = schema.as_object_mut() {
63                            obj.insert("description".to_string(), serde_json::json!(#param_desc));
64                        }
65                    }
66                    props.insert(#param_name_str.to_string(), schema);
67                });
68
69                if !is_optional {
70                    required.push(param_name_str);
71                }
72
73                param_names.push(param_name);
74                param_types.push(param_type);
75            }
76        }
77    }
78
79    let fn_name_str = fn_name.to_string();
80    let is_async = input_fn.sig.asyncness.is_some();
81    let call_await = if is_async {
82        quote! { .await }
83    } else {
84        quote! {}
85    };
86
87    let is_result = match &input_fn.sig.output {
88        syn::ReturnType::Default => false,
89        syn::ReturnType::Type(_, ty) => {
90            let s = quote!(#ty).to_string();
91            s.contains("Result")
92        }
93    };
94
95    let result_handling = if is_result {
96        quote! {
97            match result {
98                Ok(v) => Ok(serde_json::json!(v)),
99                Err(e) => Err(e.to_string()),
100            }
101        }
102    } else {
103        quote! {
104            Ok(serde_json::json!(result))
105        }
106    };
107
108    let expanded = quote! {
109        #input_fn
110
111        #[allow(non_camel_case_types)]
112        pub struct #fn_name { }
113
114        #[allow(non_camel_case_types)]
115        #[derive(gemini_client_api::serde::Deserialize)]
116        pub struct #args_struct_name {
117            #(pub #param_names: #param_types,)*
118        }
119
120        impl GeminiSchema for #fn_name {
121            fn gemini_schema() -> serde_json::Value {
122                use serde_json::{json, Map};
123                let mut props = Map::new();
124                #(#properties)*
125
126                json!({
127                    "name": #fn_name_str,
128                    "description": #fn_description,
129                    "parameters": {
130                        "type": "OBJECT",
131                        "properties": props,
132                        "required": [#(#required),*]
133                    }
134                })
135            }
136        }
137
138        impl #fn_name {
139            pub async fn execute(args: serde_json::Value) -> Result<serde_json::Value, String> {
140                use gemini_client_api::serde::Deserialize;
141                let args = #args_struct_name::deserialize(&args).map_err(|e| e.to_string())?;
142                let result = #fn_name(#(args.#param_names),*) #call_await;
143                #result_handling
144            }
145            pub fn execute_with_closure<F, T>(args: &serde_json::Value, f: F) -> Result<T, serde_json::Error>
146            where
147                F: FnOnce(#(#param_types),*) -> T,
148            {
149                use gemini_client_api::serde::Deserialize;
150                let args = #args_struct_name::deserialize(args)?;
151                Ok(f(#(args.#param_names),*))
152            }
153        }
154    };
155
156    TokenStream::from(expanded)
157}
158
159/// Macro to execute function calls requested by Gemini and update the session history, with a custom callback for results.
160///
161/// # Usage
162/// `execute_function_calls_with_callback!(session, callback, function1, function2, ...)`
163///
164/// The `callback` should be a closure or function that takes `(String, Result<serde_json::Value, String>)`
165/// and returns `serde_json::Value`.
166/// (function_name, result) is passed to it
167///
168/// # Returns
169/// A `Vec<Option<Result<serde_json::Value, String>>>` containing the results of each function call.
170/// - `Some(Ok(value))` if the function was called and succeeded.
171/// - `Some(Err(err))` if the function was called but failed.
172/// - `None` if the function was not called.
173///
174/// # Note
175/// The `session` is automatically updated with the `FunctionResponse`.
176/// The `callback` is invoked with the result of the function execution (whether `Ok` or `Err`)
177/// and its return value is used to update the session.
178#[proc_macro]
179pub fn execute_function_calls_with_callback(input: TokenStream) -> TokenStream {
180    use syn::parse::{Parse, ParseStream};
181    use syn::{Expr, Token};
182
183    struct ExecuteWithCallbackInput {
184        session: Expr,
185        _comma1: Token![,],
186        callback: Expr,
187        _comma2: Token![,],
188        functions: syn::punctuated::Punctuated<syn::Path, Token![,]>,
189    }
190
191    impl Parse for ExecuteWithCallbackInput {
192        fn parse(input: ParseStream) -> syn::Result<Self> {
193            Ok(ExecuteWithCallbackInput {
194                session: input.parse()?,
195                _comma1: input.parse()?,
196                callback: input.parse()?,
197                _comma2: input.parse()?,
198                functions: input.parse_terminated(syn::Path::parse, Token![,])?,
199            })
200        }
201    }
202
203    let input = parse_macro_input!(input as ExecuteWithCallbackInput);
204    generate_execute_logic(&input.session, &input.callback, &input.functions)
205}
206
207/// Attribute macro to derive the `GeminiSchema` trait for a struct or enum.
208///
209/// This allows the type to be used in structured output (`set_json_mode`) or as a parameter
210/// in a `gemini_function`.
211///
212/// # Requirements
213/// - For structs: only named fields are supported.
214/// - For enums: only unit variants (no data) are supported.
215/// - Field/variant types must also implement `GeminiSchema`.
216/// - Doc comments on fields and variants are extracted as descriptions in the schema.
217///
218/// # Example
219/// ```ignore
220/// #[gemini_schema]
221/// struct SearchResult {
222///     /// The title of the page.
223///     title: String,
224///     /// The URL of the page.
225///     url: String,
226/// }
227/// ```
228#[proc_macro_attribute]
229pub fn gemini_schema(_attr: TokenStream, item: TokenStream) -> TokenStream {
230    let input = parse_macro_input!(item as DeriveInput);
231    let name = &input.ident;
232    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
233    let description = extract_doc_comments(&input.attrs);
234
235    let expanded = match &input.data {
236        Data::Struct(data) => {
237            let mut properties = Vec::new();
238            let mut required = Vec::new();
239
240            match &data.fields {
241                Fields::Named(fields) => {
242                    for field in &fields.named {
243                        let field_name = field.ident.as_ref().unwrap();
244                        let field_name_str = field_name.to_string();
245                        let field_type = &field.ty;
246                        let field_desc = extract_doc_comments(&field.attrs);
247
248                        if has_reference(field_type) {
249                            return syn::Error::new_spanned(
250                                field_type,
251                                "references are not supported in gemini_schema. Use owned types instead.",
252                            )
253                            .to_compile_error()
254                            .into();
255                        }
256
257                        let is_optional = is_option(field_type);
258
259                        properties.push(quote! {
260                            let mut schema = <#field_type as GeminiSchema>::gemini_schema();
261                            if !#field_desc.is_empty() {
262                                if let Some(obj) = schema.as_object_mut() {
263                                    obj.insert("description".to_string(), serde_json::json!(#field_desc));
264                                }
265                            }
266                            props.insert(#field_name_str.to_string(), schema);
267                        });
268
269                        if !is_optional {
270                            required.push(field_name_str);
271                        }
272                    }
273                }
274                _ => panic!("gemini_schema only supports named fields in structs"),
275            }
276
277            quote! {
278                impl #impl_generics GeminiSchema for #name #ty_generics #where_clause {
279                    fn gemini_schema() -> serde_json::Value {
280                        use serde_json::{json, Map};
281                        let mut props = Map::new();
282                        #(#properties)*
283
284                        let mut schema = json!({
285                            "type": "OBJECT",
286                            "properties": props,
287                            "required": [#(#required),*]
288                        });
289
290                        if !#description.is_empty() {
291                            if let Some(obj) = schema.as_object_mut() {
292                                obj.insert("description".to_string(), json!(#description));
293                            }
294                        }
295                        schema
296                    }
297                }
298            }
299        }
300        Data::Enum(data) => {
301            let mut variants = Vec::new();
302            for variant in &data.variants {
303                if !matches!(variant.fields, Fields::Unit) {
304                    panic!("gemini_schema only supports unit variants in enums");
305                }
306                variants.push(variant.ident.to_string());
307            }
308
309            quote! {
310                impl #impl_generics GeminiSchema for #name #ty_generics #where_clause {
311                    fn gemini_schema() -> serde_json::Value {
312                        use serde_json::json;
313                        let mut schema = json!({
314                            "type": "STRING",
315                            "enum": [#(#variants),*]
316                        });
317
318                        if !#description.is_empty() {
319                            if let Some(obj) = schema.as_object_mut() {
320                                obj.insert("description".to_string(), json!(#description));
321                            }
322                        }
323                        schema
324                    }
325                }
326            }
327        }
328        _ => panic!("gemini_schema only supports structs and enums"),
329    };
330
331    let output = quote! {
332        #input
333        #expanded
334    };
335
336    TokenStream::from(output)
337}
338
339fn extract_doc_comments(attrs: &[Attribute]) -> String {
340    let mut doc_comments = Vec::new();
341    for attr in attrs {
342        if attr.path().is_ident("doc") {
343            if let Meta::NameValue(nv) = &attr.meta {
344                if let syn::Expr::Lit(expr_lit) = &nv.value {
345                    if let Lit::Str(lit_str) = &expr_lit.lit {
346                        doc_comments.push(lit_str.value().trim().to_string());
347                    }
348                }
349            }
350        }
351    }
352    doc_comments.join("\n")
353}
354
355fn is_option(ty: &Type) -> bool {
356    if let Type::Path(tp) = ty {
357        if let Some(seg) = tp.path.segments.last() {
358            return seg.ident == "Option";
359        }
360    }
361    false
362}
363
364fn has_reference(ty: &Type) -> bool {
365    match ty {
366        Type::Reference(_) => true,
367        Type::Path(tp) => {
368            for seg in &tp.path.segments {
369                if let syn::PathArguments::AngleBracketed(ab) = &seg.arguments {
370                    for arg in &ab.args {
371                        if let syn::GenericArgument::Type(inner) = arg {
372                            if has_reference(inner) {
373                                return true;
374                            }
375                        }
376                    }
377                }
378            }
379            false
380        }
381        _ => false,
382    }
383}
384
385/// Macro to execute function calls requested by Gemini and update the session history.
386///
387/// # Usage
388/// `execute_function_calls!(session, function1, function2, ...)`
389///
390/// # Returns
391/// A `Vec<Option<Result<serde_json::Value, String>>>` containing the results of each function call.
392/// The length of the vector matches the number of functions provided.
393/// - `Some(Ok(value))` if the function was called and succeeded.
394/// - `Some(Err(err))` if the function was called but failed.
395/// - `None` if the function was not called.
396///
397/// # Note
398/// The `session` is automatically updated with the `FunctionResponse` for successful calls.
399/// If a function call fails, the error is converted to a JSON object `{"Error": error_message}`
400/// and sent to the session as the function response.
401#[proc_macro]
402pub fn execute_function_calls(input: TokenStream) -> TokenStream {
403    use syn::parse::{Parse, ParseStream};
404    use syn::{Expr, Token};
405
406    struct ExecuteInput {
407        session: Expr,
408        _comma: Token![,],
409        functions: syn::punctuated::Punctuated<syn::Path, Token![,]>,
410    }
411
412    impl Parse for ExecuteInput {
413        fn parse(input: ParseStream) -> syn::Result<Self> {
414            Ok(ExecuteInput {
415                session: input.parse()?,
416                _comma: input.parse()?,
417                functions: input.parse_terminated(syn::Path::parse, Token![,])?,
418            })
419        }
420    }
421
422    let input = parse_macro_input!(input as ExecuteInput);
423    let callback: Expr = syn::parse_quote! {
424        |_name: String, result: Result<gemini_client_api::serde_json::Value, String>| {
425            match result {
426                Ok(value) => value,
427                Err(e) => gemini_client_api::serde_json::json!({"Error": e}),
428            }
429        }
430    };
431
432    generate_execute_logic(&input.session, &callback, &input.functions)
433}
434
435fn generate_execute_logic(
436    session: &syn::Expr,
437    callback: &syn::Expr,
438    functions: &syn::punctuated::Punctuated<syn::Path, syn::Token![,]>,
439) -> TokenStream {
440    let num_funcs = functions.len();
441
442    let match_arms = functions.iter().enumerate().map(|(i, path)| {
443        let name_str = path.segments.last().unwrap().ident.to_string();
444        quote! {
445            #name_str => {
446                let args = call.args().clone().unwrap_or(gemini_client_api::serde_json::json!({}));
447                let fut: gemini_client_api::futures::future::BoxFuture<'static, (usize, String, Result<gemini_client_api::serde_json::Value, String>)> = Box::pin(async move {
448                    (#i, #name_str.to_string(), #path::execute(args).await)
449                });
450                futures.push(fut);
451            }
452        }
453    });
454
455    let expanded = quote! {
456        {
457            let mut results_array = vec![None; #num_funcs];
458            // Define callback here to ensure it's available
459            let mut result_callback = #callback;
460
461            if let Some(chat) = #session.get_last_chat() {
462                let mut futures = Vec::new();
463                for part in chat.parts() {
464                    if let gemini_client_api::gemini::types::request::PartType::FunctionCall(call) = part.data() {
465                        match call.name().as_str() {
466                            #(#match_arms)*
467                            _ => {}
468                        }
469                    }
470                }
471                if !futures.is_empty() {
472                    let results = gemini_client_api::futures::future::join_all(futures).await;
473                    for (idx, name, res) in results {
474                        // Invoke callback regardless of success or failure
475                        let val_to_add = result_callback(name.clone(), res.clone());
476
477                        if let Err(e) = #session.add_function_response(name.clone(), val_to_add) {
478                             results_array[idx] = Some(Err(format!(
479                                "failed to add function response for `{}`: {}",
480                                name, e
481                            )));
482                            continue;
483                        }
484                        results_array[idx] = Some(res);
485                    }
486                }
487            }
488            results_array
489        }
490    };
491
492    TokenStream::from(expanded)
493}