Skip to main content

gemini_proc_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4    parse_macro_input, Attribute, Data, DeriveInput, Fields, FnArg, ItemFn, Lit, Meta, Pat, Type,
5};
6
7#[proc_macro_attribute]
8pub fn gemini_function(_attr: TokenStream, item: TokenStream) -> TokenStream {
9    let mut input_fn = parse_macro_input!(item as ItemFn);
10    let fn_name = &input_fn.sig.ident;
11    let fn_description = extract_doc_comments(&input_fn.attrs);
12
13    let mut properties = Vec::new();
14    let mut required = Vec::new();
15    let mut param_names = Vec::new();
16    let mut param_types = Vec::new();
17
18    for arg in input_fn.sig.inputs.iter_mut() {
19        if let FnArg::Typed(pat_type) = arg {
20            if let Pat::Ident(pat_ident) = &*pat_type.pat {
21                let param_name = pat_ident.ident.clone();
22                let param_name_str = param_name.to_string();
23                let param_type = (*pat_type.ty).clone();
24                let param_desc = extract_doc_comments(&pat_type.attrs);
25
26                // Remove doc attributes from the function signature so it compiles
27                pat_type.attrs.retain(|attr| !attr.path().is_ident("doc"));
28
29                let is_optional = is_option(&param_type);
30
31                properties.push(quote! {
32                    let mut schema = <#param_type as GeminiSchema>::gemini_schema();
33                    if !#param_desc.is_empty() {
34                        if let Some(obj) = schema.as_object_mut() {
35                            obj.insert("description".to_string(), serde_json::json!(#param_desc));
36                        }
37                    }
38                    props.insert(#param_name_str.to_string(), schema);
39                });
40
41                if !is_optional {
42                    required.push(param_name_str);
43                }
44
45                param_names.push(param_name);
46                param_types.push(param_type);
47            }
48        }
49    }
50
51    let fn_name_str = fn_name.to_string();
52    let is_async = input_fn.sig.asyncness.is_some();
53    let call_await = if is_async {
54        quote! { .await }
55    } else {
56        quote! {}
57    };
58
59    let is_result = match &input_fn.sig.output {
60        syn::ReturnType::Default => false,
61        syn::ReturnType::Type(_, ty) => {
62            let s = quote!(#ty).to_string();
63            s.contains("Result")
64        }
65    };
66
67    let result_handling = if is_result {
68        quote! {
69            match result {
70                Ok(v) => Ok(serde_json::json!(v)),
71                Err(e) => Err(e.to_string()),
72            }
73        }
74    } else {
75        quote! {
76            Ok(serde_json::json!(result))
77        }
78    };
79
80    let expanded = quote! {
81        #input_fn
82
83        #[allow(non_camel_case_types)]
84        pub struct #fn_name { }
85
86        impl GeminiSchema for #fn_name {
87            fn gemini_schema() -> serde_json::Value {
88                use serde_json::{json, Map};
89                let mut props = Map::new();
90                #(#properties)*
91
92                json!({
93                    "name": #fn_name_str,
94                    "description": #fn_description,
95                    "parameters": {
96                        "type": "OBJECT",
97                        "properties": props,
98                        "required": [#(#required),*]
99                    }
100                })
101            }
102        }
103
104        impl #fn_name {
105            pub async fn execute(args: serde_json::Value) -> Result<serde_json::Value, String> {
106                use serde::Deserialize;
107                #[derive(Deserialize)]
108                struct Args {
109                    #(#param_names: #param_types),*
110                }
111                let args: Args = serde_json::from_value(args).map_err(|e| e.to_string())?;
112                let result = #fn_name(#(args.#param_names),*) #call_await;
113                #result_handling
114            }
115        }
116    };
117
118    TokenStream::from(expanded)
119}
120
121#[proc_macro]
122/// - Provide all functions to be called `execute_function_calls!(session, f1, f2...)`
123/// - `Returns` vec![Result of f2, Result of f2 ...]
124/// *if function don't return type Result, it always return `Result::Ok(value)`*
125/// - `Session` struct is automatically updated with FunctionResponse only for `Ok` result
126pub fn execute_function_calls(input: TokenStream) -> TokenStream {
127    use syn::parse::{Parse, ParseStream};
128    use syn::{Expr, Token};
129
130    struct ExecuteInput {
131        session: Expr,
132        _comma: Token![,],
133        functions: syn::punctuated::Punctuated<syn::Path, Token![,]>,
134    }
135
136    impl Parse for ExecuteInput {
137        fn parse(input: ParseStream) -> syn::Result<Self> {
138            Ok(ExecuteInput {
139                session: input.parse()?,
140                _comma: input.parse()?,
141                functions: input.parse_terminated(syn::Path::parse, Token![,])?,
142            })
143        }
144    }
145
146    let input = parse_macro_input!(input as ExecuteInput);
147    let session = &input.session;
148    let functions = &input.functions;
149
150    let match_arms = functions.iter().map(|path| {
151        let name_str = path.segments.last().unwrap().ident.to_string();
152        quote! {
153            #name_str => {
154                let args = call.args().clone().unwrap_or(gemini_client_api::serde_json::json!({}));
155                let fut: std::pin::Pin<Box<dyn std::future::Future<Output = (String, Result<gemini_client_api::serde_json::Value, String>)>>> = Box::pin(async move {
156                    (#name_str.to_string(), #path::execute(args).await)
157                });
158                futures.push(fut);
159            }
160        }
161    });
162
163    let expanded = quote! {
164        {
165            let mut all_results = Vec::new();
166            if let Some(chat) = #session.get_last_chat() {
167                let mut futures = Vec::new();
168                for part in chat.parts() {
169                    if let gemini_client_api::gemini::types::request::PartType::FunctionCall(call) = part.data() {
170                        match call.name().as_str() {
171                            #(#match_arms)*
172                            _ => {}
173                        }
174                    }
175                }
176                if !futures.is_empty() {
177                    let results = gemini_client_api::futures::future::join_all(futures).await;
178                    for (name, res) in results {
179                        if let Ok(ref val) = res {
180                            let _ = #session.add_function_response(name, val.clone());
181                        }
182                        all_results.push(res);
183                    }
184                }
185            }
186            all_results
187        }
188    };
189
190    TokenStream::from(expanded)
191}
192
193#[proc_macro_attribute]
194pub fn gemini_schema(_attr: TokenStream, item: TokenStream) -> TokenStream {
195    let input = parse_macro_input!(item as DeriveInput);
196    let name = &input.ident;
197    let description = extract_doc_comments(&input.attrs);
198
199    let expanded = match &input.data {
200        Data::Struct(data) => {
201            let mut properties = Vec::new();
202            let mut required = Vec::new();
203
204            match &data.fields {
205                Fields::Named(fields) => {
206                    for field in &fields.named {
207                        let field_name = field.ident.as_ref().unwrap();
208                        let field_name_str = field_name.to_string();
209                        let field_type = &field.ty;
210                        let field_desc = extract_doc_comments(&field.attrs);
211
212                        let is_optional = is_option(field_type);
213
214                        properties.push(quote! {
215                            let mut schema = <#field_type as GeminiSchema>::gemini_schema();
216                            if !#field_desc.is_empty() {
217                                if let Some(obj) = schema.as_object_mut() {
218                                    obj.insert("description".to_string(), serde_json::json!(#field_desc));
219                                }
220                            }
221                            props.insert(#field_name_str.to_string(), schema);
222                        });
223
224                        if !is_optional {
225                            required.push(field_name_str);
226                        }
227                    }
228                }
229                _ => panic!("gemini_schema only supports named fields in structs"),
230            }
231
232            quote! {
233                impl GeminiSchema for #name {
234                    fn gemini_schema() -> serde_json::Value {
235                        use serde_json::{json, Map};
236                        let mut props = Map::new();
237                        #(#properties)*
238
239                        let mut schema = json!({
240                            "type": "OBJECT",
241                            "properties": props,
242                            "required": [#(#required),*]
243                        });
244
245                        if !#description.is_empty() {
246                            if let Some(obj) = schema.as_object_mut() {
247                                obj.insert("description".to_string(), json!(#description));
248                            }
249                        }
250                        schema
251                    }
252                }
253            }
254        }
255        Data::Enum(data) => {
256            let mut variants = Vec::new();
257            for variant in &data.variants {
258                if !matches!(variant.fields, Fields::Unit) {
259                    panic!("gemini_schema only supports unit variants in enums");
260                }
261                variants.push(variant.ident.to_string());
262            }
263
264            quote! {
265                impl GeminiSchema for #name {
266                    fn gemini_schema() -> serde_json::Value {
267                        use serde_json::json;
268                        let mut schema = json!({
269                            "type": "STRING",
270                            "enum": [#(#variants),*]
271                        });
272
273                        if !#description.is_empty() {
274                            if let Some(obj) = schema.as_object_mut() {
275                                obj.insert("description".to_string(), json!(#description));
276                            }
277                        }
278                        schema
279                    }
280                }
281            }
282        }
283        _ => panic!("gemini_schema only supports structs and enums"),
284    };
285
286    let output = quote! {
287        #input
288        #expanded
289    };
290
291    TokenStream::from(output)
292}
293
294fn extract_doc_comments(attrs: &[Attribute]) -> String {
295    let mut doc_comments = Vec::new();
296    for attr in attrs {
297        if attr.path().is_ident("doc") {
298            if let Meta::NameValue(nv) = &attr.meta {
299                if let syn::Expr::Lit(expr_lit) = &nv.value {
300                    if let Lit::Str(lit_str) = &expr_lit.lit {
301                        doc_comments.push(lit_str.value().trim().to_string());
302                    }
303                }
304            }
305        }
306    }
307    doc_comments.join("\n")
308}
309
310fn is_option(ty: &Type) -> bool {
311    if let Type::Path(tp) = ty {
312        if let Some(seg) = tp.path.segments.last() {
313            return seg.ident == "Option";
314        }
315    }
316    false
317}