Skip to main content

gemini_proc_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4    Attribute, Data, DeriveInput, Fields, FnArg, ItemFn, Lit, Meta, Pat, Type, parse_macro_input,
5};
6
7/// Attribute macro to mark a function as callable by Gemini.
8///
9/// This macro generates a schema for the function and an `execute` method to call it
10/// with JSON arguments.
11///
12/// # Requirements
13/// - Function arguments must be owned types that implement `GeminiSchema` (e.g., `String`, `i32`, `bool`).
14/// - References are not supported.
15/// - The function can be `async` and can return a `Result` (the `Ok` value must implement `Serialize`).
16///
17/// # Example
18/// ```ignore
19/// #[gemini_function]
20/// /// Returns the current weather for a location.
21/// fn get_weather(location: String) -> String {
22///     format!("The weather in {} is sunny.", location)
23/// }
24/// ```
25#[proc_macro_attribute]
26pub fn gemini_function(_attr: TokenStream, item: TokenStream) -> TokenStream {
27    let mut input_fn = parse_macro_input!(item as ItemFn);
28    let fn_name = &input_fn.sig.ident;
29    let fn_description = extract_doc_comments(&input_fn.attrs);
30
31    let mut properties = Vec::new();
32    let mut required = Vec::new();
33    let mut param_names = Vec::new();
34    let mut param_types = Vec::new();
35
36    for arg in input_fn.sig.inputs.iter_mut() {
37        if let FnArg::Typed(pat_type) = arg {
38            if let Pat::Ident(pat_ident) = &*pat_type.pat {
39                let param_name = pat_ident.ident.clone();
40                let param_name_str = param_name.to_string();
41                let param_type = (*pat_type.ty).clone();
42                let param_desc = extract_doc_comments(&pat_type.attrs);
43
44                if has_reference(&param_type) {
45                    return syn::Error::new_spanned(
46                        &param_type,
47                        "references are not supported in gemini_function. Use owned types like String instead.",
48                    )
49                    .to_compile_error()
50                    .into();
51                }
52
53                // Remove doc attributes from the function signature so it compiles
54                pat_type.attrs.retain(|attr| !attr.path().is_ident("doc"));
55
56                let is_optional = is_option(&param_type);
57
58                properties.push(quote! {
59                    let mut schema = <#param_type as GeminiSchema>::gemini_schema();
60                    if !#param_desc.is_empty() {
61                        if let Some(obj) = schema.as_object_mut() {
62                            obj.insert("description".to_string(), serde_json::json!(#param_desc));
63                        }
64                    }
65                    props.insert(#param_name_str.to_string(), schema);
66                });
67
68                if !is_optional {
69                    required.push(param_name_str);
70                }
71
72                param_names.push(param_name);
73                param_types.push(param_type);
74            }
75        }
76    }
77
78    let fn_name_str = fn_name.to_string();
79    let is_async = input_fn.sig.asyncness.is_some();
80    let call_await = if is_async {
81        quote! { .await }
82    } else {
83        quote! {}
84    };
85
86    let is_result = match &input_fn.sig.output {
87        syn::ReturnType::Default => false,
88        syn::ReturnType::Type(_, ty) => {
89            let s = quote!(#ty).to_string();
90            s.contains("Result")
91        }
92    };
93
94    let result_handling = if is_result {
95        quote! {
96            match result {
97                Ok(v) => Ok(serde_json::json!(v)),
98                Err(e) => Err(e.to_string()),
99            }
100        }
101    } else {
102        quote! {
103            Ok(serde_json::json!(result))
104        }
105    };
106
107    let expanded = quote! {
108        #input_fn
109
110        #[allow(non_camel_case_types)]
111        pub struct #fn_name { }
112
113        impl GeminiSchema for #fn_name {
114            fn gemini_schema() -> serde_json::Value {
115                use serde_json::{json, Map};
116                let mut props = Map::new();
117                #(#properties)*
118
119                json!({
120                    "name": #fn_name_str,
121                    "description": #fn_description,
122                    "parameters": {
123                        "type": "OBJECT",
124                        "properties": props,
125                        "required": [#(#required),*]
126                    }
127                })
128            }
129        }
130
131        impl #fn_name {
132            pub async fn execute(args: serde_json::Value) -> Result<serde_json::Value, String> {
133                use serde::Deserialize;
134                #[derive(Deserialize)]
135                struct Args {
136                    #(#param_names: #param_types,)*
137                }
138                let args = Args::deserialize(&args).map_err(|e| e.to_string())?;
139                let result = #fn_name(#(args.#param_names),*) #call_await;
140                #result_handling
141            }
142        }
143    };
144
145    TokenStream::from(expanded)
146}
147
148/// Macro to execute function calls requested by Gemini and update the session history.
149///
150/// # Usage
151/// `execute_function_calls!(session, function1, function2, ...)`
152///
153/// # Returns
154/// A `Vec<Option<Result<serde_json::Value, String>>>` containing the results of each function call.
155/// The length of the vector matches the number of functions provided.
156/// - `Some(Ok(value))` if the function was called and succeeded.
157/// - `Some(Err(err))` if the function was called but failed.
158/// - `None` if the function was not called.
159///
160/// # Note
161/// The `session` is automatically updated with the `FunctionResponse` for successful calls.
162#[proc_macro]
163pub fn execute_function_calls(input: TokenStream) -> TokenStream {
164    use syn::parse::{Parse, ParseStream};
165    use syn::{Expr, Token};
166
167    struct ExecuteInput {
168        session: Expr,
169        _comma: Token![,],
170        functions: syn::punctuated::Punctuated<syn::Path, Token![,]>,
171    }
172
173    impl Parse for ExecuteInput {
174        fn parse(input: ParseStream) -> syn::Result<Self> {
175            Ok(ExecuteInput {
176                session: input.parse()?,
177                _comma: input.parse()?,
178                functions: input.parse_terminated(syn::Path::parse, Token![,])?,
179            })
180        }
181    }
182
183    let input = parse_macro_input!(input as ExecuteInput);
184    let session = &input.session;
185    let functions = &input.functions;
186    let num_funcs = functions.len();
187
188    let match_arms = functions.iter().enumerate().map(|(i, path)| {
189        let name_str = path.segments.last().unwrap().ident.to_string();
190        quote! {
191            #name_str => {
192                let args = call.args().clone().unwrap_or(gemini_client_api::serde_json::json!({}));
193                let fut: gemini_client_api::futures::future::BoxFuture<'static, (usize, String, Result<gemini_client_api::serde_json::Value, String>)> = Box::pin(async move {
194                    (#i, #name_str.to_string(), #path::execute(args).await)
195                });
196                futures.push(fut);
197            }
198        }
199    });
200
201    let expanded = quote! {
202        {
203            let mut results_array = vec![None; #num_funcs];
204            if let Some(chat) = #session.get_last_chat() {
205                let mut futures = Vec::new();
206                for part in chat.parts() {
207                    if let gemini_client_api::gemini::types::request::PartType::FunctionCall(call) = part.data() {
208                        match call.name().as_str() {
209                            #(#match_arms)*
210                            _ => {}
211                        }
212                    }
213                }
214                if !futures.is_empty() {
215                    let results = gemini_client_api::futures::future::join_all(futures).await;
216                    for (idx, name, res) in results {
217                        if let Ok(ref val) = res {
218                            if let Err(e) = #session.add_function_response(name.clone(), val.clone()) {
219                                results_array[idx] = Some(Err(format!(
220                                    "failed to add function response for `{}`: {}",
221                                    name, e
222                                )));
223                                continue;
224                            }
225                        }
226                        results_array[idx] = Some(res);
227                    }
228                }
229            }
230            results_array
231        }
232    };
233
234    TokenStream::from(expanded)
235}
236
237/// Macro to execute function calls requested by Gemini and update the session history, with a custom callback for results.
238///
239/// # Usage
240/// `execute_function_calls_with_callback!(session, callback, function1, function2, ...)`
241///
242/// The `callback` should be a closure or function that takes `Result<serde_json::Value, String>`
243/// and returns `serde_json::Value`.
244///
245/// # Returns
246/// A `Vec<Option<Result<serde_json::Value, String>>>` containing the results of each function call.
247/// - `Some(Ok(value))` if the function was called and succeeded.
248/// - `Some(Err(err))` if the function was called but failed.
249/// - `None` if the function was not called.
250///
251/// # Note
252/// The `session` is automatically updated with the `FunctionResponse`.
253/// The `callback` is invoked with the result of the function execution (whether `Ok` or `Err`)
254/// and its return value is used to update the session.
255#[proc_macro]
256pub fn execute_function_calls_with_callback(input: TokenStream) -> TokenStream {
257    use syn::parse::{Parse, ParseStream};
258    use syn::{Expr, Token};
259
260    struct ExecuteWithCallbackInput {
261        session: Expr,
262        _comma1: Token![,],
263        callback: Expr,
264        _comma2: Token![,],
265        functions: syn::punctuated::Punctuated<syn::Path, Token![,]>,
266    }
267
268    impl Parse for ExecuteWithCallbackInput {
269        fn parse(input: ParseStream) -> syn::Result<Self> {
270            Ok(ExecuteWithCallbackInput {
271                session: input.parse()?,
272                _comma1: input.parse()?,
273                callback: input.parse()?,
274                _comma2: input.parse()?,
275                functions: input.parse_terminated(syn::Path::parse, Token![,])?,
276            })
277        }
278    }
279
280    let input = parse_macro_input!(input as ExecuteWithCallbackInput);
281    let session = &input.session;
282    let callback = &input.callback;
283    let functions = &input.functions;
284    let num_funcs = functions.len();
285
286    let match_arms = functions.iter().enumerate().map(|(i, path)| {
287        let name_str = path.segments.last().unwrap().ident.to_string();
288        quote! {
289            #name_str => {
290                let args = call.args().clone().unwrap_or(gemini_client_api::serde_json::json!({}));
291                let fut: gemini_client_api::futures::future::BoxFuture<'static, (usize, String, Result<gemini_client_api::serde_json::Value, String>)> = Box::pin(async move {
292                    (#i, #name_str.to_string(), #path::execute(args).await)
293                });
294                futures.push(fut);
295            }
296        }
297    });
298
299    let expanded = quote! {
300        {
301            let mut results_array = vec![None; #num_funcs];
302            // Define callback here to ensure it's available
303            let mut result_callback = #callback;
304
305            if let Some(chat) = #session.get_last_chat() {
306                let mut futures = Vec::new();
307                for part in chat.parts() {
308                    if let gemini_client_api::gemini::types::request::PartType::FunctionCall(call) = part.data() {
309                        match call.name().as_str() {
310                            #(#match_arms)*
311                            _ => {}
312                        }
313                    }
314                }
315                if !futures.is_empty() {
316                    let results = gemini_client_api::futures::future::join_all(futures).await;
317                    for (idx, name, res) in results {
318                        // Invoke callback regardless of success or failure
319                        let val_to_add = result_callback(res.clone());
320
321                        if let Err(e) = #session.add_function_response(name.clone(), val_to_add) {
322                             results_array[idx] = Some(Err(format!(
323                                "failed to add function response for `{}`: {}",
324                                name, e
325                            )));
326                            continue;
327                        }
328                        results_array[idx] = Some(res);
329                    }
330                }
331            }
332            results_array
333        }
334    };
335
336    TokenStream::from(expanded)
337}
338
339/// Attribute macro to derive the `GeminiSchema` trait for a struct or enum.
340///
341/// This allows the type to be used in structured output (`set_json_mode`) or as a parameter
342/// in a `gemini_function`.
343///
344/// # Requirements
345/// - For structs: only named fields are supported.
346/// - For enums: only unit variants (no data) are supported.
347/// - Field/variant types must also implement `GeminiSchema`.
348/// - Doc comments on fields and variants are extracted as descriptions in the schema.
349///
350/// # Example
351/// ```ignore
352/// #[gemini_schema]
353/// struct SearchResult {
354///     /// The title of the page.
355///     title: String,
356///     /// The URL of the page.
357///     url: String,
358/// }
359/// ```
360#[proc_macro_attribute]
361pub fn gemini_schema(_attr: TokenStream, item: TokenStream) -> TokenStream {
362    let input = parse_macro_input!(item as DeriveInput);
363    let name = &input.ident;
364    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
365    let description = extract_doc_comments(&input.attrs);
366
367    let expanded = match &input.data {
368        Data::Struct(data) => {
369            let mut properties = Vec::new();
370            let mut required = Vec::new();
371
372            match &data.fields {
373                Fields::Named(fields) => {
374                    for field in &fields.named {
375                        let field_name = field.ident.as_ref().unwrap();
376                        let field_name_str = field_name.to_string();
377                        let field_type = &field.ty;
378                        let field_desc = extract_doc_comments(&field.attrs);
379
380                        if has_reference(field_type) {
381                            return syn::Error::new_spanned(
382                                field_type,
383                                "references are not supported in gemini_schema. Use owned types instead.",
384                            )
385                            .to_compile_error()
386                            .into();
387                        }
388
389                        let is_optional = is_option(field_type);
390
391                        properties.push(quote! {
392                            let mut schema = <#field_type as GeminiSchema>::gemini_schema();
393                            if !#field_desc.is_empty() {
394                                if let Some(obj) = schema.as_object_mut() {
395                                    obj.insert("description".to_string(), serde_json::json!(#field_desc));
396                                }
397                            }
398                            props.insert(#field_name_str.to_string(), schema);
399                        });
400
401                        if !is_optional {
402                            required.push(field_name_str);
403                        }
404                    }
405                }
406                _ => panic!("gemini_schema only supports named fields in structs"),
407            }
408
409            quote! {
410                impl #impl_generics GeminiSchema for #name #ty_generics #where_clause {
411                    fn gemini_schema() -> serde_json::Value {
412                        use serde_json::{json, Map};
413                        let mut props = Map::new();
414                        #(#properties)*
415
416                        let mut schema = json!({
417                            "type": "OBJECT",
418                            "properties": props,
419                            "required": [#(#required),*]
420                        });
421
422                        if !#description.is_empty() {
423                            if let Some(obj) = schema.as_object_mut() {
424                                obj.insert("description".to_string(), json!(#description));
425                            }
426                        }
427                        schema
428                    }
429                }
430            }
431        }
432        Data::Enum(data) => {
433            let mut variants = Vec::new();
434            for variant in &data.variants {
435                if !matches!(variant.fields, Fields::Unit) {
436                    panic!("gemini_schema only supports unit variants in enums");
437                }
438                variants.push(variant.ident.to_string());
439            }
440
441            quote! {
442                impl #impl_generics GeminiSchema for #name #ty_generics #where_clause {
443                    fn gemini_schema() -> serde_json::Value {
444                        use serde_json::json;
445                        let mut schema = json!({
446                            "type": "STRING",
447                            "enum": [#(#variants),*]
448                        });
449
450                        if !#description.is_empty() {
451                            if let Some(obj) = schema.as_object_mut() {
452                                obj.insert("description".to_string(), json!(#description));
453                            }
454                        }
455                        schema
456                    }
457                }
458            }
459        }
460        _ => panic!("gemini_schema only supports structs and enums"),
461    };
462
463    let output = quote! {
464        #input
465        #expanded
466    };
467
468    TokenStream::from(output)
469}
470
471fn extract_doc_comments(attrs: &[Attribute]) -> String {
472    let mut doc_comments = Vec::new();
473    for attr in attrs {
474        if attr.path().is_ident("doc") {
475            if let Meta::NameValue(nv) = &attr.meta {
476                if let syn::Expr::Lit(expr_lit) = &nv.value {
477                    if let Lit::Str(lit_str) = &expr_lit.lit {
478                        doc_comments.push(lit_str.value().trim().to_string());
479                    }
480                }
481            }
482        }
483    }
484    doc_comments.join("\n")
485}
486
487fn is_option(ty: &Type) -> bool {
488    if let Type::Path(tp) = ty {
489        if let Some(seg) = tp.path.segments.last() {
490            return seg.ident == "Option";
491        }
492    }
493    false
494}
495
496fn has_reference(ty: &Type) -> bool {
497    match ty {
498        Type::Reference(_) => true,
499        Type::Path(tp) => {
500            for seg in &tp.path.segments {
501                if let syn::PathArguments::AngleBracketed(ab) = &seg.arguments {
502                    for arg in &ab.args {
503                        if let syn::GenericArgument::Type(inner) = arg {
504                            if has_reference(inner) {
505                                return true;
506                            }
507                        }
508                    }
509                }
510            }
511            false
512        }
513        _ => false,
514    }
515}