gemini_proc_macros/
lib.rs1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4 parse_macro_input, Attribute, Data, DeriveInput, Fields, FnArg, ItemFn, Lit, Meta, Pat, Type,
5};
6
7#[proc_macro_attribute]
8pub fn gemini_function(_attr: TokenStream, item: TokenStream) -> TokenStream {
9 let mut input_fn = parse_macro_input!(item as ItemFn);
10 let fn_name = &input_fn.sig.ident;
11 let fn_description = extract_doc_comments(&input_fn.attrs);
12
13 let mut properties = Vec::new();
14 let mut required = Vec::new();
15 let mut param_names = Vec::new();
16 let mut param_types = Vec::new();
17
18 for arg in input_fn.sig.inputs.iter_mut() {
19 if let FnArg::Typed(pat_type) = arg {
20 if let Pat::Ident(pat_ident) = &*pat_type.pat {
21 let param_name = pat_ident.ident.clone();
22 let param_name_str = param_name.to_string();
23 let param_type = (*pat_type.ty).clone();
24 let param_desc = extract_doc_comments(&pat_type.attrs);
25
26 if has_reference(¶m_type) {
27 return syn::Error::new_spanned(
28 ¶m_type,
29 "references are not supported in gemini_function. Use owned types like String instead.",
30 )
31 .to_compile_error()
32 .into();
33 }
34
35 pat_type.attrs.retain(|attr| !attr.path().is_ident("doc"));
37
38 let is_optional = is_option(¶m_type);
39
40 properties.push(quote! {
41 let mut schema = <#param_type as GeminiSchema>::gemini_schema();
42 if !#param_desc.is_empty() {
43 if let Some(obj) = schema.as_object_mut() {
44 obj.insert("description".to_string(), serde_json::json!(#param_desc));
45 }
46 }
47 props.insert(#param_name_str.to_string(), schema);
48 });
49
50 if !is_optional {
51 required.push(param_name_str);
52 }
53
54 param_names.push(param_name);
55 param_types.push(param_type);
56 }
57 }
58 }
59
60 let fn_name_str = fn_name.to_string();
61 let is_async = input_fn.sig.asyncness.is_some();
62 let call_await = if is_async {
63 quote! { .await }
64 } else {
65 quote! {}
66 };
67
68 let is_result = match &input_fn.sig.output {
69 syn::ReturnType::Default => false,
70 syn::ReturnType::Type(_, ty) => {
71 let s = quote!(#ty).to_string();
72 s.contains("Result")
73 }
74 };
75
76 let result_handling = if is_result {
77 quote! {
78 match result {
79 Ok(v) => Ok(serde_json::json!(v)),
80 Err(e) => Err(e.to_string()),
81 }
82 }
83 } else {
84 quote! {
85 Ok(serde_json::json!(result))
86 }
87 };
88
89 let expanded = quote! {
90 #input_fn
91
92 #[allow(non_camel_case_types)]
93 pub struct #fn_name { }
94
95 impl GeminiSchema for #fn_name {
96 fn gemini_schema() -> serde_json::Value {
97 use serde_json::{json, Map};
98 let mut props = Map::new();
99 #(#properties)*
100
101 json!({
102 "name": #fn_name_str,
103 "description": #fn_description,
104 "parameters": {
105 "type": "OBJECT",
106 "properties": props,
107 "required": [#(#required),*]
108 }
109 })
110 }
111 }
112
113 impl #fn_name {
114 pub async fn execute(args: serde_json::Value) -> Result<serde_json::Value, String> {
115 use serde::Deserialize;
116 #[derive(Deserialize)]
117 struct Args {
118 #(#param_names: #param_types,)*
119 }
120 let args = Args::deserialize(&args).map_err(|e| e.to_string())?;
121 let result = #fn_name(#(args.#param_names),*) #call_await;
122 #result_handling
123 }
124 }
125 };
126
127 TokenStream::from(expanded)
128}
129
130#[proc_macro]
131pub fn execute_function_calls(input: TokenStream) -> TokenStream {
138 use syn::parse::{Parse, ParseStream};
139 use syn::{Expr, Token};
140
141 struct ExecuteInput {
142 session: Expr,
143 _comma: Token![,],
144 functions: syn::punctuated::Punctuated<syn::Path, Token![,]>,
145 }
146
147 impl Parse for ExecuteInput {
148 fn parse(input: ParseStream) -> syn::Result<Self> {
149 Ok(ExecuteInput {
150 session: input.parse()?,
151 _comma: input.parse()?,
152 functions: input.parse_terminated(syn::Path::parse, Token![,])?,
153 })
154 }
155 }
156
157 let input = parse_macro_input!(input as ExecuteInput);
158 let session = &input.session;
159 let functions = &input.functions;
160 let num_funcs = functions.len();
161
162 let match_arms = functions.iter().enumerate().map(|(i, path)| {
163 let name_str = path.segments.last().unwrap().ident.to_string();
164 quote! {
165 #name_str => {
166 let args = call.args().clone().unwrap_or(gemini_client_api::serde_json::json!({}));
167 let fut: gemini_client_api::futures::future::BoxFuture<'static, (usize, String, Result<gemini_client_api::serde_json::Value, String>)> = Box::pin(async move {
168 (#i, #name_str.to_string(), #path::execute(args).await)
169 });
170 futures.push(fut);
171 }
172 }
173 });
174
175 let expanded = quote! {
176 {
177 let mut results_array = vec![None; #num_funcs];
178 if let Some(chat) = #session.get_last_chat() {
179 let mut futures = Vec::new();
180 for part in chat.parts() {
181 if let gemini_client_api::gemini::types::request::PartType::FunctionCall(call) = part.data() {
182 match call.name().as_str() {
183 #(#match_arms)*
184 _ => {}
185 }
186 }
187 }
188 if !futures.is_empty() {
189 let results = gemini_client_api::futures::future::join_all(futures).await;
190 for (idx, name, res) in results {
191 if let Ok(ref val) = res {
192 if let Err(e) = #session.add_function_response(name.clone(), val.clone()) {
193 results_array[idx] = Some(Err(format!(
194 "failed to add function response for `{}`: {}",
195 name, e
196 )));
197 continue;
198 }
199 }
200 results_array[idx] = Some(res);
201 }
202 }
203 }
204 results_array
205 }
206 };
207
208 TokenStream::from(expanded)
209}
210
211#[proc_macro_attribute]
212pub fn gemini_schema(_attr: TokenStream, item: TokenStream) -> TokenStream {
213 let input = parse_macro_input!(item as DeriveInput);
214 let name = &input.ident;
215 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
216 let description = extract_doc_comments(&input.attrs);
217
218 let expanded = match &input.data {
219 Data::Struct(data) => {
220 let mut properties = Vec::new();
221 let mut required = Vec::new();
222
223 match &data.fields {
224 Fields::Named(fields) => {
225 for field in &fields.named {
226 let field_name = field.ident.as_ref().unwrap();
227 let field_name_str = field_name.to_string();
228 let field_type = &field.ty;
229 let field_desc = extract_doc_comments(&field.attrs);
230
231 if has_reference(field_type) {
232 return syn::Error::new_spanned(
233 field_type,
234 "references are not supported in gemini_schema. Use owned types instead.",
235 )
236 .to_compile_error()
237 .into();
238 }
239
240 let is_optional = is_option(field_type);
241
242 properties.push(quote! {
243 let mut schema = <#field_type as GeminiSchema>::gemini_schema();
244 if !#field_desc.is_empty() {
245 if let Some(obj) = schema.as_object_mut() {
246 obj.insert("description".to_string(), serde_json::json!(#field_desc));
247 }
248 }
249 props.insert(#field_name_str.to_string(), schema);
250 });
251
252 if !is_optional {
253 required.push(field_name_str);
254 }
255 }
256 }
257 _ => panic!("gemini_schema only supports named fields in structs"),
258 }
259
260 quote! {
261 impl #impl_generics GeminiSchema for #name #ty_generics #where_clause {
262 fn gemini_schema() -> serde_json::Value {
263 use serde_json::{json, Map};
264 let mut props = Map::new();
265 #(#properties)*
266
267 let mut schema = json!({
268 "type": "OBJECT",
269 "properties": props,
270 "required": [#(#required),*]
271 });
272
273 if !#description.is_empty() {
274 if let Some(obj) = schema.as_object_mut() {
275 obj.insert("description".to_string(), json!(#description));
276 }
277 }
278 schema
279 }
280 }
281 }
282 }
283 Data::Enum(data) => {
284 let mut variants = Vec::new();
285 for variant in &data.variants {
286 if !matches!(variant.fields, Fields::Unit) {
287 panic!("gemini_schema only supports unit variants in enums");
288 }
289 variants.push(variant.ident.to_string());
290 }
291
292 quote! {
293 impl #impl_generics GeminiSchema for #name #ty_generics #where_clause {
294 fn gemini_schema() -> serde_json::Value {
295 use serde_json::json;
296 let mut schema = json!({
297 "type": "STRING",
298 "enum": [#(#variants),*]
299 });
300
301 if !#description.is_empty() {
302 if let Some(obj) = schema.as_object_mut() {
303 obj.insert("description".to_string(), json!(#description));
304 }
305 }
306 schema
307 }
308 }
309 }
310 }
311 _ => panic!("gemini_schema only supports structs and enums"),
312 };
313
314 let output = quote! {
315 #input
316 #expanded
317 };
318
319 TokenStream::from(output)
320}
321
322fn extract_doc_comments(attrs: &[Attribute]) -> String {
323 let mut doc_comments = Vec::new();
324 for attr in attrs {
325 if attr.path().is_ident("doc") {
326 if let Meta::NameValue(nv) = &attr.meta {
327 if let syn::Expr::Lit(expr_lit) = &nv.value {
328 if let Lit::Str(lit_str) = &expr_lit.lit {
329 doc_comments.push(lit_str.value().trim().to_string());
330 }
331 }
332 }
333 }
334 }
335 doc_comments.join("\n")
336}
337
338fn is_option(ty: &Type) -> bool {
339 if let Type::Path(tp) = ty {
340 if let Some(seg) = tp.path.segments.last() {
341 return seg.ident == "Option";
342 }
343 }
344 false
345}
346
347fn has_reference(ty: &Type) -> bool {
348 match ty {
349 Type::Reference(_) => true,
350 Type::Path(tp) => {
351 for seg in &tp.path.segments {
352 if let syn::PathArguments::AngleBracketed(ab) = &seg.arguments {
353 for arg in &ab.args {
354 if let syn::GenericArgument::Type(inner) = arg {
355 if has_reference(inner) {
356 return true;
357 }
358 }
359 }
360 }
361 }
362 false
363 }
364 _ => false,
365 }
366}