csv_schema_validator_derive/
lib.rs

1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{parse_macro_input, Data, DeriveInput, Expr, ExprLit, Fields, Ident, Lit, Meta};
6
7// Structure to store field validations
8struct FieldValidation {
9    field_name: Ident,
10    validations: Vec<Validation>,
11}
12
13// Enum representing the different types of validations that can be derived.
14enum Validation {
15    Range { min: f64, max: f64 },
16    Regex { regex: String },
17    Required,
18    Custom { path: syn::Path },
19}
20
21impl Validation {
22    /// Parses the content of `#[validate(...)]` into a vector of `Validation`.
23    fn parse_validations(input: syn::parse::ParseStream) -> syn::Result<Vec<Self>> {
24        let mut validations = Vec::new();
25
26        let meta_items =
27            syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated(input)?;
28
29        for meta in meta_items {
30            match meta {
31                Meta::Path(path) => {
32                    if path.is_ident("required") {
33                        validations.push(Validation::Required);
34                    }
35                }
36                Meta::NameValue(mnv) => {
37                    if mnv.path.is_ident("regex") {
38                        if let Expr::Lit(ExprLit {
39                            lit: Lit::Str(lit_str),
40                            ..
41                        }) = mnv.value
42                        {
43                            validations.push(Validation::Regex {
44                                regex: lit_str.value(),
45                            });
46                        } else {
47                            return Err(syn::Error::new_spanned(
48                                mnv.value,
49                                "Expected string literal for `regex`",
50                            ));
51                        }
52                    } else if mnv.path.is_ident("custom") {
53                        if let Expr::Lit(ExprLit {
54                            lit: Lit::Str(lit_str),
55                            ..
56                        }) = mnv.value
57                        {
58                            let path: syn::Path = syn::parse_str(&lit_str.value())
59                                .map_err(|e| syn::Error::new_spanned(lit_str, e))?;
60                            validations.push(Validation::Custom { path });
61                        } else {
62                            return Err(syn::Error::new_spanned(mnv.value, "Expected string literal for `custom` (e.g., `custom = \"path::to::function\"`)"));
63                        }
64                    }
65                }
66                Meta::List(meta_list) => {
67                    if meta_list.path.is_ident("range") {
68                        let mut min: Option<f64> = None;
69                        let mut max: Option<f64> = None;
70
71                        let range_items: syn::punctuated::Punctuated<
72                            syn::MetaNameValue,
73                            syn::Token![,],
74                        > = meta_list
75                            .parse_args_with(syn::punctuated::Punctuated::parse_terminated)?;
76
77                        for kv in range_items {
78                            let key = kv.path;
79                            let value = kv.value;
80                            if key.is_ident("min") {
81                                if let Expr::Lit(ExprLit {
82                                    lit: Lit::Float(lit_float),
83                                    ..
84                                }) = value
85                                {
86                                    min = Some(lit_float.base10_parse::<f64>()?);
87                                } else {
88                                    return Err(syn::Error::new_spanned(
89                                        value,
90                                        "`min` value for `range` must be a float literal",
91                                    ));
92                                }
93                            } else if key.is_ident("max") {
94                                if let Expr::Lit(ExprLit {
95                                    lit: Lit::Float(lit_float),
96                                    ..
97                                }) = value
98                                {
99                                    max = Some(lit_float.base10_parse::<f64>()?);
100                                } else {
101                                    return Err(syn::Error::new_spanned(
102                                        value,
103                                        "`max` value for `range` must be a float literal",
104                                    ));
105                                }
106                            }
107                        }
108                        if min.is_none() && max.is_none() {
109                            return Err(syn::Error::new_spanned(
110                                meta_list,
111                                "`range` validation requires at least one of `min` or `max`",
112                            ));
113                        }
114                        validations.push(Validation::Range {
115                            min: min.unwrap_or(f64::NEG_INFINITY),
116                            max: max.unwrap_or(f64::INFINITY),
117                        });
118                    }
119                }
120            }
121        }
122        Ok(validations)
123    }
124}
125
126#[proc_macro_derive(ValidateCsv, attributes(validate))]
127pub fn validate_csv_derive(input: TokenStream) -> TokenStream {
128    let input = parse_macro_input!(input as DeriveInput);
129    let name = &input.ident;
130
131    let fields = match &input.data {
132        Data::Struct(data) => match &data.fields {
133            Fields::Named(f) => &f.named,
134            _ => {
135                return syn::Error::new_spanned(
136                    &data.fields,
137                    "only structs with named fields (e.g., `struct S { a: T }`) are supported",
138                )
139                .to_compile_error()
140                .into();
141            }
142        },
143        _ => {
144            return syn::Error::new_spanned(&input, "only structs are supported")
145                .to_compile_error()
146                .into();
147        }
148    };
149
150    let mut field_validations = Vec::new();
151
152    for field in fields {
153        let field_name = field.ident.as_ref().unwrap().clone();
154        let mut validations = Vec::new();
155
156        for attr in &field.attrs {
157            if attr.path().is_ident("validate") {
158                match attr.parse_args_with(Validation::parse_validations) {
159                    Ok(mut parsed_validations) => {
160                        validations.append(&mut parsed_validations);
161                    }
162                    Err(e) => {
163                        return e.to_compile_error().into();
164                    }
165                }
166            }
167        }
168
169        if !validations.is_empty() {
170            field_validations.push(FieldValidation {
171                field_name,
172                validations,
173            });
174        }
175    }
176
177    let validation_arms = field_validations.into_iter().map(|fv| {
178        let field_name_str = fv.field_name.to_string();
179        let field_name_ident = fv.field_name;
180
181        let checks = fv.validations.into_iter().map(|validation| {
182            match validation {
183                Validation::Required => {
184                    quote! {
185                        if (&self.#field_name_ident).is_none() {
186                            errors.push(::csv_schema_validator::ValidationError {
187                                field: #field_name_str.to_string(),
188                                message: "mandatory field".to_string(),
189                            });
190                        }
191                    }
192                }
193                Validation::Range { min, max } => {
194                    quote! {
195                        let value = &self.#field_name_ident;
196                        if !(#min <= *value && *value <= #max) {
197                            errors.push(::csv_schema_validator::ValidationError {
198                                field: #field_name_str.to_string(),
199                                message: format!("value out of expected range: {} to {}", #min, #max),
200                            });
201                        }
202                    }
203                }
204                Validation::Regex { regex } => {
205                    quote! {
206                        use ::csv_schema_validator::__private::once_cell::sync::Lazy;
207                        use ::csv_schema_validator::__private::regex;
208                        static RE: Lazy<Result<regex::Regex, regex::Error>> = Lazy::new(|| regex::Regex::new(#regex));
209
210                        match RE.as_ref() {
211                            Ok(compiled_regex) => {
212                                if !compiled_regex.is_match(&self.#field_name_ident) {
213                                    errors.push(::csv_schema_validator::ValidationError {
214                                        field: #field_name_str.to_string(),
215                                        message: "does not match the expected pattern".to_string(),
216                                    });
217                                }
218                            }
219                            Err(e) => {
220                                errors.push(::csv_schema_validator::ValidationError {
221                                    field: #field_name_str.to_string(),
222                                    message: format!("invalid regex '{}': {}", #regex, e),
223                                });
224                            }
225                        }
226                    }
227                }
228                Validation::Custom { path } => {
229                    quote! {
230                        match #path(&self.#field_name_ident) {
231                            Err(err) => {
232                                errors.push(::csv_schema_validator::ValidationError {
233                                    field: #field_name_str.to_string(),
234                                    message: format!("{}", err),
235                                });
236                            }
237                            Ok(()) => {}
238                        }
239                    }
240                }
241            }
242        });
243
244        quote! {
245            #(#checks)*
246        }
247    });
248
249    let expanded = quote! {
250        impl #name {
251            pub fn validate_csv(&self) -> ::core::result::Result<(), ::std::vec::Vec<::csv_schema_validator::ValidationError>> {
252                let mut errors = ::std::vec::Vec::new();
253                #(#validation_arms)*
254                if errors.is_empty() {
255                    ::core::result::Result::Ok(())
256                } else {
257                    ::core::result::Result::Err(errors)
258                }
259            }
260        }
261    };
262
263    TokenStream::from(expanded)
264}