adamastor_macros/
lib.rs

1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{Attribute, Fields, ItemStruct, Lit, Meta, parse_macro_input};
6
7#[proc_macro_attribute]
8pub fn schema(_attr: TokenStream, item: TokenStream) -> TokenStream {
9    let mut ast = parse_macro_input!(item as ItemStruct);
10    let name = &ast.ident;
11
12    ast.vis = syn::parse_quote!(pub);
13    let derives_attr: Attribute =
14        syn::parse_quote!(#[derive(serde::Serialize, serde::Deserialize, Debug, Clone, Default)]);
15    ast.attrs.push(derives_attr);
16
17    let (gemini_schema_impl, cleaned_fields) =
18        generate_gemini_schema_impl_and_clean_fields(&name, &ast.fields);
19
20    if let Fields::Named(ref mut fields) = ast.fields {
21        fields.named = cleaned_fields;
22    }
23
24    let output = quote! {
25        #ast
26        #gemini_schema_impl
27    };
28
29    output.into()
30}
31
32fn generate_gemini_schema_impl_and_clean_fields(
33    name: &syn::Ident,
34    fields: &Fields,
35) -> (
36    proc_macro2::TokenStream,
37    syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
38) {
39    let fields_iter = match fields {
40        Fields::Named(fields) => fields.named.iter(),
41        _ => panic!("#[schema] can only be used on structs with named fields"),
42    };
43
44    let mut cleaned_fields = syn::punctuated::Punctuated::new();
45
46    let properties_quotes = fields_iter.map(|f| {
47        let mut cleaned_field = f.clone();
48
49        let field_name = f.ident.as_ref().unwrap();
50        let field_name_str = field_name.to_string();
51        let field_type = &f.ty;
52
53        let description = f.attrs.iter().find_map(|attr| {
54            if attr.path().is_ident("doc") {
55                if let Meta::NameValue(nv) = &attr.meta {
56                    if let syn::Expr::Lit(expr_lit) = &nv.value {
57                        if let Lit::Str(lit_str) = &expr_lit.lit {
58                            return Some(lit_str.value().trim().to_string());
59                        }
60                    }
61                }
62            }
63            None
64        });
65
66        cleaned_fields.push(cleaned_field);
67
68        let desc_quote = if let Some(desc) = description {
69            quote! {
70                if let Some(obj) = field_schema.as_object_mut() {
71                    obj.insert("description".to_string(), serde_json::json!(#desc));
72                }
73            }
74        } else {
75            quote! {}
76        };
77
78        let type_str = quote!(#field_type).to_string().replace(" ", "");
79
80        let field_schema_type =
81            if type_str.contains("Vec<String>") || type_str.contains("Vec<&str>") {
82                quote! { serde_json::json!({"type": "ARRAY", "items": {"type": "STRING"}}) }
83            } else if type_str.contains("Vec<") {
84                quote! { serde_json::json!({"type": "ARRAY", "items": {"type": "STRING"}}) }
85            } else if type_str.contains("Option<") {
86                quote! { serde_json::json!({"type": "STRING", "nullable": true}) }
87            } else if type_str.contains("String") {
88                quote! { serde_json::json!({"type": "STRING"}) }
89            } else if type_str.contains("u32") || type_str.contains("i32") {
90                quote! { serde_json::json!({"type": "INTEGER", "format": "int32"}) }
91            } else if type_str.contains("u64") || type_str.contains("i64") {
92                quote! { serde_json::json!({"type": "INTEGER", "format": "int64"}) }
93            } else if type_str.contains("f32") {
94                quote! { serde_json::json!({"type": "NUMBER", "format": "float"}) }
95            } else if type_str.contains("f64") {
96                quote! { serde_json::json!({"type": "NUMBER", "format": "double"}) }
97            } else if type_str.contains("bool") {
98                quote! { serde_json::json!({"type": "BOOLEAN"}) }
99            } else {
100                quote! { serde_json::json!({"type": "STRING"}) }
101            };
102
103        quote! {
104            {
105                let mut field_schema = #field_schema_type;
106                #desc_quote
107                properties.insert(#field_name_str.to_string(), field_schema);
108                required.push(#field_name_str.to_string());
109            }
110        }
111    });
112
113    let impl_block = quote! {
114        impl adamastor::GeminiSchema for #name {
115            fn gemini_schema() -> serde_json::Value {
116                let mut properties = serde_json::Map::new();
117                let mut required = vec![];
118
119                #(#properties_quotes)*
120
121                serde_json::json!({
122                    "type": "OBJECT",
123                    "properties": properties,
124                    "required": required
125                })
126            }
127        }
128    };
129
130    (impl_block, cleaned_fields)
131}