Skip to main content

gemini_proc_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4    parse_macro_input, Attribute, Data, DeriveInput, Fields, FnArg, ItemFn, Lit, Meta, Pat, Type,
5};
6
7#[proc_macro_attribute]
8pub fn gemini_function(_attr: TokenStream, item: TokenStream) -> TokenStream {
9    let mut input_fn = parse_macro_input!(item as ItemFn);
10    let fn_name = &input_fn.sig.ident;
11    let fn_description = extract_doc_comments(&input_fn.attrs);
12
13    let mut properties = Vec::new();
14    let mut required = Vec::new();
15    let mut param_names = Vec::new();
16    let mut param_types = Vec::new();
17
18    for arg in input_fn.sig.inputs.iter_mut() {
19        if let FnArg::Typed(pat_type) = arg {
20            if let Pat::Ident(pat_ident) = &*pat_type.pat {
21                let param_name = pat_ident.ident.clone();
22                let param_name_str = param_name.to_string();
23                let param_type = (*pat_type.ty).clone();
24                let param_desc = extract_doc_comments(&pat_type.attrs);
25
26                if has_reference(&param_type) {
27                    return syn::Error::new_spanned(
28                        &param_type,
29                        "references are not supported in gemini_function. Use owned types like String instead.",
30                    )
31                    .to_compile_error()
32                    .into();
33                }
34
35                // Remove doc attributes from the function signature so it compiles
36                pat_type.attrs.retain(|attr| !attr.path().is_ident("doc"));
37
38                let is_optional = is_option(&param_type);
39
40                properties.push(quote! {
41                    let mut schema = <#param_type as GeminiSchema>::gemini_schema();
42                    if !#param_desc.is_empty() {
43                        if let Some(obj) = schema.as_object_mut() {
44                            obj.insert("description".to_string(), serde_json::json!(#param_desc));
45                        }
46                    }
47                    props.insert(#param_name_str.to_string(), schema);
48                });
49
50                if !is_optional {
51                    required.push(param_name_str);
52                }
53
54                param_names.push(param_name);
55                param_types.push(param_type);
56            }
57        }
58    }
59
60    let fn_name_str = fn_name.to_string();
61    let is_async = input_fn.sig.asyncness.is_some();
62    let call_await = if is_async {
63        quote! { .await }
64    } else {
65        quote! {}
66    };
67
68    let is_result = match &input_fn.sig.output {
69        syn::ReturnType::Default => false,
70        syn::ReturnType::Type(_, ty) => {
71            let s = quote!(#ty).to_string();
72            s.contains("Result")
73        }
74    };
75
76    let result_handling = if is_result {
77        quote! {
78            match result {
79                Ok(v) => Ok(serde_json::json!(v)),
80                Err(e) => Err(e.to_string()),
81            }
82        }
83    } else {
84        quote! {
85            Ok(serde_json::json!(result))
86        }
87    };
88
89    let expanded = quote! {
90        #input_fn
91
92        #[allow(non_camel_case_types)]
93        pub struct #fn_name { }
94
95        impl GeminiSchema for #fn_name {
96            fn gemini_schema() -> serde_json::Value {
97                use serde_json::{json, Map};
98                let mut props = Map::new();
99                #(#properties)*
100
101                json!({
102                    "name": #fn_name_str,
103                    "description": #fn_description,
104                    "parameters": {
105                        "type": "OBJECT",
106                        "properties": props,
107                        "required": [#(#required),*]
108                    }
109                })
110            }
111        }
112
113        impl #fn_name {
114            pub async fn execute(args: serde_json::Value) -> Result<serde_json::Value, String> {
115                use serde::Deserialize;
116                #[derive(Deserialize)]
117                struct Args {
118                    #(#param_names: #param_types,)*
119                }
120                let args = Args::deserialize(&args).map_err(|e| e.to_string())?;
121                let result = #fn_name(#(args.#param_names),*) #call_await;
122                #result_handling
123            }
124        }
125    };
126
127    TokenStream::from(expanded)
128}
129
130#[proc_macro]
131/// - Provide all functions to be called `execute_function_calls!(session, f1, f2...)`
132/// - `Returns` Vec<Option<Result<serde_json::Value, String>>>
133/// - Returned vec length always equals the number of functions passed
134/// - `None` if f_i was not called by Gemini
135/// *if function don't return type Result, it always return `Result::Ok(value)`*
136/// - `Session` struct is automatically updated with FunctionResponse only for `Ok` result
137pub fn execute_function_calls(input: TokenStream) -> TokenStream {
138    use syn::parse::{Parse, ParseStream};
139    use syn::{Expr, Token};
140
141    struct ExecuteInput {
142        session: Expr,
143        _comma: Token![,],
144        functions: syn::punctuated::Punctuated<syn::Path, Token![,]>,
145    }
146
147    impl Parse for ExecuteInput {
148        fn parse(input: ParseStream) -> syn::Result<Self> {
149            Ok(ExecuteInput {
150                session: input.parse()?,
151                _comma: input.parse()?,
152                functions: input.parse_terminated(syn::Path::parse, Token![,])?,
153            })
154        }
155    }
156
157    let input = parse_macro_input!(input as ExecuteInput);
158    let session = &input.session;
159    let functions = &input.functions;
160    let num_funcs = functions.len();
161
162    let match_arms = functions.iter().enumerate().map(|(i, path)| {
163        let name_str = path.segments.last().unwrap().ident.to_string();
164        quote! {
165            #name_str => {
166                let args = call.args().clone().unwrap_or(gemini_client_api::serde_json::json!({}));
167                let fut: gemini_client_api::futures::future::BoxFuture<'static, (usize, String, Result<gemini_client_api::serde_json::Value, String>)> = Box::pin(async move {
168                    (#i, #name_str.to_string(), #path::execute(args).await)
169                });
170                futures.push(fut);
171            }
172        }
173    });
174
175    let expanded = quote! {
176        {
177            let mut results_array = vec![None; #num_funcs];
178            if let Some(chat) = #session.get_last_chat() {
179                let mut futures = Vec::new();
180                for part in chat.parts() {
181                    if let gemini_client_api::gemini::types::request::PartType::FunctionCall(call) = part.data() {
182                        match call.name().as_str() {
183                            #(#match_arms)*
184                            _ => {}
185                        }
186                    }
187                }
188                if !futures.is_empty() {
189                    let results = gemini_client_api::futures::future::join_all(futures).await;
190                    for (idx, name, res) in results {
191                        if let Ok(ref val) = res {
192                            if let Err(e) = #session.add_function_response(name.clone(), val.clone()) {
193                                results_array[idx] = Some(Err(format!(
194                                    "failed to add function response for `{}`: {}",
195                                    name, e
196                                )));
197                                continue;
198                            }
199                        }
200                        results_array[idx] = Some(res);
201                    }
202                }
203            }
204            results_array
205        }
206    };
207
208    TokenStream::from(expanded)
209}
210
211#[proc_macro_attribute]
212pub fn gemini_schema(_attr: TokenStream, item: TokenStream) -> TokenStream {
213    let input = parse_macro_input!(item as DeriveInput);
214    let name = &input.ident;
215    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
216    let description = extract_doc_comments(&input.attrs);
217
218    let expanded = match &input.data {
219        Data::Struct(data) => {
220            let mut properties = Vec::new();
221            let mut required = Vec::new();
222
223            match &data.fields {
224                Fields::Named(fields) => {
225                    for field in &fields.named {
226                        let field_name = field.ident.as_ref().unwrap();
227                        let field_name_str = field_name.to_string();
228                        let field_type = &field.ty;
229                        let field_desc = extract_doc_comments(&field.attrs);
230
231                        if has_reference(field_type) {
232                            return syn::Error::new_spanned(
233                                field_type,
234                                "references are not supported in gemini_schema. Use owned types instead.",
235                            )
236                            .to_compile_error()
237                            .into();
238                        }
239
240                        let is_optional = is_option(field_type);
241
242                        properties.push(quote! {
243                            let mut schema = <#field_type as GeminiSchema>::gemini_schema();
244                            if !#field_desc.is_empty() {
245                                if let Some(obj) = schema.as_object_mut() {
246                                    obj.insert("description".to_string(), serde_json::json!(#field_desc));
247                                }
248                            }
249                            props.insert(#field_name_str.to_string(), schema);
250                        });
251
252                        if !is_optional {
253                            required.push(field_name_str);
254                        }
255                    }
256                }
257                _ => panic!("gemini_schema only supports named fields in structs"),
258            }
259
260            quote! {
261                impl #impl_generics GeminiSchema for #name #ty_generics #where_clause {
262                    fn gemini_schema() -> serde_json::Value {
263                        use serde_json::{json, Map};
264                        let mut props = Map::new();
265                        #(#properties)*
266
267                        let mut schema = json!({
268                            "type": "OBJECT",
269                            "properties": props,
270                            "required": [#(#required),*]
271                        });
272
273                        if !#description.is_empty() {
274                            if let Some(obj) = schema.as_object_mut() {
275                                obj.insert("description".to_string(), json!(#description));
276                            }
277                        }
278                        schema
279                    }
280                }
281            }
282        }
283        Data::Enum(data) => {
284            let mut variants = Vec::new();
285            for variant in &data.variants {
286                if !matches!(variant.fields, Fields::Unit) {
287                    panic!("gemini_schema only supports unit variants in enums");
288                }
289                variants.push(variant.ident.to_string());
290            }
291
292            quote! {
293                impl #impl_generics GeminiSchema for #name #ty_generics #where_clause {
294                    fn gemini_schema() -> serde_json::Value {
295                        use serde_json::json;
296                        let mut schema = json!({
297                            "type": "STRING",
298                            "enum": [#(#variants),*]
299                        });
300
301                        if !#description.is_empty() {
302                            if let Some(obj) = schema.as_object_mut() {
303                                obj.insert("description".to_string(), json!(#description));
304                            }
305                        }
306                        schema
307                    }
308                }
309            }
310        }
311        _ => panic!("gemini_schema only supports structs and enums"),
312    };
313
314    let output = quote! {
315        #input
316        #expanded
317    };
318
319    TokenStream::from(output)
320}
321
322fn extract_doc_comments(attrs: &[Attribute]) -> String {
323    let mut doc_comments = Vec::new();
324    for attr in attrs {
325        if attr.path().is_ident("doc") {
326            if let Meta::NameValue(nv) = &attr.meta {
327                if let syn::Expr::Lit(expr_lit) = &nv.value {
328                    if let Lit::Str(lit_str) = &expr_lit.lit {
329                        doc_comments.push(lit_str.value().trim().to_string());
330                    }
331                }
332            }
333        }
334    }
335    doc_comments.join("\n")
336}
337
338fn is_option(ty: &Type) -> bool {
339    if let Type::Path(tp) = ty {
340        if let Some(seg) = tp.path.segments.last() {
341            return seg.ident == "Option";
342        }
343    }
344    false
345}
346
347fn has_reference(ty: &Type) -> bool {
348    match ty {
349        Type::Reference(_) => true,
350        Type::Path(tp) => {
351            for seg in &tp.path.segments {
352                if let syn::PathArguments::AngleBracketed(ab) = &seg.arguments {
353                    for arg in &ab.args {
354                        if let syn::GenericArgument::Type(inner) = arg {
355                            if has_reference(inner) {
356                                return true;
357                            }
358                        }
359                    }
360                }
361            }
362            false
363        }
364        _ => false,
365    }
366}