schema_bridge_macro/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, Data, DeriveInput, Fields, Ident, Lit, Meta};
4
5#[proc_macro_derive(SchemaBridge, attributes(schema_bridge, serde))]
6pub fn derive_schema_bridge(input: TokenStream) -> TokenStream {
7    let input = parse_macro_input!(input as DeriveInput);
8    let name = &input.ident;
9
10    let ts_impl = impl_to_ts(&input);
11    let schema_impl = impl_to_schema(name, &input.data);
12
13    // Check for string_conversion attribute
14    let string_conversion = has_string_conversion(&input.attrs);
15
16    let mut expanded = quote! {
17        impl ::schema_bridge::SchemaBridge for #name {
18            fn to_ts() -> String {
19                #ts_impl
20            }
21
22            fn to_schema() -> ::schema_bridge::Schema {
23                #schema_impl
24            }
25        }
26    };
27
28    // Generate Display and FromStr if requested
29    if string_conversion {
30        if let Data::Enum(_) = &input.data {
31            let display_impl = impl_display(&input);
32            let fromstr_impl = impl_fromstr(&input);
33
34            expanded = quote! {
35                #expanded
36
37                #display_impl
38
39                #fromstr_impl
40            };
41        }
42    }
43
44    TokenStream::from(expanded)
45}
46
47/// Check if #[schema_bridge(string_conversion)] attribute is present
48fn has_string_conversion(attrs: &[syn::Attribute]) -> bool {
49    for attr in attrs {
50        if attr.path().is_ident("schema_bridge") {
51            if let Meta::List(meta_list) = &attr.meta {
52                if let Ok(Meta::Path(path)) = syn::parse2(meta_list.tokens.clone()) {
53                    if path.is_ident("string_conversion") {
54                        return true;
55                    }
56                }
57            }
58        }
59    }
60    false
61}
62
63fn impl_to_ts(input: &DeriveInput) -> proc_macro2::TokenStream {
64    match &input.data {
65        Data::Struct(data) => {
66            match &data.fields {
67                Fields::Named(fields) => {
68                    // Check for serde rename_all attribute on struct
69                    let rename_all = get_serde_rename_all(&input.attrs);
70
71                    let fields_ts = fields.named.iter().map(|f| {
72                        let field_name = &f.ident;
73                        let field_str = field_name.as_ref().unwrap().to_string();
74                        let ty = &f.ty;
75
76                        // Apply rename_all transformation if present
77                        let ts_field_name = if let Some(ref rule) = rename_all {
78                            apply_rename_rule(&field_str, rule)
79                        } else {
80                            field_str
81                        };
82
83                        quote! {
84                            format!("{}: {};", #ts_field_name, <#ty as ::schema_bridge::SchemaBridge>::to_ts())
85                        }
86                    });
87
88                    quote! {
89                        let fields = vec![#(#fields_ts),*];
90                        format!("{{ {} }}", fields.join(" "))
91                    }
92                }
93                Fields::Unnamed(fields) => {
94                    // Support for tuple structs, especially newtype pattern
95                    if fields.unnamed.len() == 1 {
96                        // Newtype pattern: delegate to the inner type
97                        let inner_ty = &fields.unnamed[0].ty;
98                        quote! {
99                            <#inner_ty as ::schema_bridge::SchemaBridge>::to_ts()
100                        }
101                    } else {
102                        // Multiple field tuple struct - represent as tuple
103                        let field_types = fields.unnamed.iter().map(|f| {
104                            let ty = &f.ty;
105                            quote! {
106                                <#ty as ::schema_bridge::SchemaBridge>::to_ts()
107                            }
108                        });
109
110                        quote! {
111                            let types = vec![#(#field_types),*];
112                            format!("[{}]", types.join(", "))
113                        }
114                    }
115                }
116                Fields::Unit => quote! { "null".to_string() },
117            }
118        }
119        Data::Enum(data) => {
120            // Check for serde rename_all attribute
121            let rename_all = get_serde_rename_all(&input.attrs);
122
123            let variants = data.variants.iter().map(|v| {
124                let variant_name = &v.ident;
125                let variant_str = variant_name.to_string();
126
127                // Apply rename_all transformation if present
128                let ts_name = if let Some(ref rule) = rename_all {
129                    apply_rename_rule(&variant_str, rule)
130                } else {
131                    variant_str
132                };
133
134                quote! {
135                    format!("'{}'", #ts_name)
136                }
137            });
138
139            quote! {
140                let variants = vec![#(#variants),*];
141                variants.join(" | ")
142            }
143        }
144        _ => quote! { "any".to_string() },
145    }
146}
147
148/// Extract rename_all from #[serde(rename_all = "...")]
149fn get_serde_rename_all(attrs: &[syn::Attribute]) -> Option<String> {
150    for attr in attrs {
151        if attr.path().is_ident("serde") {
152            if let Meta::List(meta_list) = &attr.meta {
153                // Parse the meta list
154                let nested: Result<Meta, _> = syn::parse2(meta_list.tokens.clone());
155                if let Ok(Meta::NameValue(nv)) = nested {
156                    if nv.path.is_ident("rename_all") {
157                        if let syn::Expr::Lit(expr_lit) = &nv.value {
158                            if let Lit::Str(lit_str) = &expr_lit.lit {
159                                return Some(lit_str.value());
160                            }
161                        }
162                    }
163                }
164            }
165        }
166    }
167    None
168}
169
170/// Detect if a name is in snake_case format
171fn is_snake_case(name: &str) -> bool {
172    name.contains('_')
173}
174
175/// Apply serde rename_all transformation
176fn apply_rename_rule(name: &str, rule: &str) -> String {
177    match rule {
178        "lowercase" => name.to_lowercase(),
179        "UPPERCASE" => name.to_uppercase(),
180        "PascalCase" => {
181            if is_snake_case(name) {
182                snake_to_pascal(name)
183            } else {
184                name.to_string() // Already PascalCase
185            }
186        }
187        "camelCase" => {
188            if is_snake_case(name) {
189                snake_to_camel(name)
190            } else {
191                pascal_to_camel(name)
192            }
193        }
194        "snake_case" => {
195            if is_snake_case(name) {
196                name.to_string() // Already snake_case
197            } else {
198                pascal_to_snake(name)
199            }
200        }
201        "SCREAMING_SNAKE_CASE" => {
202            if is_snake_case(name) {
203                name.to_uppercase()
204            } else {
205                pascal_to_screaming_snake(name)
206            }
207        }
208        "kebab-case" => {
209            if is_snake_case(name) {
210                name.replace('_', "-")
211            } else {
212                pascal_to_kebab(name)
213            }
214        }
215        _ => name.to_string(), // Unknown rule, keep as-is
216    }
217}
218
219/// Convert snake_case to PascalCase
220fn snake_to_pascal(name: &str) -> String {
221    name.split('_')
222        .filter(|s| !s.is_empty())
223        .map(|word| {
224            let mut chars = word.chars();
225            match chars.next() {
226                None => String::new(),
227                Some(first) => first.to_uppercase().chain(chars).collect(),
228            }
229        })
230        .collect()
231}
232
233/// Convert snake_case to camelCase
234fn snake_to_camel(name: &str) -> String {
235    let parts: Vec<&str> = name.split('_').filter(|s| !s.is_empty()).collect();
236    if parts.is_empty() {
237        return String::new();
238    }
239
240    let mut result = parts[0].to_lowercase();
241    for part in &parts[1..] {
242        let mut chars = part.chars();
243        if let Some(first) = chars.next() {
244            result.push_str(&first.to_uppercase().chain(chars).collect::<String>());
245        }
246    }
247    result
248}
249
250/// Convert PascalCase to camelCase
251fn pascal_to_camel(name: &str) -> String {
252    let mut chars = name.chars();
253    match chars.next() {
254        None => String::new(),
255        Some(first) => first.to_lowercase().chain(chars).collect(),
256    }
257}
258
259/// Convert PascalCase to snake_case
260fn pascal_to_snake(name: &str) -> String {
261    let mut result = String::new();
262    for (i, ch) in name.chars().enumerate() {
263        if ch.is_uppercase() && i > 0 {
264            result.push('_');
265        }
266        result.push(ch.to_lowercase().next().unwrap());
267    }
268    result
269}
270
271/// Convert PascalCase to SCREAMING_SNAKE_CASE
272fn pascal_to_screaming_snake(name: &str) -> String {
273    let mut result = String::new();
274    for (i, ch) in name.chars().enumerate() {
275        if ch.is_uppercase() && i > 0 {
276            result.push('_');
277        }
278        result.push(ch.to_uppercase().next().unwrap());
279    }
280    result
281}
282
283/// Convert PascalCase to kebab-case
284fn pascal_to_kebab(name: &str) -> String {
285    let mut result = String::new();
286    for (i, ch) in name.chars().enumerate() {
287        if ch.is_uppercase() && i > 0 {
288            result.push('-');
289        }
290        result.push(ch.to_lowercase().next().unwrap());
291    }
292    result
293}
294
295fn impl_to_schema(_name: &Ident, _data: &Data) -> proc_macro2::TokenStream {
296    // Placeholder for now, focusing on TS generation first
297    quote! {
298        ::schema_bridge::Schema::Any
299    }
300}
301
302/// Generate Display implementation for enum
303fn impl_display(input: &DeriveInput) -> proc_macro2::TokenStream {
304    let name = &input.ident;
305
306    if let Data::Enum(data) = &input.data {
307        let rename_all = get_serde_rename_all(&input.attrs);
308
309        let match_arms = data.variants.iter().map(|v| {
310            let variant_name = &v.ident;
311            let variant_str = variant_name.to_string();
312
313            let display_str = if let Some(ref rule) = rename_all {
314                apply_rename_rule(&variant_str, rule)
315            } else {
316                variant_str
317            };
318
319            quote! {
320                #name::#variant_name => write!(f, "{}", #display_str)
321            }
322        });
323
324        quote! {
325            impl ::std::fmt::Display for #name {
326                fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
327                    match self {
328                        #(#match_arms),*
329                    }
330                }
331            }
332        }
333    } else {
334        quote! {}
335    }
336}
337
338/// Generate FromStr implementation for enum
339fn impl_fromstr(input: &DeriveInput) -> proc_macro2::TokenStream {
340    let name = &input.ident;
341
342    if let Data::Enum(data) = &input.data {
343        let rename_all = get_serde_rename_all(&input.attrs);
344
345        let match_arms = data.variants.iter().map(|v| {
346            let variant_name = &v.ident;
347            let variant_str = variant_name.to_string();
348
349            let pattern_str = if let Some(ref rule) = rename_all {
350                apply_rename_rule(&variant_str, rule)
351            } else {
352                variant_str
353            };
354
355            quote! {
356                #pattern_str => ::std::result::Result::Ok(#name::#variant_name)
357            }
358        });
359
360        quote! {
361            impl ::std::str::FromStr for #name {
362                type Err = String;
363
364                fn from_str(s: &str) -> ::std::result::Result<Self, Self::Err> {
365                    match s {
366                        #(#match_arms,)*
367                        _ => ::std::result::Result::Err(format!("Unknown {}: {}", stringify!(#name), s))
368                    }
369                }
370            }
371        }
372    } else {
373        quote! {}
374    }
375}