1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4 Attribute, Data, DeriveInput, Fields, FnArg, ItemFn, Lit, Meta, Pat, Type, parse_macro_input,
5};
6
7#[proc_macro_attribute]
26pub fn gemini_function(_attr: TokenStream, item: TokenStream) -> TokenStream {
27 let mut input_fn = parse_macro_input!(item as ItemFn);
28 let fn_name = &input_fn.sig.ident;
29 let fn_description = extract_doc_comments(&input_fn.attrs);
30
31 let mut properties = Vec::new();
32 let mut required = Vec::new();
33 let mut param_names = Vec::new();
34 let mut param_types = Vec::new();
35
36 for arg in input_fn.sig.inputs.iter_mut() {
37 if let FnArg::Typed(pat_type) = arg {
38 if let Pat::Ident(pat_ident) = &*pat_type.pat {
39 let param_name = pat_ident.ident.clone();
40 let param_name_str = param_name.to_string();
41 let param_type = (*pat_type.ty).clone();
42 let param_desc = extract_doc_comments(&pat_type.attrs);
43
44 if has_reference(¶m_type) {
45 return syn::Error::new_spanned(
46 ¶m_type,
47 "references are not supported in gemini_function. Use owned types like String instead.",
48 )
49 .to_compile_error()
50 .into();
51 }
52
53 pat_type.attrs.retain(|attr| !attr.path().is_ident("doc"));
55
56 let is_optional = is_option(¶m_type);
57
58 properties.push(quote! {
59 let mut schema = <#param_type as GeminiSchema>::gemini_schema();
60 if !#param_desc.is_empty() {
61 if let Some(obj) = schema.as_object_mut() {
62 obj.insert("description".to_string(), serde_json::json!(#param_desc));
63 }
64 }
65 props.insert(#param_name_str.to_string(), schema);
66 });
67
68 if !is_optional {
69 required.push(param_name_str);
70 }
71
72 param_names.push(param_name);
73 param_types.push(param_type);
74 }
75 }
76 }
77
78 let fn_name_str = fn_name.to_string();
79 let is_async = input_fn.sig.asyncness.is_some();
80 let call_await = if is_async {
81 quote! { .await }
82 } else {
83 quote! {}
84 };
85
86 let is_result = match &input_fn.sig.output {
87 syn::ReturnType::Default => false,
88 syn::ReturnType::Type(_, ty) => {
89 let s = quote!(#ty).to_string();
90 s.contains("Result")
91 }
92 };
93
94 let result_handling = if is_result {
95 quote! {
96 match result {
97 Ok(v) => Ok(serde_json::json!(v)),
98 Err(e) => Err(e.to_string()),
99 }
100 }
101 } else {
102 quote! {
103 Ok(serde_json::json!(result))
104 }
105 };
106
107 let expanded = quote! {
108 #input_fn
109
110 #[allow(non_camel_case_types)]
111 pub struct #fn_name { }
112
113 impl GeminiSchema for #fn_name {
114 fn gemini_schema() -> serde_json::Value {
115 use serde_json::{json, Map};
116 let mut props = Map::new();
117 #(#properties)*
118
119 json!({
120 "name": #fn_name_str,
121 "description": #fn_description,
122 "parameters": {
123 "type": "OBJECT",
124 "properties": props,
125 "required": [#(#required),*]
126 }
127 })
128 }
129 }
130
131 impl #fn_name {
132 pub async fn execute(args: serde_json::Value) -> Result<serde_json::Value, String> {
133 use serde::Deserialize;
134 #[derive(Deserialize)]
135 struct Args {
136 #(#param_names: #param_types,)*
137 }
138 let args = Args::deserialize(&args).map_err(|e| e.to_string())?;
139 let result = #fn_name(#(args.#param_names),*) #call_await;
140 #result_handling
141 }
142 }
143 };
144
145 TokenStream::from(expanded)
146}
147
148#[proc_macro]
163pub fn execute_function_calls(input: TokenStream) -> TokenStream {
164 use syn::parse::{Parse, ParseStream};
165 use syn::{Expr, Token};
166
167 struct ExecuteInput {
168 session: Expr,
169 _comma: Token![,],
170 functions: syn::punctuated::Punctuated<syn::Path, Token![,]>,
171 }
172
173 impl Parse for ExecuteInput {
174 fn parse(input: ParseStream) -> syn::Result<Self> {
175 Ok(ExecuteInput {
176 session: input.parse()?,
177 _comma: input.parse()?,
178 functions: input.parse_terminated(syn::Path::parse, Token![,])?,
179 })
180 }
181 }
182
183 let input = parse_macro_input!(input as ExecuteInput);
184 let session = &input.session;
185 let functions = &input.functions;
186 let num_funcs = functions.len();
187
188 let match_arms = functions.iter().enumerate().map(|(i, path)| {
189 let name_str = path.segments.last().unwrap().ident.to_string();
190 quote! {
191 #name_str => {
192 let args = call.args().clone().unwrap_or(gemini_client_api::serde_json::json!({}));
193 let fut: gemini_client_api::futures::future::BoxFuture<'static, (usize, String, Result<gemini_client_api::serde_json::Value, String>)> = Box::pin(async move {
194 (#i, #name_str.to_string(), #path::execute(args).await)
195 });
196 futures.push(fut);
197 }
198 }
199 });
200
201 let expanded = quote! {
202 {
203 let mut results_array = vec![None; #num_funcs];
204 if let Some(chat) = #session.get_last_chat() {
205 let mut futures = Vec::new();
206 for part in chat.parts() {
207 if let gemini_client_api::gemini::types::request::PartType::FunctionCall(call) = part.data() {
208 match call.name().as_str() {
209 #(#match_arms)*
210 _ => {}
211 }
212 }
213 }
214 if !futures.is_empty() {
215 let results = gemini_client_api::futures::future::join_all(futures).await;
216 for (idx, name, res) in results {
217 if let Ok(ref val) = res {
218 if let Err(e) = #session.add_function_response(name.clone(), val.clone()) {
219 results_array[idx] = Some(Err(format!(
220 "failed to add function response for `{}`: {}",
221 name, e
222 )));
223 continue;
224 }
225 }
226 results_array[idx] = Some(res);
227 }
228 }
229 }
230 results_array
231 }
232 };
233
234 TokenStream::from(expanded)
235}
236
237#[proc_macro]
256pub fn execute_function_calls_with_callback(input: TokenStream) -> TokenStream {
257 use syn::parse::{Parse, ParseStream};
258 use syn::{Expr, Token};
259
260 struct ExecuteWithCallbackInput {
261 session: Expr,
262 _comma1: Token![,],
263 callback: Expr,
264 _comma2: Token![,],
265 functions: syn::punctuated::Punctuated<syn::Path, Token![,]>,
266 }
267
268 impl Parse for ExecuteWithCallbackInput {
269 fn parse(input: ParseStream) -> syn::Result<Self> {
270 Ok(ExecuteWithCallbackInput {
271 session: input.parse()?,
272 _comma1: input.parse()?,
273 callback: input.parse()?,
274 _comma2: input.parse()?,
275 functions: input.parse_terminated(syn::Path::parse, Token![,])?,
276 })
277 }
278 }
279
280 let input = parse_macro_input!(input as ExecuteWithCallbackInput);
281 let session = &input.session;
282 let callback = &input.callback;
283 let functions = &input.functions;
284 let num_funcs = functions.len();
285
286 let match_arms = functions.iter().enumerate().map(|(i, path)| {
287 let name_str = path.segments.last().unwrap().ident.to_string();
288 quote! {
289 #name_str => {
290 let args = call.args().clone().unwrap_or(gemini_client_api::serde_json::json!({}));
291 let fut: gemini_client_api::futures::future::BoxFuture<'static, (usize, String, Result<gemini_client_api::serde_json::Value, String>)> = Box::pin(async move {
292 (#i, #name_str.to_string(), #path::execute(args).await)
293 });
294 futures.push(fut);
295 }
296 }
297 });
298
299 let expanded = quote! {
300 {
301 let mut results_array = vec![None; #num_funcs];
302 let mut result_callback = #callback;
304
305 if let Some(chat) = #session.get_last_chat() {
306 let mut futures = Vec::new();
307 for part in chat.parts() {
308 if let gemini_client_api::gemini::types::request::PartType::FunctionCall(call) = part.data() {
309 match call.name().as_str() {
310 #(#match_arms)*
311 _ => {}
312 }
313 }
314 }
315 if !futures.is_empty() {
316 let results = gemini_client_api::futures::future::join_all(futures).await;
317 for (idx, name, res) in results {
318 let val_to_add = result_callback(res.clone());
320
321 if let Err(e) = #session.add_function_response(name.clone(), val_to_add) {
322 results_array[idx] = Some(Err(format!(
323 "failed to add function response for `{}`: {}",
324 name, e
325 )));
326 continue;
327 }
328 results_array[idx] = Some(res);
329 }
330 }
331 }
332 results_array
333 }
334 };
335
336 TokenStream::from(expanded)
337}
338
339#[proc_macro_attribute]
361pub fn gemini_schema(_attr: TokenStream, item: TokenStream) -> TokenStream {
362 let input = parse_macro_input!(item as DeriveInput);
363 let name = &input.ident;
364 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
365 let description = extract_doc_comments(&input.attrs);
366
367 let expanded = match &input.data {
368 Data::Struct(data) => {
369 let mut properties = Vec::new();
370 let mut required = Vec::new();
371
372 match &data.fields {
373 Fields::Named(fields) => {
374 for field in &fields.named {
375 let field_name = field.ident.as_ref().unwrap();
376 let field_name_str = field_name.to_string();
377 let field_type = &field.ty;
378 let field_desc = extract_doc_comments(&field.attrs);
379
380 if has_reference(field_type) {
381 return syn::Error::new_spanned(
382 field_type,
383 "references are not supported in gemini_schema. Use owned types instead.",
384 )
385 .to_compile_error()
386 .into();
387 }
388
389 let is_optional = is_option(field_type);
390
391 properties.push(quote! {
392 let mut schema = <#field_type as GeminiSchema>::gemini_schema();
393 if !#field_desc.is_empty() {
394 if let Some(obj) = schema.as_object_mut() {
395 obj.insert("description".to_string(), serde_json::json!(#field_desc));
396 }
397 }
398 props.insert(#field_name_str.to_string(), schema);
399 });
400
401 if !is_optional {
402 required.push(field_name_str);
403 }
404 }
405 }
406 _ => panic!("gemini_schema only supports named fields in structs"),
407 }
408
409 quote! {
410 impl #impl_generics GeminiSchema for #name #ty_generics #where_clause {
411 fn gemini_schema() -> serde_json::Value {
412 use serde_json::{json, Map};
413 let mut props = Map::new();
414 #(#properties)*
415
416 let mut schema = json!({
417 "type": "OBJECT",
418 "properties": props,
419 "required": [#(#required),*]
420 });
421
422 if !#description.is_empty() {
423 if let Some(obj) = schema.as_object_mut() {
424 obj.insert("description".to_string(), json!(#description));
425 }
426 }
427 schema
428 }
429 }
430 }
431 }
432 Data::Enum(data) => {
433 let mut variants = Vec::new();
434 for variant in &data.variants {
435 if !matches!(variant.fields, Fields::Unit) {
436 panic!("gemini_schema only supports unit variants in enums");
437 }
438 variants.push(variant.ident.to_string());
439 }
440
441 quote! {
442 impl #impl_generics GeminiSchema for #name #ty_generics #where_clause {
443 fn gemini_schema() -> serde_json::Value {
444 use serde_json::json;
445 let mut schema = json!({
446 "type": "STRING",
447 "enum": [#(#variants),*]
448 });
449
450 if !#description.is_empty() {
451 if let Some(obj) = schema.as_object_mut() {
452 obj.insert("description".to_string(), json!(#description));
453 }
454 }
455 schema
456 }
457 }
458 }
459 }
460 _ => panic!("gemini_schema only supports structs and enums"),
461 };
462
463 let output = quote! {
464 #input
465 #expanded
466 };
467
468 TokenStream::from(output)
469}
470
471fn extract_doc_comments(attrs: &[Attribute]) -> String {
472 let mut doc_comments = Vec::new();
473 for attr in attrs {
474 if attr.path().is_ident("doc") {
475 if let Meta::NameValue(nv) = &attr.meta {
476 if let syn::Expr::Lit(expr_lit) = &nv.value {
477 if let Lit::Str(lit_str) = &expr_lit.lit {
478 doc_comments.push(lit_str.value().trim().to_string());
479 }
480 }
481 }
482 }
483 }
484 doc_comments.join("\n")
485}
486
487fn is_option(ty: &Type) -> bool {
488 if let Type::Path(tp) = ty {
489 if let Some(seg) = tp.path.segments.last() {
490 return seg.ident == "Option";
491 }
492 }
493 false
494}
495
496fn has_reference(ty: &Type) -> bool {
497 match ty {
498 Type::Reference(_) => true,
499 Type::Path(tp) => {
500 for seg in &tp.path.segments {
501 if let syn::PathArguments::AngleBracketed(ab) = &seg.arguments {
502 for arg in &ab.args {
503 if let syn::GenericArgument::Type(inner) = arg {
504 if has_reference(inner) {
505 return true;
506 }
507 }
508 }
509 }
510 }
511 false
512 }
513 _ => false,
514 }
515}