gemini_proc_macros/
lib.rs1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4 parse_macro_input, Attribute, Data, DeriveInput, Fields, FnArg, ItemFn, Lit, Meta, Pat, Type,
5};
6
7#[proc_macro_attribute]
8pub fn gemini_function(_attr: TokenStream, item: TokenStream) -> TokenStream {
9 let mut input_fn = parse_macro_input!(item as ItemFn);
10 let fn_name = &input_fn.sig.ident;
11 let fn_description = extract_doc_comments(&input_fn.attrs);
12
13 let mut properties = Vec::new();
14 let mut required = Vec::new();
15 let mut param_names = Vec::new();
16 let mut param_types = Vec::new();
17
18 for arg in input_fn.sig.inputs.iter_mut() {
19 if let FnArg::Typed(pat_type) = arg {
20 if let Pat::Ident(pat_ident) = &*pat_type.pat {
21 let param_name = pat_ident.ident.clone();
22 let param_name_str = param_name.to_string();
23 let param_type = (*pat_type.ty).clone();
24 let param_desc = extract_doc_comments(&pat_type.attrs);
25
26 pat_type.attrs.retain(|attr| !attr.path().is_ident("doc"));
28
29 let is_optional = is_option(¶m_type);
30
31 properties.push(quote! {
32 let mut schema = <#param_type as GeminiSchema>::gemini_schema();
33 if !#param_desc.is_empty() {
34 if let Some(obj) = schema.as_object_mut() {
35 obj.insert("description".to_string(), serde_json::json!(#param_desc));
36 }
37 }
38 props.insert(#param_name_str.to_string(), schema);
39 });
40
41 if !is_optional {
42 required.push(param_name_str);
43 }
44
45 param_names.push(param_name);
46 param_types.push(param_type);
47 }
48 }
49 }
50
51 let fn_name_str = fn_name.to_string();
52 let is_async = input_fn.sig.asyncness.is_some();
53 let call_await = if is_async {
54 quote! { .await }
55 } else {
56 quote! {}
57 };
58
59 let is_result = match &input_fn.sig.output {
60 syn::ReturnType::Default => false,
61 syn::ReturnType::Type(_, ty) => {
62 let s = quote!(#ty).to_string();
63 s.contains("Result")
64 }
65 };
66
67 let result_handling = if is_result {
68 quote! {
69 match result {
70 Ok(v) => Ok(serde_json::json!(v)),
71 Err(e) => Err(e.to_string()),
72 }
73 }
74 } else {
75 quote! {
76 Ok(serde_json::json!(result))
77 }
78 };
79
80 let expanded = quote! {
81 #input_fn
82
83 #[allow(non_camel_case_types)]
84 pub struct #fn_name { }
85
86 impl GeminiSchema for #fn_name {
87 fn gemini_schema() -> serde_json::Value {
88 use serde_json::{json, Map};
89 let mut props = Map::new();
90 #(#properties)*
91
92 json!({
93 "name": #fn_name_str,
94 "description": #fn_description,
95 "parameters": {
96 "type": "OBJECT",
97 "properties": props,
98 "required": [#(#required),*]
99 }
100 })
101 }
102 }
103
104 impl #fn_name {
105 pub async fn execute(args: serde_json::Value) -> Result<serde_json::Value, String> {
106 use serde::Deserialize;
107 #[derive(Deserialize)]
108 struct Args {
109 #(#param_names: #param_types),*
110 }
111 let args: Args = serde_json::from_value(args).map_err(|e| e.to_string())?;
112 let result = #fn_name(#(args.#param_names),*) #call_await;
113 #result_handling
114 }
115 }
116 };
117
118 TokenStream::from(expanded)
119}
120
121#[proc_macro]
122pub fn execute_function_calls(input: TokenStream) -> TokenStream {
127 use syn::parse::{Parse, ParseStream};
128 use syn::{Expr, Token};
129
130 struct ExecuteInput {
131 session: Expr,
132 _comma: Token![,],
133 functions: syn::punctuated::Punctuated<syn::Path, Token![,]>,
134 }
135
136 impl Parse for ExecuteInput {
137 fn parse(input: ParseStream) -> syn::Result<Self> {
138 Ok(ExecuteInput {
139 session: input.parse()?,
140 _comma: input.parse()?,
141 functions: input.parse_terminated(syn::Path::parse, Token![,])?,
142 })
143 }
144 }
145
146 let input = parse_macro_input!(input as ExecuteInput);
147 let session = &input.session;
148 let functions = &input.functions;
149
150 let match_arms = functions.iter().map(|path| {
151 let name_str = path.segments.last().unwrap().ident.to_string();
152 quote! {
153 #name_str => {
154 let args = call.args().clone().unwrap_or(gemini_client_api::serde_json::json!({}));
155 let fut: std::pin::Pin<Box<dyn std::future::Future<Output = (String, Result<gemini_client_api::serde_json::Value, String>)>>> = Box::pin(async move {
156 (#name_str.to_string(), #path::execute(args).await)
157 });
158 futures.push(fut);
159 }
160 }
161 });
162
163 let expanded = quote! {
164 {
165 let mut all_results = Vec::new();
166 if let Some(chat) = #session.get_last_chat() {
167 let mut futures = Vec::new();
168 for part in chat.parts() {
169 if let gemini_client_api::gemini::types::request::PartType::FunctionCall(call) = part.data() {
170 match call.name().as_str() {
171 #(#match_arms)*
172 _ => {}
173 }
174 }
175 }
176 if !futures.is_empty() {
177 let results = gemini_client_api::futures::future::join_all(futures).await;
178 for (name, res) in results {
179 if let Ok(ref val) = res {
180 let _ = #session.add_function_response(name, val.clone());
181 }
182 all_results.push(res);
183 }
184 }
185 }
186 all_results
187 }
188 };
189
190 TokenStream::from(expanded)
191}
192
193#[proc_macro_attribute]
194pub fn gemini_schema(_attr: TokenStream, item: TokenStream) -> TokenStream {
195 let input = parse_macro_input!(item as DeriveInput);
196 let name = &input.ident;
197 let description = extract_doc_comments(&input.attrs);
198
199 let expanded = match &input.data {
200 Data::Struct(data) => {
201 let mut properties = Vec::new();
202 let mut required = Vec::new();
203
204 match &data.fields {
205 Fields::Named(fields) => {
206 for field in &fields.named {
207 let field_name = field.ident.as_ref().unwrap();
208 let field_name_str = field_name.to_string();
209 let field_type = &field.ty;
210 let field_desc = extract_doc_comments(&field.attrs);
211
212 let is_optional = is_option(field_type);
213
214 properties.push(quote! {
215 let mut schema = <#field_type as GeminiSchema>::gemini_schema();
216 if !#field_desc.is_empty() {
217 if let Some(obj) = schema.as_object_mut() {
218 obj.insert("description".to_string(), serde_json::json!(#field_desc));
219 }
220 }
221 props.insert(#field_name_str.to_string(), schema);
222 });
223
224 if !is_optional {
225 required.push(field_name_str);
226 }
227 }
228 }
229 _ => panic!("gemini_schema only supports named fields in structs"),
230 }
231
232 quote! {
233 impl GeminiSchema for #name {
234 fn gemini_schema() -> serde_json::Value {
235 use serde_json::{json, Map};
236 let mut props = Map::new();
237 #(#properties)*
238
239 let mut schema = json!({
240 "type": "OBJECT",
241 "properties": props,
242 "required": [#(#required),*]
243 });
244
245 if !#description.is_empty() {
246 if let Some(obj) = schema.as_object_mut() {
247 obj.insert("description".to_string(), json!(#description));
248 }
249 }
250 schema
251 }
252 }
253 }
254 }
255 Data::Enum(data) => {
256 let mut variants = Vec::new();
257 for variant in &data.variants {
258 if !matches!(variant.fields, Fields::Unit) {
259 panic!("gemini_schema only supports unit variants in enums");
260 }
261 variants.push(variant.ident.to_string());
262 }
263
264 quote! {
265 impl GeminiSchema for #name {
266 fn gemini_schema() -> serde_json::Value {
267 use serde_json::json;
268 let mut schema = json!({
269 "type": "STRING",
270 "enum": [#(#variants),*]
271 });
272
273 if !#description.is_empty() {
274 if let Some(obj) = schema.as_object_mut() {
275 obj.insert("description".to_string(), json!(#description));
276 }
277 }
278 schema
279 }
280 }
281 }
282 }
283 _ => panic!("gemini_schema only supports structs and enums"),
284 };
285
286 let output = quote! {
287 #input
288 #expanded
289 };
290
291 TokenStream::from(output)
292}
293
294fn extract_doc_comments(attrs: &[Attribute]) -> String {
295 let mut doc_comments = Vec::new();
296 for attr in attrs {
297 if attr.path().is_ident("doc") {
298 if let Meta::NameValue(nv) = &attr.meta {
299 if let syn::Expr::Lit(expr_lit) = &nv.value {
300 if let Lit::Str(lit_str) = &expr_lit.lit {
301 doc_comments.push(lit_str.value().trim().to_string());
302 }
303 }
304 }
305 }
306 }
307 doc_comments.join("\n")
308}
309
310fn is_option(ty: &Type) -> bool {
311 if let Type::Path(tp) = ty {
312 if let Some(seg) = tp.path.segments.last() {
313 return seg.ident == "Option";
314 }
315 }
316 false
317}