1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4 parse::Parse, parse_macro_input, punctuated::Punctuated, Data, DeriveInput, Fields, Ident, Lit,
5 Meta, Token, Type,
6};
7
8#[proc_macro_derive(SchemaBridge, attributes(schema_bridge, schema, serde))]
9pub fn derive_schema_bridge(input: TokenStream) -> TokenStream {
10 let input = parse_macro_input!(input as DeriveInput);
11 let name = &input.ident;
12
13 let ts_impl = impl_to_ts(&input);
14 let schema_impl = impl_to_schema(name, &input);
15
16 let string_conversion = has_string_conversion(&input.attrs);
18
19 let mut expanded = quote! {
20 impl ::schema_bridge::SchemaBridge for #name {
21 fn to_ts() -> String {
22 #ts_impl
23 }
24
25 fn to_schema() -> ::schema_bridge::Schema {
26 #schema_impl
27 }
28 }
29 };
30
31 if string_conversion {
33 if let Data::Enum(_) = &input.data {
34 let display_impl = impl_display(&input);
35 let fromstr_impl = impl_fromstr(&input);
36
37 expanded = quote! {
38 #expanded
39
40 #display_impl
41
42 #fromstr_impl
43 };
44 }
45 }
46
47 TokenStream::from(expanded)
48}
49
50fn has_string_conversion(attrs: &[syn::Attribute]) -> bool {
52 for attr in attrs {
53 if attr.path().is_ident("schema_bridge") {
54 if let Meta::List(meta_list) = &attr.meta {
55 if let Ok(Meta::Path(path)) = syn::parse2(meta_list.tokens.clone()) {
56 if path.is_ident("string_conversion") {
57 return true;
58 }
59 }
60 }
61 }
62 }
63 false
64}
65
66fn impl_to_ts(input: &DeriveInput) -> proc_macro2::TokenStream {
67 match &input.data {
68 Data::Struct(data) => {
69 match &data.fields {
70 Fields::Named(fields) => {
71 let rename_all = get_serde_rename_all(&input.attrs);
73
74 let fields_ts = fields.named.iter().map(|f| {
75 let field_name = &f.ident;
76 let field_str = field_name.as_ref().unwrap().to_string();
77 let ty = &f.ty;
78
79 let ts_field_name = if let Some(ref rule) = rename_all {
81 apply_rename_rule(&field_str, rule)
82 } else {
83 field_str
84 };
85
86 quote! {
87 format!("{}: {};", #ts_field_name, <#ty as ::schema_bridge::SchemaBridge>::to_ts())
88 }
89 });
90
91 quote! {
92 let fields = vec![#(#fields_ts),*];
93 format!("{{ {} }}", fields.join(" "))
94 }
95 }
96 Fields::Unnamed(fields) => {
97 if fields.unnamed.len() == 1 {
99 let inner_ty = &fields.unnamed[0].ty;
101 quote! {
102 <#inner_ty as ::schema_bridge::SchemaBridge>::to_ts()
103 }
104 } else {
105 let field_types = fields.unnamed.iter().map(|f| {
107 let ty = &f.ty;
108 quote! {
109 <#ty as ::schema_bridge::SchemaBridge>::to_ts()
110 }
111 });
112
113 quote! {
114 let types = vec![#(#field_types),*];
115 format!("[{}]", types.join(", "))
116 }
117 }
118 }
119 Fields::Unit => quote! { "null".to_string() },
120 }
121 }
122 Data::Enum(data) => {
123 let rename_all = get_serde_rename_all(&input.attrs);
125
126 let variants = data.variants.iter().map(|v| {
127 let variant_name = &v.ident;
128 let variant_str = variant_name.to_string();
129
130 let ts_name = if let Some(ref rule) = rename_all {
132 apply_rename_rule(&variant_str, rule)
133 } else {
134 variant_str
135 };
136
137 quote! {
138 format!("'{}'", #ts_name)
139 }
140 });
141
142 quote! {
143 let variants = vec![#(#variants),*];
144 variants.join(" | ")
145 }
146 }
147 _ => quote! { "any".to_string() },
148 }
149}
150
151fn get_serde_rename_all(attrs: &[syn::Attribute]) -> Option<String> {
153 for attr in attrs {
154 if attr.path().is_ident("serde") {
155 if let Meta::List(meta_list) = &attr.meta {
156 let nested: Result<Meta, _> = syn::parse2(meta_list.tokens.clone());
158 if let Ok(Meta::NameValue(nv)) = nested {
159 if nv.path.is_ident("rename_all") {
160 if let syn::Expr::Lit(expr_lit) = &nv.value {
161 if let Lit::Str(lit_str) = &expr_lit.lit {
162 return Some(lit_str.value());
163 }
164 }
165 }
166 }
167 }
168 }
169 }
170 None
171}
172
173fn is_snake_case(name: &str) -> bool {
175 name.contains('_')
176}
177
178fn apply_rename_rule(name: &str, rule: &str) -> String {
180 match rule {
181 "lowercase" => name.to_lowercase(),
182 "UPPERCASE" => name.to_uppercase(),
183 "PascalCase" => {
184 if is_snake_case(name) {
185 snake_to_pascal(name)
186 } else {
187 name.to_string() }
189 }
190 "camelCase" => {
191 if is_snake_case(name) {
192 snake_to_camel(name)
193 } else {
194 pascal_to_camel(name)
195 }
196 }
197 "snake_case" => {
198 if is_snake_case(name) {
199 name.to_string() } else {
201 pascal_to_snake(name)
202 }
203 }
204 "SCREAMING_SNAKE_CASE" => {
205 if is_snake_case(name) {
206 name.to_uppercase()
207 } else {
208 pascal_to_screaming_snake(name)
209 }
210 }
211 "kebab-case" => {
212 if is_snake_case(name) {
213 name.replace('_', "-")
214 } else {
215 pascal_to_kebab(name)
216 }
217 }
218 _ => name.to_string(), }
220}
221
222fn snake_to_pascal(name: &str) -> String {
224 name.split('_')
225 .filter(|s| !s.is_empty())
226 .map(|word| {
227 let mut chars = word.chars();
228 match chars.next() {
229 None => String::new(),
230 Some(first) => first.to_uppercase().chain(chars).collect(),
231 }
232 })
233 .collect()
234}
235
236fn snake_to_camel(name: &str) -> String {
238 let parts: Vec<&str> = name.split('_').filter(|s| !s.is_empty()).collect();
239 if parts.is_empty() {
240 return String::new();
241 }
242
243 let mut result = parts[0].to_lowercase();
244 for part in &parts[1..] {
245 let mut chars = part.chars();
246 if let Some(first) = chars.next() {
247 result.push_str(&first.to_uppercase().chain(chars).collect::<String>());
248 }
249 }
250 result
251}
252
253fn pascal_to_camel(name: &str) -> String {
255 let mut chars = name.chars();
256 match chars.next() {
257 None => String::new(),
258 Some(first) => first.to_lowercase().chain(chars).collect(),
259 }
260}
261
262fn pascal_to_snake(name: &str) -> String {
264 let mut result = String::new();
265 for (i, ch) in name.chars().enumerate() {
266 if ch.is_uppercase() && i > 0 {
267 result.push('_');
268 }
269 result.push(ch.to_lowercase().next().unwrap());
270 }
271 result
272}
273
274fn pascal_to_screaming_snake(name: &str) -> String {
276 let mut result = String::new();
277 for (i, ch) in name.chars().enumerate() {
278 if ch.is_uppercase() && i > 0 {
279 result.push('_');
280 }
281 result.push(ch.to_uppercase().next().unwrap());
282 }
283 result
284}
285
286fn pascal_to_kebab(name: &str) -> String {
288 let mut result = String::new();
289 for (i, ch) in name.chars().enumerate() {
290 if ch.is_uppercase() && i > 0 {
291 result.push('-');
292 }
293 result.push(ch.to_lowercase().next().unwrap());
294 }
295 result
296}
297
298#[derive(Default)]
302struct SchemaFieldAttrs {
303 required: Option<bool>,
304 min: Option<f64>,
305 max: Option<f64>,
306 min_len: Option<usize>,
307 max_len: Option<usize>,
308 one_of: Option<Vec<String>>,
309}
310
311fn parse_schema_attrs(attrs: &[syn::Attribute]) -> SchemaFieldAttrs {
312 let mut result = SchemaFieldAttrs::default();
313
314 for attr in attrs {
315 if !attr.path().is_ident("schema") {
316 continue;
317 }
318 let _ = attr.parse_nested_meta(|meta| {
319 if meta.path.is_ident("required") {
320 result.required = Some(true);
321 return Ok(());
322 }
323 if meta.path.is_ident("min") {
324 let value = meta.value()?;
325 let lit: Lit = value.parse()?;
326 if let Lit::Float(f) = &lit {
327 result.min = Some(f.base10_parse::<f64>()?);
328 } else if let Lit::Int(i) = &lit {
329 result.min = Some(i.base10_parse::<f64>()?);
330 }
331 return Ok(());
332 }
333 if meta.path.is_ident("max") {
334 let value = meta.value()?;
335 let lit: Lit = value.parse()?;
336 if let Lit::Float(f) = &lit {
337 result.max = Some(f.base10_parse::<f64>()?);
338 } else if let Lit::Int(i) = &lit {
339 result.max = Some(i.base10_parse::<f64>()?);
340 }
341 return Ok(());
342 }
343 if meta.path.is_ident("min_len") {
344 let value = meta.value()?;
345 let lit: Lit = value.parse()?;
346 if let Lit::Int(i) = &lit {
347 result.min_len = Some(i.base10_parse::<usize>()?);
348 }
349 return Ok(());
350 }
351 if meta.path.is_ident("max_len") {
352 let value = meta.value()?;
353 let lit: Lit = value.parse()?;
354 if let Lit::Int(i) = &lit {
355 result.max_len = Some(i.base10_parse::<usize>()?);
356 }
357 return Ok(());
358 }
359 if meta.path.is_ident("one_of") {
360 let content;
361 syn::parenthesized!(content in meta.input);
362 let lits: Punctuated<Lit, Token![,]> =
363 content.parse_terminated(Lit::parse, Token![,])?;
364 let values: Vec<String> = lits
365 .into_iter()
366 .filter_map(|lit| {
367 if let Lit::Str(s) = lit {
368 Some(s.value())
369 } else {
370 None
371 }
372 })
373 .collect();
374 if !values.is_empty() {
375 result.one_of = Some(values);
376 }
377 return Ok(());
378 }
379 Err(meta.error("unknown schema attribute"))
380 });
381 }
382
383 result
384}
385
386fn extract_option_inner(ty: &Type) -> Option<&Type> {
388 if let Type::Path(type_path) = ty {
389 let segment = type_path.path.segments.last()?;
390 if segment.ident == "Option" {
391 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
392 if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
393 return Some(inner);
394 }
395 }
396 }
397 }
398 None
399}
400
401fn impl_to_schema(_name: &Ident, input: &DeriveInput) -> proc_macro2::TokenStream {
402 match &input.data {
403 Data::Struct(data) => match &data.fields {
404 Fields::Named(fields) => {
405 let rename_all = get_serde_rename_all(&input.attrs);
406
407 let field_exprs = fields.named.iter().map(|f| {
408 let field_ident = f.ident.as_ref().unwrap();
409 let field_str = field_ident.to_string();
410 let ty = &f.ty;
411 let schema_attrs = parse_schema_attrs(&f.attrs);
412
413 let field_name = if let Some(ref rule) = rename_all {
414 apply_rename_rule(&field_str, rule)
415 } else {
416 field_str
417 };
418
419 let (schema_expr, is_option) = if let Some(inner) = extract_option_inner(ty) {
421 (
422 quote! { <#inner as ::schema_bridge::SchemaBridge>::to_schema() },
423 true,
424 )
425 } else {
426 (
427 quote! { <#ty as ::schema_bridge::SchemaBridge>::to_schema() },
428 false,
429 )
430 };
431
432 let required = match schema_attrs.required {
434 Some(r) => r,
435 None => !is_option,
436 };
437
438 let min_expr = match schema_attrs.min {
440 Some(v) => quote! { Some(#v) },
441 None => quote! { None },
442 };
443 let max_expr = match schema_attrs.max {
444 Some(v) => quote! { Some(#v) },
445 None => quote! { None },
446 };
447 let min_len_expr = match schema_attrs.min_len {
448 Some(v) => quote! { Some(#v) },
449 None => quote! { None },
450 };
451 let max_len_expr = match schema_attrs.max_len {
452 Some(v) => quote! { Some(#v) },
453 None => quote! { None },
454 };
455 let one_of_expr = match &schema_attrs.one_of {
456 Some(vals) => {
457 let lit_vals = vals.iter().map(|s| quote! { #s.to_string() });
458 quote! { Some(vec![#(#lit_vals),*]) }
459 }
460 None => quote! { None },
461 };
462
463 quote! {
464 ::schema_bridge::Field {
465 name: #field_name.to_string(),
466 schema: #schema_expr,
467 required: #required,
468 constraints: ::schema_bridge::Constraints {
469 min: #min_expr,
470 max: #max_expr,
471 min_len: #min_len_expr,
472 max_len: #max_len_expr,
473 one_of: #one_of_expr,
474 },
475 }
476 }
477 });
478
479 quote! {
480 ::schema_bridge::Schema::Object(vec![
481 #(#field_exprs),*
482 ])
483 }
484 }
485 Fields::Unnamed(fields) => {
486 if fields.unnamed.len() == 1 {
487 let inner_ty = &fields.unnamed[0].ty;
488 quote! {
489 <#inner_ty as ::schema_bridge::SchemaBridge>::to_schema()
490 }
491 } else {
492 let types = fields.unnamed.iter().map(|f| {
493 let ty = &f.ty;
494 quote! { <#ty as ::schema_bridge::SchemaBridge>::to_schema() }
495 });
496 quote! {
497 ::schema_bridge::Schema::Tuple(vec![#(#types),*])
498 }
499 }
500 }
501 Fields::Unit => quote! { ::schema_bridge::Schema::Null },
502 },
503 Data::Enum(data) => {
504 let rename_all = get_serde_rename_all(&input.attrs);
505 let variants = data.variants.iter().map(|v| {
506 let variant_str = v.ident.to_string();
507 let display_name = if let Some(ref rule) = rename_all {
508 apply_rename_rule(&variant_str, rule)
509 } else {
510 variant_str
511 };
512 quote! { #display_name.to_string() }
513 });
514 quote! {
515 ::schema_bridge::Schema::Enum(vec![#(#variants),*])
516 }
517 }
518 _ => quote! { ::schema_bridge::Schema::Any },
519 }
520}
521
522fn impl_display(input: &DeriveInput) -> proc_macro2::TokenStream {
524 let name = &input.ident;
525
526 if let Data::Enum(data) = &input.data {
527 let rename_all = get_serde_rename_all(&input.attrs);
528
529 let match_arms = data.variants.iter().map(|v| {
530 let variant_name = &v.ident;
531 let variant_str = variant_name.to_string();
532
533 let display_str = if let Some(ref rule) = rename_all {
534 apply_rename_rule(&variant_str, rule)
535 } else {
536 variant_str
537 };
538
539 quote! {
540 #name::#variant_name => write!(f, "{}", #display_str)
541 }
542 });
543
544 quote! {
545 impl ::std::fmt::Display for #name {
546 fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
547 match self {
548 #(#match_arms),*
549 }
550 }
551 }
552 }
553 } else {
554 quote! {}
555 }
556}
557
558fn impl_fromstr(input: &DeriveInput) -> proc_macro2::TokenStream {
560 let name = &input.ident;
561
562 if let Data::Enum(data) = &input.data {
563 let rename_all = get_serde_rename_all(&input.attrs);
564
565 let match_arms = data.variants.iter().map(|v| {
566 let variant_name = &v.ident;
567 let variant_str = variant_name.to_string();
568
569 let pattern_str = if let Some(ref rule) = rename_all {
570 apply_rename_rule(&variant_str, rule)
571 } else {
572 variant_str
573 };
574
575 quote! {
576 #pattern_str => ::std::result::Result::Ok(#name::#variant_name)
577 }
578 });
579
580 quote! {
581 impl ::std::str::FromStr for #name {
582 type Err = String;
583
584 fn from_str(s: &str) -> ::std::result::Result<Self, Self::Err> {
585 match s {
586 #(#match_arms,)*
587 _ => ::std::result::Result::Err(format!("Unknown {}: {}", stringify!(#name), s))
588 }
589 }
590 }
591 }
592 } else {
593 quote! {}
594 }
595}