Skip to main content

gemini_proc_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4    parse_macro_input, Attribute, Data, DeriveInput, Fields, FnArg, ItemFn, Lit, Meta, Pat, Type,
5};
6
7#[proc_macro_attribute]
8pub fn gemini_function(_attr: TokenStream, item: TokenStream) -> TokenStream {
9    let mut input_fn = parse_macro_input!(item as ItemFn);
10    let fn_name = &input_fn.sig.ident;
11    let fn_description = extract_doc_comments(&input_fn.attrs);
12
13    let mut properties = Vec::new();
14    let mut required = Vec::new();
15    let mut param_names = Vec::new();
16    let mut param_types = Vec::new();
17
18    for arg in input_fn.sig.inputs.iter_mut() {
19        if let FnArg::Typed(pat_type) = arg {
20            if let Pat::Ident(pat_ident) = &*pat_type.pat {
21                let param_name = pat_ident.ident.clone();
22                let param_name_str = param_name.to_string();
23                let param_type = (*pat_type.ty).clone();
24                let param_desc = extract_doc_comments(&pat_type.attrs);
25
26                // Remove doc attributes from the function signature so it compiles
27                pat_type.attrs.retain(|attr| !attr.path().is_ident("doc"));
28
29                let is_optional = is_option(&param_type);
30
31                properties.push(quote! {
32                    let mut schema = <#param_type as GeminiSchema>::gemini_schema();
33                    if !#param_desc.is_empty() {
34                        if let Some(obj) = schema.as_object_mut() {
35                            obj.insert("description".to_string(), serde_json::json!(#param_desc));
36                        }
37                    }
38                    props.insert(#param_name_str.to_string(), schema);
39                });
40
41                if !is_optional {
42                    required.push(param_name_str);
43                }
44
45                param_names.push(param_name);
46                param_types.push(param_type);
47            }
48        }
49    }
50
51    let fn_name_str = fn_name.to_string();
52    let is_async = input_fn.sig.asyncness.is_some();
53    let call_await = if is_async {
54        quote! { .await }
55    } else {
56        quote! {}
57    };
58
59    let is_result = match &input_fn.sig.output {
60        syn::ReturnType::Default => false,
61        syn::ReturnType::Type(_, ty) => {
62            let s = quote!(#ty).to_string();
63            s.contains("Result")
64        }
65    };
66
67    let result_handling = if is_result {
68        quote! {
69            match result {
70                Ok(v) => Ok(serde_json::json!(v)),
71                Err(e) => Err(e.to_string()),
72            }
73        }
74    } else {
75        quote! {
76            Ok(serde_json::json!(result))
77        }
78    };
79
80    let expanded = quote! {
81        #input_fn
82
83        #[allow(non_camel_case_types)]
84        pub struct #fn_name { }
85
86        impl GeminiSchema for #fn_name {
87            fn gemini_schema() -> serde_json::Value {
88                use serde_json::{json, Map};
89                let mut props = Map::new();
90                #(#properties)*
91
92                json!({
93                    "name": #fn_name_str,
94                    "description": #fn_description,
95                    "parameters": {
96                        "type": "OBJECT",
97                        "properties": props,
98                        "required": [#(#required),*]
99                    }
100                })
101            }
102        }
103
104        impl #fn_name {
105            pub async fn execute(args: serde_json::Value) -> Result<serde_json::Value, String> {
106                use serde::Deserialize;
107                #[derive(Deserialize)]
108                struct Args {
109                    #(#param_names: #param_types),*
110                }
111                let args: Args = serde_json::from_value(args).map_err(|e| e.to_string())?;
112                let result = #fn_name(#(args.#param_names),*) #call_await;
113                #result_handling
114            }
115        }
116    };
117
118    TokenStream::from(expanded)
119}
120
121#[proc_macro]
122pub fn execute_function_calls(input: TokenStream) -> TokenStream {
123    use syn::parse::{Parse, ParseStream};
124    use syn::{Expr, Token};
125
126    struct ExecuteInput {
127        session: Expr,
128        _comma: Token![,],
129        functions: syn::punctuated::Punctuated<syn::Path, Token![,]>,
130    }
131
132    impl Parse for ExecuteInput {
133        fn parse(input: ParseStream) -> syn::Result<Self> {
134            Ok(ExecuteInput {
135                session: input.parse()?,
136                _comma: input.parse()?,
137                functions: input.parse_terminated(syn::Path::parse, Token![,])?,
138            })
139        }
140    }
141
142    let input = parse_macro_input!(input as ExecuteInput);
143    let session = &input.session;
144    let functions = &input.functions;
145
146    let match_arms = functions.iter().map(|path| {
147        let name_str = path.segments.last().unwrap().ident.to_string();
148        quote! {
149            #name_str => {
150                let args = call.args().clone().unwrap_or(gemini_client_api::serde_json::json!({}));
151                let fut: std::pin::Pin<Box<dyn std::future::Future<Output = (String, Result<gemini_client_api::serde_json::Value, String>)>>> = Box::pin(async move {
152                    (#name_str.to_string(), #path::execute(args).await)
153                });
154                futures.push(fut);
155            }
156        }
157    });
158
159    let expanded = quote! {
160        {
161            let mut all_results = Vec::new();
162            if let Some(chat) = #session.get_last_chat() {
163                let mut futures = Vec::new();
164                for part in chat.parts() {
165                    if let gemini_client_api::gemini::types::request::Part::functionCall(call) = part {
166                        match call.name().as_str() {
167                            #(#match_arms)*
168                            _ => {}
169                        }
170                    }
171                }
172                if !futures.is_empty() {
173                    let results = gemini_client_api::futures::future::join_all(futures).await;
174                    for (name, res) in results {
175                        if let Ok(ref val) = res {
176                            let _ = #session.add_function_response(name, val.clone());
177                        }
178                        all_results.push(res);
179                    }
180                }
181            }
182            all_results
183        }
184    };
185
186    TokenStream::from(expanded)
187}
188
189#[proc_macro_attribute]
190pub fn gemini_schema(_attr: TokenStream, item: TokenStream) -> TokenStream {
191    let input = parse_macro_input!(item as DeriveInput);
192    let name = &input.ident;
193    let description = extract_doc_comments(&input.attrs);
194
195    let expanded = match &input.data {
196        Data::Struct(data) => {
197            let mut properties = Vec::new();
198            let mut required = Vec::new();
199
200            match &data.fields {
201                Fields::Named(fields) => {
202                    for field in &fields.named {
203                        let field_name = field.ident.as_ref().unwrap();
204                        let field_name_str = field_name.to_string();
205                        let field_type = &field.ty;
206                        let field_desc = extract_doc_comments(&field.attrs);
207
208                        let is_optional = is_option(field_type);
209
210                        properties.push(quote! {
211                            let mut schema = <#field_type as GeminiSchema>::gemini_schema();
212                            if !#field_desc.is_empty() {
213                                if let Some(obj) = schema.as_object_mut() {
214                                    obj.insert("description".to_string(), serde_json::json!(#field_desc));
215                                }
216                            }
217                            props.insert(#field_name_str.to_string(), schema);
218                        });
219
220                        if !is_optional {
221                            required.push(field_name_str);
222                        }
223                    }
224                }
225                _ => panic!("gemini_schema only supports named fields in structs"),
226            }
227
228            quote! {
229                impl GeminiSchema for #name {
230                    fn gemini_schema() -> serde_json::Value {
231                        use serde_json::{json, Map};
232                        let mut props = Map::new();
233                        #(#properties)*
234
235                        let mut schema = json!({
236                            "type": "OBJECT",
237                            "properties": props,
238                            "required": [#(#required),*]
239                        });
240
241                        if !#description.is_empty() {
242                            if let Some(obj) = schema.as_object_mut() {
243                                obj.insert("description".to_string(), json!(#description));
244                            }
245                        }
246                        schema
247                    }
248                }
249            }
250        }
251        Data::Enum(data) => {
252            let mut variants = Vec::new();
253            for variant in &data.variants {
254                if !matches!(variant.fields, Fields::Unit) {
255                    panic!("gemini_schema only supports unit variants in enums");
256                }
257                variants.push(variant.ident.to_string());
258            }
259
260            quote! {
261                impl GeminiSchema for #name {
262                    fn gemini_schema() -> serde_json::Value {
263                        use serde_json::json;
264                        let mut schema = json!({
265                            "type": "STRING",
266                            "enum": [#(#variants),*]
267                        });
268
269                        if !#description.is_empty() {
270                            if let Some(obj) = schema.as_object_mut() {
271                                obj.insert("description".to_string(), json!(#description));
272                            }
273                        }
274                        schema
275                    }
276                }
277            }
278        }
279        _ => panic!("gemini_schema only supports structs and enums"),
280    };
281
282    let output = quote! {
283        #input
284        #expanded
285    };
286
287    TokenStream::from(output)
288}
289
290fn extract_doc_comments(attrs: &[Attribute]) -> String {
291    let mut doc_comments = Vec::new();
292    for attr in attrs {
293        if attr.path().is_ident("doc") {
294            if let Meta::NameValue(nv) = &attr.meta {
295                if let syn::Expr::Lit(expr_lit) = &nv.value {
296                    if let Lit::Str(lit_str) = &expr_lit.lit {
297                        doc_comments.push(lit_str.value().trim().to_string());
298                    }
299                }
300            }
301        }
302    }
303    doc_comments.join("\n")
304}
305
306fn is_option(ty: &Type) -> bool {
307    if let Type::Path(tp) = ty {
308        if let Some(seg) = tp.path.segments.last() {
309            return seg.ident == "Option";
310        }
311    }
312    false
313}