gemini_proc_macros/
lib.rs1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4 parse_macro_input, Attribute, Data, DeriveInput, Fields, FnArg, ItemFn, Lit, Meta, Pat, Type,
5};
6
7#[proc_macro_attribute]
8pub fn gemini_function(_attr: TokenStream, item: TokenStream) -> TokenStream {
9 let mut input_fn = parse_macro_input!(item as ItemFn);
10 let fn_name = &input_fn.sig.ident;
11 let fn_description = extract_doc_comments(&input_fn.attrs);
12
13 let mut properties = Vec::new();
14 let mut required = Vec::new();
15 let mut param_names = Vec::new();
16 let mut param_types = Vec::new();
17
18 for arg in input_fn.sig.inputs.iter_mut() {
19 if let FnArg::Typed(pat_type) = arg {
20 if let Pat::Ident(pat_ident) = &*pat_type.pat {
21 let param_name = pat_ident.ident.clone();
22 let param_name_str = param_name.to_string();
23 let param_type = (*pat_type.ty).clone();
24 let param_desc = extract_doc_comments(&pat_type.attrs);
25
26 if has_reference(¶m_type) {
27 return syn::Error::new_spanned(
28 ¶m_type,
29 "references are not supported in gemini_function. Use owned types like String instead.",
30 )
31 .to_compile_error()
32 .into();
33 }
34
35 pat_type.attrs.retain(|attr| !attr.path().is_ident("doc"));
37
38 let is_optional = is_option(¶m_type);
39
40 properties.push(quote! {
41 let mut schema = <#param_type as GeminiSchema>::gemini_schema();
42 if !#param_desc.is_empty() {
43 if let Some(obj) = schema.as_object_mut() {
44 obj.insert("description".to_string(), serde_json::json!(#param_desc));
45 }
46 }
47 props.insert(#param_name_str.to_string(), schema);
48 });
49
50 if !is_optional {
51 required.push(param_name_str);
52 }
53
54 param_names.push(param_name);
55 param_types.push(param_type);
56 }
57 }
58 }
59
60 let fn_name_str = fn_name.to_string();
61 let is_async = input_fn.sig.asyncness.is_some();
62 let call_await = if is_async {
63 quote! { .await }
64 } else {
65 quote! {}
66 };
67
68 let is_result = match &input_fn.sig.output {
69 syn::ReturnType::Default => false,
70 syn::ReturnType::Type(_, ty) => {
71 let s = quote!(#ty).to_string();
72 s.contains("Result")
73 }
74 };
75
76 let result_handling = if is_result {
77 quote! {
78 match result {
79 Ok(v) => Ok(serde_json::json!(v)),
80 Err(e) => Err(e.to_string()),
81 }
82 }
83 } else {
84 quote! {
85 Ok(serde_json::json!(result))
86 }
87 };
88
89 let expanded = quote! {
90 #input_fn
91
92 #[allow(non_camel_case_types)]
93 pub struct #fn_name { }
94
95 impl GeminiSchema for #fn_name {
96 fn gemini_schema() -> serde_json::Value {
97 use serde_json::{json, Map};
98 let mut props = Map::new();
99 #(#properties)*
100
101 json!({
102 "name": #fn_name_str,
103 "description": #fn_description,
104 "parameters": {
105 "type": "OBJECT",
106 "properties": props,
107 "required": [#(#required),*]
108 }
109 })
110 }
111 }
112
113 impl #fn_name {
114 pub async fn execute(args: serde_json::Value) -> Result<serde_json::Value, String> {
115 use serde::Deserialize;
116 #[derive(Deserialize)]
117 struct Args {
118 #(#param_names: #param_types,)*
119 }
120 let args = Args::deserialize(&args).map_err(|e| e.to_string())?;
121 let result = #fn_name(#(args.#param_names),*) #call_await;
122 #result_handling
123 }
124 }
125 };
126
127 TokenStream::from(expanded)
128}
129
130#[proc_macro]
131pub fn execute_function_calls(input: TokenStream) -> TokenStream {
138 use syn::parse::{Parse, ParseStream};
139 use syn::{Expr, Token};
140
141 struct ExecuteInput {
142 session: Expr,
143 _comma: Token![,],
144 functions: syn::punctuated::Punctuated<syn::Path, Token![,]>,
145 }
146
147 impl Parse for ExecuteInput {
148 fn parse(input: ParseStream) -> syn::Result<Self> {
149 Ok(ExecuteInput {
150 session: input.parse()?,
151 _comma: input.parse()?,
152 functions: input.parse_terminated(syn::Path::parse, Token![,])?,
153 })
154 }
155 }
156
157 let input = parse_macro_input!(input as ExecuteInput);
158 let session = &input.session;
159 let functions = &input.functions;
160 let num_funcs = functions.len();
161
162 let match_arms = functions.iter().enumerate().map(|(i, path)| {
163 let name_str = path.segments.last().unwrap().ident.to_string();
164 quote! {
165 #name_str => {
166 let args = call.args().clone().unwrap_or(gemini_client_api::serde_json::json!({}));
167 let fut: gemini_client_api::futures::future::BoxFuture<'static, (usize, String, Result<gemini_client_api::serde_json::Value, String>)> = Box::pin(async move {
168 (#i, #name_str.to_string(), #path::execute(args).await)
169 });
170 futures.push(fut);
171 }
172 }
173 });
174
175 let expanded = quote! {
176 {
177 let mut results_array = vec![None; #num_funcs];
178 if let Some(chat) = #session.get_last_chat() {
179 let mut futures = Vec::new();
180 for part in chat.parts() {
181 if let gemini_client_api::gemini::types::request::PartType::FunctionCall(call) = part.data() {
182 match call.name().as_str() {
183 #(#match_arms)*
184 _ => {}
185 }
186 }
187 }
188 if !futures.is_empty() {
189 let results = gemini_client_api::futures::future::join_all(futures).await;
190 for (idx, name, res) in results {
191 if let Ok(ref val) = res {
192 let _ = #session.add_function_response(name, val.clone());
193 }
194 results_array[idx] = Some(res);
195 }
196 }
197 }
198 results_array
199 }
200 };
201
202 TokenStream::from(expanded)
203}
204
205#[proc_macro_attribute]
206pub fn gemini_schema(_attr: TokenStream, item: TokenStream) -> TokenStream {
207 let input = parse_macro_input!(item as DeriveInput);
208 let name = &input.ident;
209 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
210 let description = extract_doc_comments(&input.attrs);
211
212 let expanded = match &input.data {
213 Data::Struct(data) => {
214 let mut properties = Vec::new();
215 let mut required = Vec::new();
216
217 match &data.fields {
218 Fields::Named(fields) => {
219 for field in &fields.named {
220 let field_name = field.ident.as_ref().unwrap();
221 let field_name_str = field_name.to_string();
222 let field_type = &field.ty;
223 let field_desc = extract_doc_comments(&field.attrs);
224
225 if has_reference(field_type) {
226 return syn::Error::new_spanned(
227 field_type,
228 "references are not supported in gemini_schema. Use owned types instead.",
229 )
230 .to_compile_error()
231 .into();
232 }
233
234 let is_optional = is_option(field_type);
235
236 properties.push(quote! {
237 let mut schema = <#field_type as GeminiSchema>::gemini_schema();
238 if !#field_desc.is_empty() {
239 if let Some(obj) = schema.as_object_mut() {
240 obj.insert("description".to_string(), serde_json::json!(#field_desc));
241 }
242 }
243 props.insert(#field_name_str.to_string(), schema);
244 });
245
246 if !is_optional {
247 required.push(field_name_str);
248 }
249 }
250 }
251 _ => panic!("gemini_schema only supports named fields in structs"),
252 }
253
254 quote! {
255 impl #impl_generics GeminiSchema for #name #ty_generics #where_clause {
256 fn gemini_schema() -> serde_json::Value {
257 use serde_json::{json, Map};
258 let mut props = Map::new();
259 #(#properties)*
260
261 let mut schema = json!({
262 "type": "OBJECT",
263 "properties": props,
264 "required": [#(#required),*]
265 });
266
267 if !#description.is_empty() {
268 if let Some(obj) = schema.as_object_mut() {
269 obj.insert("description".to_string(), json!(#description));
270 }
271 }
272 schema
273 }
274 }
275 }
276 }
277 Data::Enum(data) => {
278 let mut variants = Vec::new();
279 for variant in &data.variants {
280 if !matches!(variant.fields, Fields::Unit) {
281 panic!("gemini_schema only supports unit variants in enums");
282 }
283 variants.push(variant.ident.to_string());
284 }
285
286 quote! {
287 impl #impl_generics GeminiSchema for #name #ty_generics #where_clause {
288 fn gemini_schema() -> serde_json::Value {
289 use serde_json::json;
290 let mut schema = json!({
291 "type": "STRING",
292 "enum": [#(#variants),*]
293 });
294
295 if !#description.is_empty() {
296 if let Some(obj) = schema.as_object_mut() {
297 obj.insert("description".to_string(), json!(#description));
298 }
299 }
300 schema
301 }
302 }
303 }
304 }
305 _ => panic!("gemini_schema only supports structs and enums"),
306 };
307
308 let output = quote! {
309 #input
310 #expanded
311 };
312
313 TokenStream::from(output)
314}
315
316fn extract_doc_comments(attrs: &[Attribute]) -> String {
317 let mut doc_comments = Vec::new();
318 for attr in attrs {
319 if attr.path().is_ident("doc") {
320 if let Meta::NameValue(nv) = &attr.meta {
321 if let syn::Expr::Lit(expr_lit) = &nv.value {
322 if let Lit::Str(lit_str) = &expr_lit.lit {
323 doc_comments.push(lit_str.value().trim().to_string());
324 }
325 }
326 }
327 }
328 }
329 doc_comments.join("\n")
330}
331
332fn is_option(ty: &Type) -> bool {
333 if let Type::Path(tp) = ty {
334 if let Some(seg) = tp.path.segments.last() {
335 return seg.ident == "Option";
336 }
337 }
338 false
339}
340
341fn has_reference(ty: &Type) -> bool {
342 match ty {
343 Type::Reference(_) => true,
344 Type::Path(tp) => {
345 for seg in &tp.path.segments {
346 if let syn::PathArguments::AngleBracketed(ab) = &seg.arguments {
347 for arg in &ab.args {
348 if let syn::GenericArgument::Type(inner) = arg {
349 if has_reference(inner) {
350 return true;
351 }
352 }
353 }
354 }
355 }
356 false
357 }
358 _ => false,
359 }
360}