zod_rs_ts/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, Attribute, Data, DeriveInput, Fields, Meta};
4
5#[proc_macro_derive(ZodTs, attributes(zod))]
6pub fn derive_zod_ts(input: TokenStream) -> TokenStream {
7    let input = parse_macro_input!(input as DeriveInput);
8    let name = &input.ident;
9    let name_str = name.to_string();
10
11    match &input.data {
12        Data::Struct(data_struct) => match &data_struct.fields {
13            Fields::Named(fields) => {
14                let field_schemas: Vec<String> = fields
15                    .named
16                    .iter()
17                    .map(|field| {
18                        let field_name = field.ident.as_ref().unwrap().to_string();
19                        let field_type = &field.ty;
20                        let attrs = parse_zod_attributes(&field.attrs);
21                        let is_optional = is_option_type(field_type);
22
23                        let base_type = if is_optional {
24                            get_option_inner_type_str(field_type)
25                        } else {
26                            type_to_string(field_type)
27                        };
28
29                        let zod_type = rust_type_to_zod(&base_type, &attrs);
30                        let final_type = if is_optional {
31                            format!("{}.optional()", zod_type)
32                        } else {
33                            zod_type
34                        };
35
36                        format!("  {}: {}", field_name, final_type)
37                    })
38                    .collect();
39
40                let fields_str = field_schemas.join(",\n");
41                let schema_name = format!("{}Schema", name_str);
42
43                let ts_code = format!(
44                    r#"import {{ z }} from 'zod';
45
46export const {} = z.object({{
47{}
48}});
49
50export type {} = z.infer<typeof {}>;"#,
51                    schema_name, fields_str, name_str, schema_name
52                );
53
54                let expanded = quote! {
55                    impl #name {
56                        pub fn zod_ts() -> String {
57                            #ts_code.to_string()
58                        }
59                    }
60                };
61
62                TokenStream::from(expanded)
63            }
64            _ => {
65                let error = syn::Error::new_spanned(
66                    &input,
67                    "ZodTs can only be derived for structs with named fields",
68                );
69                TokenStream::from(error.to_compile_error())
70            }
71        },
72        Data::Enum(data_enum) => {
73            let variant_schemas: Vec<String> = data_enum
74                .variants
75                .iter()
76                .map(|variant| {
77                    let variant_name = variant.ident.to_string();
78                    generate_variant_ts(&variant_name, &variant.fields)
79                })
80                .collect();
81
82            let variants_str = variant_schemas.join(",\n  ");
83            let schema_name = format!("{}Schema", name_str);
84
85            let ts_code = format!(
86                r#"import {{ z }} from 'zod';
87
88export const {} = z.union([
89  {}
90]);
91
92export type {} = z.infer<typeof {}>;"#,
93                schema_name, variants_str, name_str, schema_name
94            );
95
96            let expanded = quote! {
97                impl #name {
98                    pub fn zod_ts() -> String {
99                        #ts_code.to_string()
100                    }
101                }
102            };
103
104            TokenStream::from(expanded)
105        }
106        Data::Union(_) => {
107            let error =
108                syn::Error::new_spanned(&input, "ZodTs cannot be derived for Rust unions");
109            TokenStream::from(error.to_compile_error())
110        }
111    }
112}
113
114fn generate_variant_ts(variant_name: &str, fields: &Fields) -> String {
115    match fields {
116        Fields::Unit => {
117            format!("z.object({{ {}: z.null() }})", variant_name)
118        }
119        Fields::Unnamed(fields_unnamed) => {
120            let field_count = fields_unnamed.unnamed.len();
121            if field_count == 1 {
122                let field = fields_unnamed.unnamed.first().unwrap();
123                let field_type = type_to_string(&field.ty);
124                let attrs = parse_zod_attributes(&field.attrs);
125                let zod_type = rust_type_to_zod(&field_type, &attrs);
126                format!("z.object({{ {}: {} }})", variant_name, zod_type)
127            } else {
128                let element_types: Vec<String> = fields_unnamed
129                    .unnamed
130                    .iter()
131                    .map(|field| {
132                        let field_type = type_to_string(&field.ty);
133                        let attrs = parse_zod_attributes(&field.attrs);
134                        rust_type_to_zod(&field_type, &attrs)
135                    })
136                    .collect();
137                let tuple_str = element_types.join(", ");
138                format!("z.object({{ {}: z.tuple([{}]) }})", variant_name, tuple_str)
139            }
140        }
141        Fields::Named(fields_named) => {
142            let field_schemas: Vec<String> = fields_named
143                .named
144                .iter()
145                .map(|field| {
146                    let field_name = field.ident.as_ref().unwrap().to_string();
147                    let field_type = type_to_string(&field.ty);
148                    let attrs = parse_zod_attributes(&field.attrs);
149                    let is_optional = is_option_type(&field.ty);
150
151                    let base_type = if is_optional {
152                        get_option_inner_type_str(&field.ty)
153                    } else {
154                        field_type
155                    };
156
157                    let zod_type = rust_type_to_zod(&base_type, &attrs);
158                    let final_type = if is_optional {
159                        format!("{}.optional()", zod_type)
160                    } else {
161                        zod_type
162                    };
163
164                    format!("{}: {}", field_name, final_type)
165                })
166                .collect();
167            let fields_str = field_schemas.join(", ");
168            format!(
169                "z.object({{ {}: z.object({{ {} }}) }})",
170                variant_name, fields_str
171            )
172        }
173    }
174}
175
176#[derive(Default)]
177struct ZodAttributes {
178    min: Option<f64>,
179    max: Option<f64>,
180    length: Option<usize>,
181    min_length: Option<usize>,
182    max_length: Option<usize>,
183    starts_with: Option<String>,
184    ends_with: Option<String>,
185    includes: Option<String>,
186    email: bool,
187    url: bool,
188    regex: Option<String>,
189    positive: bool,
190    negative: bool,
191    nonnegative: bool,
192    nonpositive: bool,
193    int: bool,
194    finite: bool,
195}
196
197fn parse_zod_attributes(attrs: &[Attribute]) -> ZodAttributes {
198    let mut zod_attrs = ZodAttributes::default();
199
200    for attr in attrs {
201        if attr.path().is_ident("zod") {
202            if let Meta::List(meta_list) = &attr.meta {
203                let tokens: Vec<_> = meta_list.tokens.clone().into_iter().collect();
204                let mut i = 0;
205
206                while i < tokens.len() {
207                    let token_str = tokens[i].to_string();
208
209                    match token_str.as_str() {
210                        "min_length" => {
211                            if i + 1 < tokens.len() {
212                                let value_token = tokens[i + 1].to_string();
213                                if let Some(value) = extract_number_from_parens(&value_token) {
214                                    zod_attrs.min_length = Some(value);
215                                }
216                                i += 1;
217                            }
218                        }
219                        "max_length" => {
220                            if i + 1 < tokens.len() {
221                                let value_token = tokens[i + 1].to_string();
222                                if let Some(value) = extract_number_from_parens(&value_token) {
223                                    zod_attrs.max_length = Some(value);
224                                }
225                                i += 1;
226                            }
227                        }
228                        "length" => {
229                            if i + 1 < tokens.len() {
230                                let value_token = tokens[i + 1].to_string();
231                                if let Some(value) = extract_number_from_parens(&value_token) {
232                                    zod_attrs.length = Some(value);
233                                }
234                                i += 1;
235                            }
236                        }
237                        "min" => {
238                            if i + 1 < tokens.len() {
239                                let value_token = tokens[i + 1].to_string();
240                                if let Some(value_str) = extract_string_from_parens(&value_token) {
241                                    if let Ok(value) = value_str.parse::<f64>() {
242                                        zod_attrs.min = Some(value);
243                                    }
244                                }
245                                i += 1;
246                            }
247                        }
248                        "max" => {
249                            if i + 1 < tokens.len() {
250                                let value_token = tokens[i + 1].to_string();
251                                if let Some(value_str) = extract_string_from_parens(&value_token) {
252                                    if let Ok(value) = value_str.parse::<f64>() {
253                                        zod_attrs.max = Some(value);
254                                    }
255                                }
256                                i += 1;
257                            }
258                        }
259                        "starts_with" => {
260                            if i + 1 < tokens.len() {
261                                let value_token = tokens[i + 1].to_string();
262                                if let Some(value) = extract_string_from_parens(&value_token) {
263                                    zod_attrs.starts_with = Some(strip_quotes(&value));
264                                }
265                                i += 1;
266                            }
267                        }
268                        "ends_with" => {
269                            if i + 1 < tokens.len() {
270                                let value_token = tokens[i + 1].to_string();
271                                if let Some(value) = extract_string_from_parens(&value_token) {
272                                    zod_attrs.ends_with = Some(strip_quotes(&value));
273                                }
274                                i += 1;
275                            }
276                        }
277                        "includes" => {
278                            if i + 1 < tokens.len() {
279                                let value_token = tokens[i + 1].to_string();
280                                if let Some(value) = extract_string_from_parens(&value_token) {
281                                    zod_attrs.includes = Some(strip_quotes(&value));
282                                }
283                                i += 1;
284                            }
285                        }
286                        "regex" => {
287                            if i + 1 < tokens.len() {
288                                let value_token = tokens[i + 1].to_string();
289                                if let Some(value) = extract_string_from_parens(&value_token) {
290                                    zod_attrs.regex = Some(strip_quotes(&value));
291                                }
292                                i += 1;
293                            }
294                        }
295                        "email" => {
296                            zod_attrs.email = true;
297                        }
298                        "url" => {
299                            zod_attrs.url = true;
300                        }
301                        "positive" => {
302                            zod_attrs.positive = true;
303                        }
304                        "negative" => {
305                            zod_attrs.negative = true;
306                        }
307                        "nonnegative" => {
308                            zod_attrs.nonnegative = true;
309                        }
310                        "nonpositive" => {
311                            zod_attrs.nonpositive = true;
312                        }
313                        "int" => {
314                            zod_attrs.int = true;
315                        }
316                        "finite" => {
317                            zod_attrs.finite = true;
318                        }
319                        "," => {}
320                        _ => {}
321                    }
322
323                    i += 1;
324                }
325            }
326        }
327    }
328
329    zod_attrs
330}
331
332fn extract_number_from_parens(token: &str) -> Option<usize> {
333    token
334        .strip_prefix('(')
335        .and_then(|s| s.strip_suffix(')'))
336        .and_then(|inner| inner.parse::<usize>().ok())
337}
338
339fn extract_string_from_parens(token: &str) -> Option<String> {
340    token
341        .strip_prefix('(')
342        .and_then(|s| s.strip_suffix(')'))
343        .map(|s| s.to_string())
344}
345
346fn strip_quotes(value: &str) -> String {
347    if let Some(inner) = value.strip_prefix('"').and_then(|s| s.strip_suffix('"')) {
348        return inner.to_string();
349    }
350    if let Some(inner) = value.strip_prefix("r\"").and_then(|s| s.strip_suffix('"')) {
351        return inner.to_string();
352    }
353    value.to_string()
354}
355
356fn rust_type_to_zod(rust_type: &str, attrs: &ZodAttributes) -> String {
357    let base = match rust_type {
358        "String" | "&str" | "str" => {
359            let mut chain = String::from("z.string()");
360
361            if let Some(len) = attrs.length {
362                chain.push_str(&format!(".length({})", len));
363            }
364            if let Some(min) = attrs.min_length {
365                chain.push_str(&format!(".min({})", min));
366            }
367            if let Some(max) = attrs.max_length {
368                chain.push_str(&format!(".max({})", max));
369            }
370            if attrs.email {
371                chain.push_str(".email()");
372            }
373            if attrs.url {
374                chain.push_str(".url()");
375            }
376            if let Some(ref pattern) = attrs.regex {
377                chain.push_str(&format!(".regex(/{}/)", pattern));
378            }
379            if let Some(ref prefix) = attrs.starts_with {
380                chain.push_str(&format!(".startsWith(\"{}\")", prefix));
381            }
382            if let Some(ref suffix) = attrs.ends_with {
383                chain.push_str(&format!(".endsWith(\"{}\")", suffix));
384            }
385            if let Some(ref substr) = attrs.includes {
386                chain.push_str(&format!(".includes(\"{}\")", substr));
387            }
388
389            chain
390        }
391        "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32" | "u64" | "u128"
392        | "usize" => {
393            let mut chain = String::from("z.number().int()");
394            append_number_validators(&mut chain, attrs);
395            chain
396        }
397        "f32" | "f64" => {
398            let mut chain = String::from("z.number()");
399            if attrs.int {
400                chain.push_str(".int()");
401            }
402            append_number_validators(&mut chain, attrs);
403            chain
404        }
405        "bool" => String::from("z.boolean()"),
406        other => {
407            if other.starts_with("Vec<") {
408                let inner = other
409                    .strip_prefix("Vec<")
410                    .and_then(|s| s.strip_suffix('>'))
411                    .unwrap_or("unknown");
412                let inner_zod = rust_type_to_zod(inner, &ZodAttributes::default());
413                let mut chain = format!("z.array({})", inner_zod);
414
415                if let Some(len) = attrs.length {
416                    chain.push_str(&format!(".length({})", len));
417                }
418                if let Some(min) = attrs.min_length {
419                    chain.push_str(&format!(".min({})", min));
420                }
421                if let Some(max) = attrs.max_length {
422                    chain.push_str(&format!(".max({})", max));
423                }
424
425                chain
426            } else {
427                format!("{}Schema", other)
428            }
429        }
430    };
431
432    base
433}
434
435fn append_number_validators(chain: &mut String, attrs: &ZodAttributes) {
436    if let Some(min) = attrs.min {
437        chain.push_str(&format!(".min({})", min));
438    }
439    if let Some(max) = attrs.max {
440        chain.push_str(&format!(".max({})", max));
441    }
442    if attrs.positive {
443        chain.push_str(".positive()");
444    }
445    if attrs.negative {
446        chain.push_str(".negative()");
447    }
448    if attrs.nonnegative {
449        chain.push_str(".nonnegative()");
450    }
451    if attrs.nonpositive {
452        chain.push_str(".nonpositive()");
453    }
454    if attrs.finite {
455        chain.push_str(".finite()");
456    }
457}
458
459fn type_to_string(ty: &syn::Type) -> String {
460    if let syn::Type::Path(type_path) = ty {
461        let segments: Vec<String> = type_path
462            .path
463            .segments
464            .iter()
465            .map(|seg| {
466                let ident = seg.ident.to_string();
467                if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
468                    let args_str: Vec<String> = args
469                        .args
470                        .iter()
471                        .filter_map(|arg| {
472                            if let syn::GenericArgument::Type(t) = arg {
473                                Some(type_to_string(t))
474                            } else {
475                                None
476                            }
477                        })
478                        .collect();
479                    if args_str.is_empty() {
480                        ident
481                    } else {
482                        format!("{}<{}>", ident, args_str.join(", "))
483                    }
484                } else {
485                    ident
486                }
487            })
488            .collect();
489        segments.join("::")
490    } else {
491        "unknown".to_string()
492    }
493}
494
495fn is_option_type(ty: &syn::Type) -> bool {
496    if let syn::Type::Path(type_path) = ty {
497        if let Some(segment) = type_path.path.segments.last() {
498            return segment.ident == "Option";
499        }
500    }
501    false
502}
503
504fn get_option_inner_type_str(ty: &syn::Type) -> String {
505    if let syn::Type::Path(type_path) = ty {
506        if let Some(segment) = type_path.path.segments.last() {
507            if segment.ident == "Option" {
508                if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
509                    if let Some(syn::GenericArgument::Type(inner_type)) = args.args.first() {
510                        return type_to_string(inner_type);
511                    }
512                }
513            }
514        }
515    }
516    "unknown".to_string()
517}