gemini_proc_macros/
lib.rs1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4 parse_macro_input, Attribute, Data, DeriveInput, Fields, FnArg, ItemFn, Lit, Meta, Pat, Type,
5};
6
7#[proc_macro_attribute]
8pub fn gemini_function(_attr: TokenStream, item: TokenStream) -> TokenStream {
9 let mut input_fn = parse_macro_input!(item as ItemFn);
10 let fn_name = &input_fn.sig.ident;
11 let fn_description = extract_doc_comments(&input_fn.attrs);
12
13 let mut properties = Vec::new();
14 let mut required = Vec::new();
15 let mut param_names = Vec::new();
16 let mut param_types = Vec::new();
17
18 for arg in input_fn.sig.inputs.iter_mut() {
19 if let FnArg::Typed(pat_type) = arg {
20 if let Pat::Ident(pat_ident) = &*pat_type.pat {
21 let param_name = pat_ident.ident.clone();
22 let param_name_str = param_name.to_string();
23 let param_type = (*pat_type.ty).clone();
24 let param_desc = extract_doc_comments(&pat_type.attrs);
25
26 pat_type.attrs.retain(|attr| !attr.path().is_ident("doc"));
28
29 let is_optional = is_option(¶m_type);
30
31 properties.push(quote! {
32 let mut schema = <#param_type as GeminiSchema>::gemini_schema();
33 if !#param_desc.is_empty() {
34 if let Some(obj) = schema.as_object_mut() {
35 obj.insert("description".to_string(), serde_json::json!(#param_desc));
36 }
37 }
38 props.insert(#param_name_str.to_string(), schema);
39 });
40
41 if !is_optional {
42 required.push(param_name_str);
43 }
44
45 param_names.push(param_name);
46 param_types.push(param_type);
47 }
48 }
49 }
50
51 let fn_name_str = fn_name.to_string();
52 let is_async = input_fn.sig.asyncness.is_some();
53 let call_await = if is_async {
54 quote! { .await }
55 } else {
56 quote! {}
57 };
58
59 let is_result = match &input_fn.sig.output {
60 syn::ReturnType::Default => false,
61 syn::ReturnType::Type(_, ty) => {
62 let s = quote!(#ty).to_string();
63 s.contains("Result")
64 }
65 };
66
67 let result_handling = if is_result {
68 quote! {
69 match result {
70 Ok(v) => Ok(serde_json::json!(v)),
71 Err(e) => Err(e.to_string()),
72 }
73 }
74 } else {
75 quote! {
76 Ok(serde_json::json!(result))
77 }
78 };
79
80 let expanded = quote! {
81 #input_fn
82
83 #[allow(non_camel_case_types)]
84 pub struct #fn_name { }
85
86 impl GeminiSchema for #fn_name {
87 fn gemini_schema() -> serde_json::Value {
88 use serde_json::{json, Map};
89 let mut props = Map::new();
90 #(#properties)*
91
92 json!({
93 "name": #fn_name_str,
94 "description": #fn_description,
95 "parameters": {
96 "type": "OBJECT",
97 "properties": props,
98 "required": [#(#required),*]
99 }
100 })
101 }
102 }
103
104 impl #fn_name {
105 pub async fn execute(args: serde_json::Value) -> Result<serde_json::Value, String> {
106 use serde::Deserialize;
107 #[derive(Deserialize)]
108 struct Args {
109 #(#param_names: #param_types),*
110 }
111 let args: Args = serde_json::from_value(args).map_err(|e| e.to_string())?;
112 let result = #fn_name(#(args.#param_names),*) #call_await;
113 #result_handling
114 }
115 }
116 };
117
118 TokenStream::from(expanded)
119}
120
121#[proc_macro]
122pub fn execute_function_calls(input: TokenStream) -> TokenStream {
123 use syn::parse::{Parse, ParseStream};
124 use syn::{Expr, Token};
125
126 struct ExecuteInput {
127 session: Expr,
128 _comma: Token![,],
129 functions: syn::punctuated::Punctuated<syn::Path, Token![,]>,
130 }
131
132 impl Parse for ExecuteInput {
133 fn parse(input: ParseStream) -> syn::Result<Self> {
134 Ok(ExecuteInput {
135 session: input.parse()?,
136 _comma: input.parse()?,
137 functions: input.parse_terminated(syn::Path::parse, Token![,])?,
138 })
139 }
140 }
141
142 let input = parse_macro_input!(input as ExecuteInput);
143 let session = &input.session;
144 let functions = &input.functions;
145
146 let match_arms = functions.iter().map(|path| {
147 let name_str = path.segments.last().unwrap().ident.to_string();
148 quote! {
149 #name_str => {
150 let args = call.args().clone().unwrap_or(gemini_client_api::serde_json::json!({}));
151 let fut: std::pin::Pin<Box<dyn std::future::Future<Output = (String, Result<gemini_client_api::serde_json::Value, String>)>>> = Box::pin(async move {
152 (#name_str.to_string(), #path::execute(args).await)
153 });
154 futures.push(fut);
155 }
156 }
157 });
158
159 let expanded = quote! {
160 {
161 let mut all_results = Vec::new();
162 if let Some(chat) = #session.get_last_chat() {
163 let mut futures = Vec::new();
164 for part in chat.parts() {
165 if let gemini_client_api::gemini::types::request::Part::functionCall(call) = part {
166 match call.name().as_str() {
167 #(#match_arms)*
168 _ => {}
169 }
170 }
171 }
172 if !futures.is_empty() {
173 let results = gemini_client_api::futures::future::join_all(futures).await;
174 for (name, res) in results {
175 if let Ok(ref val) = res {
176 let _ = #session.add_function_response(name, val.clone());
177 }
178 all_results.push(res);
179 }
180 }
181 }
182 all_results
183 }
184 };
185
186 TokenStream::from(expanded)
187}
188
189#[proc_macro_attribute]
190pub fn gemini_schema(_attr: TokenStream, item: TokenStream) -> TokenStream {
191 let input = parse_macro_input!(item as DeriveInput);
192 let name = &input.ident;
193 let description = extract_doc_comments(&input.attrs);
194
195 let expanded = match &input.data {
196 Data::Struct(data) => {
197 let mut properties = Vec::new();
198 let mut required = Vec::new();
199
200 match &data.fields {
201 Fields::Named(fields) => {
202 for field in &fields.named {
203 let field_name = field.ident.as_ref().unwrap();
204 let field_name_str = field_name.to_string();
205 let field_type = &field.ty;
206 let field_desc = extract_doc_comments(&field.attrs);
207
208 let is_optional = is_option(field_type);
209
210 properties.push(quote! {
211 let mut schema = <#field_type as GeminiSchema>::gemini_schema();
212 if !#field_desc.is_empty() {
213 if let Some(obj) = schema.as_object_mut() {
214 obj.insert("description".to_string(), serde_json::json!(#field_desc));
215 }
216 }
217 props.insert(#field_name_str.to_string(), schema);
218 });
219
220 if !is_optional {
221 required.push(field_name_str);
222 }
223 }
224 }
225 _ => panic!("gemini_schema only supports named fields in structs"),
226 }
227
228 quote! {
229 impl GeminiSchema for #name {
230 fn gemini_schema() -> serde_json::Value {
231 use serde_json::{json, Map};
232 let mut props = Map::new();
233 #(#properties)*
234
235 let mut schema = json!({
236 "type": "OBJECT",
237 "properties": props,
238 "required": [#(#required),*]
239 });
240
241 if !#description.is_empty() {
242 if let Some(obj) = schema.as_object_mut() {
243 obj.insert("description".to_string(), json!(#description));
244 }
245 }
246 schema
247 }
248 }
249 }
250 }
251 Data::Enum(data) => {
252 let mut variants = Vec::new();
253 for variant in &data.variants {
254 if !matches!(variant.fields, Fields::Unit) {
255 panic!("gemini_schema only supports unit variants in enums");
256 }
257 variants.push(variant.ident.to_string());
258 }
259
260 quote! {
261 impl GeminiSchema for #name {
262 fn gemini_schema() -> serde_json::Value {
263 use serde_json::json;
264 let mut schema = json!({
265 "type": "STRING",
266 "enum": [#(#variants),*]
267 });
268
269 if !#description.is_empty() {
270 if let Some(obj) = schema.as_object_mut() {
271 obj.insert("description".to_string(), json!(#description));
272 }
273 }
274 schema
275 }
276 }
277 }
278 }
279 _ => panic!("gemini_schema only supports structs and enums"),
280 };
281
282 let output = quote! {
283 #input
284 #expanded
285 };
286
287 TokenStream::from(output)
288}
289
290fn extract_doc_comments(attrs: &[Attribute]) -> String {
291 let mut doc_comments = Vec::new();
292 for attr in attrs {
293 if attr.path().is_ident("doc") {
294 if let Meta::NameValue(nv) = &attr.meta {
295 if let syn::Expr::Lit(expr_lit) = &nv.value {
296 if let Lit::Str(lit_str) = &expr_lit.lit {
297 doc_comments.push(lit_str.value().trim().to_string());
298 }
299 }
300 }
301 }
302 }
303 doc_comments.join("\n")
304}
305
306fn is_option(ty: &Type) -> bool {
307 if let Type::Path(tp) = ty {
308 if let Some(seg) = tp.path.segments.last() {
309 return seg.ident == "Option";
310 }
311 }
312 false
313}