Skip to main content

schema_bridge_macro/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4    parse::Parse, parse_macro_input, punctuated::Punctuated, Data, DeriveInput, Fields, Ident, Lit,
5    Meta, Token, Type,
6};
7
8#[proc_macro_derive(SchemaBridge, attributes(schema_bridge, schema, serde))]
9pub fn derive_schema_bridge(input: TokenStream) -> TokenStream {
10    let input = parse_macro_input!(input as DeriveInput);
11    let name = &input.ident;
12
13    let ts_impl = impl_to_ts(&input);
14    let schema_impl = impl_to_schema(name, &input);
15
16    // Check for string_conversion attribute
17    let string_conversion = has_string_conversion(&input.attrs);
18
19    let mut expanded = quote! {
20        impl ::schema_bridge::SchemaBridge for #name {
21            fn to_ts() -> String {
22                #ts_impl
23            }
24
25            fn to_schema() -> ::schema_bridge::Schema {
26                #schema_impl
27            }
28        }
29    };
30
31    // Generate Display and FromStr if requested
32    if string_conversion {
33        if let Data::Enum(_) = &input.data {
34            let display_impl = impl_display(&input);
35            let fromstr_impl = impl_fromstr(&input);
36
37            expanded = quote! {
38                #expanded
39
40                #display_impl
41
42                #fromstr_impl
43            };
44        }
45    }
46
47    TokenStream::from(expanded)
48}
49
50/// Check if #[schema_bridge(string_conversion)] attribute is present
51fn has_string_conversion(attrs: &[syn::Attribute]) -> bool {
52    for attr in attrs {
53        if attr.path().is_ident("schema_bridge") {
54            if let Meta::List(meta_list) = &attr.meta {
55                if let Ok(Meta::Path(path)) = syn::parse2(meta_list.tokens.clone()) {
56                    if path.is_ident("string_conversion") {
57                        return true;
58                    }
59                }
60            }
61        }
62    }
63    false
64}
65
66fn impl_to_ts(input: &DeriveInput) -> proc_macro2::TokenStream {
67    match &input.data {
68        Data::Struct(data) => {
69            match &data.fields {
70                Fields::Named(fields) => {
71                    // Check for serde rename_all attribute on struct
72                    let rename_all = get_serde_rename_all(&input.attrs);
73
74                    let fields_ts = fields.named.iter().map(|f| {
75                        let field_name = &f.ident;
76                        let field_str = field_name.as_ref().unwrap().to_string();
77                        let ty = &f.ty;
78
79                        // Apply rename_all transformation if present
80                        let ts_field_name = if let Some(ref rule) = rename_all {
81                            apply_rename_rule(&field_str, rule)
82                        } else {
83                            field_str
84                        };
85
86                        quote! {
87                            format!("{}: {};", #ts_field_name, <#ty as ::schema_bridge::SchemaBridge>::to_ts())
88                        }
89                    });
90
91                    quote! {
92                        let fields = vec![#(#fields_ts),*];
93                        format!("{{ {} }}", fields.join(" "))
94                    }
95                }
96                Fields::Unnamed(fields) => {
97                    // Support for tuple structs, especially newtype pattern
98                    if fields.unnamed.len() == 1 {
99                        // Newtype pattern: delegate to the inner type
100                        let inner_ty = &fields.unnamed[0].ty;
101                        quote! {
102                            <#inner_ty as ::schema_bridge::SchemaBridge>::to_ts()
103                        }
104                    } else {
105                        // Multiple field tuple struct - represent as tuple
106                        let field_types = fields.unnamed.iter().map(|f| {
107                            let ty = &f.ty;
108                            quote! {
109                                <#ty as ::schema_bridge::SchemaBridge>::to_ts()
110                            }
111                        });
112
113                        quote! {
114                            let types = vec![#(#field_types),*];
115                            format!("[{}]", types.join(", "))
116                        }
117                    }
118                }
119                Fields::Unit => quote! { "null".to_string() },
120            }
121        }
122        Data::Enum(data) => {
123            // Check for serde rename_all attribute
124            let rename_all = get_serde_rename_all(&input.attrs);
125
126            let variants = data.variants.iter().map(|v| {
127                let variant_name = &v.ident;
128                let variant_str = variant_name.to_string();
129
130                // Apply rename_all transformation if present
131                let ts_name = if let Some(ref rule) = rename_all {
132                    apply_rename_rule(&variant_str, rule)
133                } else {
134                    variant_str
135                };
136
137                quote! {
138                    format!("'{}'", #ts_name)
139                }
140            });
141
142            quote! {
143                let variants = vec![#(#variants),*];
144                variants.join(" | ")
145            }
146        }
147        _ => quote! { "any".to_string() },
148    }
149}
150
151/// Extract rename_all from #[serde(rename_all = "...")]
152fn get_serde_rename_all(attrs: &[syn::Attribute]) -> Option<String> {
153    for attr in attrs {
154        if attr.path().is_ident("serde") {
155            if let Meta::List(meta_list) = &attr.meta {
156                // Parse the meta list
157                let nested: Result<Meta, _> = syn::parse2(meta_list.tokens.clone());
158                if let Ok(Meta::NameValue(nv)) = nested {
159                    if nv.path.is_ident("rename_all") {
160                        if let syn::Expr::Lit(expr_lit) = &nv.value {
161                            if let Lit::Str(lit_str) = &expr_lit.lit {
162                                return Some(lit_str.value());
163                            }
164                        }
165                    }
166                }
167            }
168        }
169    }
170    None
171}
172
173/// Detect if a name is in snake_case format
174fn is_snake_case(name: &str) -> bool {
175    name.contains('_')
176}
177
178/// Apply serde rename_all transformation
179fn apply_rename_rule(name: &str, rule: &str) -> String {
180    match rule {
181        "lowercase" => name.to_lowercase(),
182        "UPPERCASE" => name.to_uppercase(),
183        "PascalCase" => {
184            if is_snake_case(name) {
185                snake_to_pascal(name)
186            } else {
187                name.to_string() // Already PascalCase
188            }
189        }
190        "camelCase" => {
191            if is_snake_case(name) {
192                snake_to_camel(name)
193            } else {
194                pascal_to_camel(name)
195            }
196        }
197        "snake_case" => {
198            if is_snake_case(name) {
199                name.to_string() // Already snake_case
200            } else {
201                pascal_to_snake(name)
202            }
203        }
204        "SCREAMING_SNAKE_CASE" => {
205            if is_snake_case(name) {
206                name.to_uppercase()
207            } else {
208                pascal_to_screaming_snake(name)
209            }
210        }
211        "kebab-case" => {
212            if is_snake_case(name) {
213                name.replace('_', "-")
214            } else {
215                pascal_to_kebab(name)
216            }
217        }
218        _ => name.to_string(), // Unknown rule, keep as-is
219    }
220}
221
222/// Convert snake_case to PascalCase
223fn snake_to_pascal(name: &str) -> String {
224    name.split('_')
225        .filter(|s| !s.is_empty())
226        .map(|word| {
227            let mut chars = word.chars();
228            match chars.next() {
229                None => String::new(),
230                Some(first) => first.to_uppercase().chain(chars).collect(),
231            }
232        })
233        .collect()
234}
235
236/// Convert snake_case to camelCase
237fn snake_to_camel(name: &str) -> String {
238    let parts: Vec<&str> = name.split('_').filter(|s| !s.is_empty()).collect();
239    if parts.is_empty() {
240        return String::new();
241    }
242
243    let mut result = parts[0].to_lowercase();
244    for part in &parts[1..] {
245        let mut chars = part.chars();
246        if let Some(first) = chars.next() {
247            result.push_str(&first.to_uppercase().chain(chars).collect::<String>());
248        }
249    }
250    result
251}
252
253/// Convert PascalCase to camelCase
254fn pascal_to_camel(name: &str) -> String {
255    let mut chars = name.chars();
256    match chars.next() {
257        None => String::new(),
258        Some(first) => first.to_lowercase().chain(chars).collect(),
259    }
260}
261
262/// Convert PascalCase to snake_case
263fn pascal_to_snake(name: &str) -> String {
264    let mut result = String::new();
265    for (i, ch) in name.chars().enumerate() {
266        if ch.is_uppercase() && i > 0 {
267            result.push('_');
268        }
269        result.push(ch.to_lowercase().next().unwrap());
270    }
271    result
272}
273
274/// Convert PascalCase to SCREAMING_SNAKE_CASE
275fn pascal_to_screaming_snake(name: &str) -> String {
276    let mut result = String::new();
277    for (i, ch) in name.chars().enumerate() {
278        if ch.is_uppercase() && i > 0 {
279            result.push('_');
280        }
281        result.push(ch.to_uppercase().next().unwrap());
282    }
283    result
284}
285
286/// Convert PascalCase to kebab-case
287fn pascal_to_kebab(name: &str) -> String {
288    let mut result = String::new();
289    for (i, ch) in name.chars().enumerate() {
290        if ch.is_uppercase() && i > 0 {
291            result.push('-');
292        }
293        result.push(ch.to_lowercase().next().unwrap());
294    }
295    result
296}
297
298/// Parse #[schema(...)] attributes on a field.
299///
300/// Supported: required, min = N, max = N, min_len = N, max_len = N, one_of("a", "b", ...)
301#[derive(Default)]
302struct SchemaFieldAttrs {
303    required: Option<bool>,
304    min: Option<f64>,
305    max: Option<f64>,
306    min_len: Option<usize>,
307    max_len: Option<usize>,
308    one_of: Option<Vec<String>>,
309}
310
311fn parse_schema_attrs(attrs: &[syn::Attribute]) -> SchemaFieldAttrs {
312    let mut result = SchemaFieldAttrs::default();
313
314    for attr in attrs {
315        if !attr.path().is_ident("schema") {
316            continue;
317        }
318        let _ = attr.parse_nested_meta(|meta| {
319            if meta.path.is_ident("required") {
320                result.required = Some(true);
321                return Ok(());
322            }
323            if meta.path.is_ident("min") {
324                let value = meta.value()?;
325                let lit: Lit = value.parse()?;
326                if let Lit::Float(f) = &lit {
327                    result.min = Some(f.base10_parse::<f64>()?);
328                } else if let Lit::Int(i) = &lit {
329                    result.min = Some(i.base10_parse::<f64>()?);
330                }
331                return Ok(());
332            }
333            if meta.path.is_ident("max") {
334                let value = meta.value()?;
335                let lit: Lit = value.parse()?;
336                if let Lit::Float(f) = &lit {
337                    result.max = Some(f.base10_parse::<f64>()?);
338                } else if let Lit::Int(i) = &lit {
339                    result.max = Some(i.base10_parse::<f64>()?);
340                }
341                return Ok(());
342            }
343            if meta.path.is_ident("min_len") {
344                let value = meta.value()?;
345                let lit: Lit = value.parse()?;
346                if let Lit::Int(i) = &lit {
347                    result.min_len = Some(i.base10_parse::<usize>()?);
348                }
349                return Ok(());
350            }
351            if meta.path.is_ident("max_len") {
352                let value = meta.value()?;
353                let lit: Lit = value.parse()?;
354                if let Lit::Int(i) = &lit {
355                    result.max_len = Some(i.base10_parse::<usize>()?);
356                }
357                return Ok(());
358            }
359            if meta.path.is_ident("one_of") {
360                let content;
361                syn::parenthesized!(content in meta.input);
362                let lits: Punctuated<Lit, Token![,]> =
363                    content.parse_terminated(Lit::parse, Token![,])?;
364                let values: Vec<String> = lits
365                    .into_iter()
366                    .filter_map(|lit| {
367                        if let Lit::Str(s) = lit {
368                            Some(s.value())
369                        } else {
370                            None
371                        }
372                    })
373                    .collect();
374                if !values.is_empty() {
375                    result.one_of = Some(values);
376                }
377                return Ok(());
378            }
379            Err(meta.error("unknown schema attribute"))
380        });
381    }
382
383    result
384}
385
386/// Check if a type is Option<T> and return the inner type T
387fn extract_option_inner(ty: &Type) -> Option<&Type> {
388    if let Type::Path(type_path) = ty {
389        let segment = type_path.path.segments.last()?;
390        if segment.ident == "Option" {
391            if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
392                if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
393                    return Some(inner);
394                }
395            }
396        }
397    }
398    None
399}
400
401fn impl_to_schema(_name: &Ident, input: &DeriveInput) -> proc_macro2::TokenStream {
402    match &input.data {
403        Data::Struct(data) => match &data.fields {
404            Fields::Named(fields) => {
405                let rename_all = get_serde_rename_all(&input.attrs);
406
407                let field_exprs = fields.named.iter().map(|f| {
408                    let field_ident = f.ident.as_ref().unwrap();
409                    let field_str = field_ident.to_string();
410                    let ty = &f.ty;
411                    let schema_attrs = parse_schema_attrs(&f.attrs);
412
413                    let field_name = if let Some(ref rule) = rename_all {
414                        apply_rename_rule(&field_str, rule)
415                    } else {
416                        field_str
417                    };
418
419                    // Determine if Option<T> and extract inner type
420                    let (schema_expr, is_option) = if let Some(inner) = extract_option_inner(ty) {
421                        (
422                            quote! { <#inner as ::schema_bridge::SchemaBridge>::to_schema() },
423                            true,
424                        )
425                    } else {
426                        (
427                            quote! { <#ty as ::schema_bridge::SchemaBridge>::to_schema() },
428                            false,
429                        )
430                    };
431
432                    // Required: explicit #[schema(required)] > Option detection > default true
433                    let required = match schema_attrs.required {
434                        Some(r) => r,
435                        None => !is_option,
436                    };
437
438                    // Build constraints
439                    let min_expr = match schema_attrs.min {
440                        Some(v) => quote! { Some(#v) },
441                        None => quote! { None },
442                    };
443                    let max_expr = match schema_attrs.max {
444                        Some(v) => quote! { Some(#v) },
445                        None => quote! { None },
446                    };
447                    let min_len_expr = match schema_attrs.min_len {
448                        Some(v) => quote! { Some(#v) },
449                        None => quote! { None },
450                    };
451                    let max_len_expr = match schema_attrs.max_len {
452                        Some(v) => quote! { Some(#v) },
453                        None => quote! { None },
454                    };
455                    let one_of_expr = match &schema_attrs.one_of {
456                        Some(vals) => {
457                            let lit_vals = vals.iter().map(|s| quote! { #s.to_string() });
458                            quote! { Some(vec![#(#lit_vals),*]) }
459                        }
460                        None => quote! { None },
461                    };
462
463                    quote! {
464                        ::schema_bridge::Field {
465                            name: #field_name.to_string(),
466                            schema: #schema_expr,
467                            required: #required,
468                            constraints: ::schema_bridge::Constraints {
469                                min: #min_expr,
470                                max: #max_expr,
471                                min_len: #min_len_expr,
472                                max_len: #max_len_expr,
473                                one_of: #one_of_expr,
474                            },
475                        }
476                    }
477                });
478
479                quote! {
480                    ::schema_bridge::Schema::Object(vec![
481                        #(#field_exprs),*
482                    ])
483                }
484            }
485            Fields::Unnamed(fields) => {
486                if fields.unnamed.len() == 1 {
487                    let inner_ty = &fields.unnamed[0].ty;
488                    quote! {
489                        <#inner_ty as ::schema_bridge::SchemaBridge>::to_schema()
490                    }
491                } else {
492                    let types = fields.unnamed.iter().map(|f| {
493                        let ty = &f.ty;
494                        quote! { <#ty as ::schema_bridge::SchemaBridge>::to_schema() }
495                    });
496                    quote! {
497                        ::schema_bridge::Schema::Tuple(vec![#(#types),*])
498                    }
499                }
500            }
501            Fields::Unit => quote! { ::schema_bridge::Schema::Null },
502        },
503        Data::Enum(data) => {
504            let rename_all = get_serde_rename_all(&input.attrs);
505            let variants = data.variants.iter().map(|v| {
506                let variant_str = v.ident.to_string();
507                let display_name = if let Some(ref rule) = rename_all {
508                    apply_rename_rule(&variant_str, rule)
509                } else {
510                    variant_str
511                };
512                quote! { #display_name.to_string() }
513            });
514            quote! {
515                ::schema_bridge::Schema::Enum(vec![#(#variants),*])
516            }
517        }
518        _ => quote! { ::schema_bridge::Schema::Any },
519    }
520}
521
522/// Generate Display implementation for enum
523fn impl_display(input: &DeriveInput) -> proc_macro2::TokenStream {
524    let name = &input.ident;
525
526    if let Data::Enum(data) = &input.data {
527        let rename_all = get_serde_rename_all(&input.attrs);
528
529        let match_arms = data.variants.iter().map(|v| {
530            let variant_name = &v.ident;
531            let variant_str = variant_name.to_string();
532
533            let display_str = if let Some(ref rule) = rename_all {
534                apply_rename_rule(&variant_str, rule)
535            } else {
536                variant_str
537            };
538
539            quote! {
540                #name::#variant_name => write!(f, "{}", #display_str)
541            }
542        });
543
544        quote! {
545            impl ::std::fmt::Display for #name {
546                fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
547                    match self {
548                        #(#match_arms),*
549                    }
550                }
551            }
552        }
553    } else {
554        quote! {}
555    }
556}
557
558/// Generate FromStr implementation for enum
559fn impl_fromstr(input: &DeriveInput) -> proc_macro2::TokenStream {
560    let name = &input.ident;
561
562    if let Data::Enum(data) = &input.data {
563        let rename_all = get_serde_rename_all(&input.attrs);
564
565        let match_arms = data.variants.iter().map(|v| {
566            let variant_name = &v.ident;
567            let variant_str = variant_name.to_string();
568
569            let pattern_str = if let Some(ref rule) = rename_all {
570                apply_rename_rule(&variant_str, rule)
571            } else {
572                variant_str
573            };
574
575            quote! {
576                #pattern_str => ::std::result::Result::Ok(#name::#variant_name)
577            }
578        });
579
580        quote! {
581            impl ::std::str::FromStr for #name {
582                type Err = String;
583
584                fn from_str(s: &str) -> ::std::result::Result<Self, Self::Err> {
585                    match s {
586                        #(#match_arms,)*
587                        _ => ::std::result::Result::Err(format!("Unknown {}: {}", stringify!(#name), s))
588                    }
589                }
590            }
591        }
592    } else {
593        quote! {}
594    }
595}