Skip to main content

dog_schema_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, spanned::Spanned, Attribute, ItemMod, LitBool, LitStr, Meta, NestedMeta};
4
5#[proc_macro_attribute]
6pub fn schema(args: TokenStream, item: TokenStream) -> TokenStream {
7    let args = parse_macro_input!(args as syn::AttributeArgs);
8    let mut service: Option<LitStr> = None;
9    let mut error_message: Option<LitStr> = None;
10    let mut backend: Option<LitStr> = None;
11
12    let mut module = parse_macro_input!(item as ItemMod);
13
14    for arg in args {
15        match arg {
16            NestedMeta::Meta(Meta::NameValue(nv)) if nv.path.is_ident("service") => {
17                if let syn::Lit::Str(s) = nv.lit {
18                    service = Some(s);
19                }
20            }
21            NestedMeta::Meta(Meta::NameValue(nv)) if nv.path.is_ident("error_message") => {
22                if let syn::Lit::Str(s) = nv.lit {
23                    error_message = Some(s);
24                }
25            }
26            NestedMeta::Meta(Meta::NameValue(nv)) if nv.path.is_ident("backend") => {
27                if let syn::Lit::Str(s) = nv.lit {
28                    backend = Some(s);
29                }
30            }
31            _ => {}
32        }
33    }
34
35    let service = service.unwrap_or_else(|| LitStr::new("", proc_macro2::Span::call_site()));
36    let error_message = error_message
37        .unwrap_or_else(|| LitStr::new("Schema validation failed", proc_macro2::Span::call_site()));
38    let backend = backend.unwrap_or_else(|| LitStr::new("built_in", proc_macro2::Span::call_site()));
39
40    let (_, items) = match &mut module.content {
41        Some((brace, items)) => (brace, items),
42        None => {
43            return syn::Error::new(module.span(), "#[schema] requires an inline module").to_compile_error().into();
44        }
45    };
46
47    let mut create_struct: Option<syn::ItemStruct> = None;
48    let mut patch_struct: Option<syn::ItemStruct> = None;
49
50    for it in items.iter() {
51        if let syn::Item::Struct(s) = it {
52            if has_marker_attr(&s.attrs, "create") {
53                create_struct = Some(s.clone());
54            }
55            if has_marker_attr(&s.attrs, "patch") {
56                patch_struct = Some(s.clone());
57            }
58        }
59    }
60
61    let Some(create_struct) = create_struct else {
62        return syn::Error::new(module.span(), "#[schema] module must contain a #[create] struct")
63            .to_compile_error()
64            .into();
65    };
66
67    let create_rules = collect_field_rules(&create_struct);
68    let patch_rules = patch_struct.as_ref().map(collect_field_rules);
69
70    // Remove internal marker attrs so they don't reach rustc.
71    // They are only inputs to this macro.
72    strip_internal_attrs(items);
73
74    let create_ident = create_struct.ident.clone();
75    let patch_ident = patch_struct.as_ref().map(|s| s.ident.clone());
76
77    let resolve_create_fn = gen_resolve_create(&create_rules, &error_message);
78    let validate_create_fn = gen_validate_create(&create_rules, &error_message, &backend, &create_ident);
79    let validate_patch_fn = patch_rules
80        .as_ref()
81        .map(|rules| {
82            let patch_ident = patch_ident.as_ref().expect("patch rules implies patch struct");
83            gen_validate_patch(rules, &error_message, &backend, patch_ident)
84        })
85        .unwrap_or_else(|| quote! {});
86
87    let register_fn = gen_register_fn(&service, patch_rules.is_some());
88
89    // Append generated functions into the existing module body.
90    // (This keeps the module name stable: `posts_schema::register(...)`)
91    if let Ok(it) = syn::parse2::<syn::Item>(resolve_create_fn) {
92        items.push(it);
93    }
94    if let Ok(it) = syn::parse2::<syn::Item>(validate_create_fn) {
95        items.push(it);
96    }
97    if !validate_patch_fn.is_empty() {
98        if let Ok(it) = syn::parse2::<syn::Item>(validate_patch_fn) {
99            items.push(it);
100        }
101    }
102    if let Ok(it) = syn::parse2::<syn::Item>(register_fn) {
103        items.push(it);
104    }
105
106    TokenStream::from(quote!(#module))
107}
108
109fn has_marker_attr(attrs: &[Attribute], name: &str) -> bool {
110    attrs.iter().any(|a| a.path.is_ident(name))
111}
112
113fn strip_internal_attrs(items: &mut Vec<syn::Item>) {
114    for it in items.iter_mut() {
115        if let syn::Item::Struct(s) = it {
116            s.attrs.push(syn::parse_quote!(#[allow(dead_code)]));
117
118            // strip #[create]/#[patch]
119            s.attrs.retain(|a| {
120                !(a.path.is_ident("create") || a.path.is_ident("patch"))
121            });
122
123            // strip #[dog(...)] on fields
124            if let syn::Fields::Named(named) = &mut s.fields {
125                for f in named.named.iter_mut() {
126                    f.attrs.retain(|a| !a.path.is_ident("dog"));
127                }
128            }
129        }
130    }
131}
132
133#[derive(Clone)]
134enum FieldKind {
135    String,
136    Bool,
137    Other,
138}
139
140#[derive(Clone)]
141struct FieldRule {
142    json_key: String,
143    kind: FieldKind,
144    trim: bool,
145    min_len: Option<usize>,
146    default_bool: Option<bool>,
147    optional: bool,
148}
149
150fn collect_field_rules(st: &syn::ItemStruct) -> Vec<FieldRule> {
151    let mut rules = Vec::new();
152
153    let fields = match &st.fields {
154        syn::Fields::Named(n) => &n.named,
155        _ => return rules,
156    };
157
158    for f in fields {
159        let Some(ident) = f.ident.clone() else { continue };
160        let json_key = ident.to_string();
161
162        let mut rule = FieldRule {
163            json_key,
164            kind: field_kind(&f.ty),
165            trim: false,
166            min_len: None,
167            default_bool: None,
168            optional: is_option_type(&f.ty),
169        };
170
171        // Allow: #[dog(...)] on fields
172        for attr in &f.attrs {
173            if !attr.path.is_ident("dog") {
174                continue;
175            }
176            if let Ok(meta) = attr.parse_meta() {
177                match meta {
178                    Meta::List(list) => {
179                        for nested in list.nested {
180                            match nested {
181                                NestedMeta::Meta(Meta::Path(p)) => {
182                                    if p.is_ident("trim") {
183                                        rule.trim = true;
184                                    } else if p.is_ident("optional") {
185                                        rule.optional = true;
186                                    }
187                                }
188                                NestedMeta::Meta(Meta::List(ml)) => {
189                                    if ml.path.is_ident("min_len") {
190                                        if let Some(NestedMeta::Lit(syn::Lit::Int(n))) = ml.nested.first() {
191                                            if let Ok(v) = n.base10_parse::<usize>() {
192                                                rule.min_len = Some(v);
193                                            }
194                                        }
195                                    }
196                                }
197                                NestedMeta::Meta(Meta::NameValue(nv)) => {
198                                    if nv.path.is_ident("default") {
199                                        if let syn::Lit::Bool(LitBool { value, .. }) = nv.lit {
200                                            rule.default_bool = Some(value);
201                                        }
202                                    }
203                                }
204                                _ => {}
205                            }
206                        }
207                    }
208                    _ => {}
209                }
210            }
211        }
212
213        rules.push(rule);
214    }
215
216    rules
217}
218
219fn is_option_type(ty: &syn::Type) -> bool {
220    match ty {
221        syn::Type::Path(p) => p
222            .path
223            .segments
224            .last()
225            .is_some_and(|s| s.ident == "Option"),
226        _ => false,
227    }
228}
229
230fn field_kind(ty: &syn::Type) -> FieldKind {
231    // Detect Option<T>
232    let inner = match ty {
233        syn::Type::Path(p) => {
234            let last = p.path.segments.last();
235            if let Some(seg) = last {
236                if seg.ident == "Option" {
237                    if let syn::PathArguments::AngleBracketed(ab) = &seg.arguments {
238                        if let Some(syn::GenericArgument::Type(t)) = ab.args.first() {
239                            return field_kind(t);
240                        }
241                    }
242                }
243            }
244            ty
245        }
246        _ => ty,
247    };
248
249    match inner {
250        syn::Type::Path(p) => {
251            if let Some(seg) = p.path.segments.last() {
252                if seg.ident == "String" {
253                    return FieldKind::String;
254                }
255                if seg.ident == "bool" {
256                    return FieldKind::Bool;
257                }
258            }
259            FieldKind::Other
260        }
261        _ => FieldKind::Other,
262    }
263}
264
265fn gen_resolve_create(rules: &[FieldRule], _error_message: &LitStr) -> proc_macro2::TokenStream {
266    // trim string fields + apply default bools if missing
267    let trim_stmts = rules
268        .iter()
269        .filter(|r| r.trim && matches!(r.kind, FieldKind::String))
270        .map(|r| {
271        let key = &r.json_key;
272        quote! {
273            if let Some(serde_json::Value::String(s)) = obj.get_mut(#key) {
274                *s = s.trim().to_string();
275            }
276        }
277    });
278
279    let default_stmts = rules.iter().filter_map(|r| r.default_bool.map(|v| (r, v))).map(|(r, v)| {
280        let key = &r.json_key;
281        quote! {
282            if !obj.contains_key(#key) {
283                obj.insert(#key.to_string(), serde_json::Value::Bool(#v));
284            }
285        }
286    });
287
288    quote! {
289        pub fn resolve_create<P>(data: &mut serde_json::Value, _meta: &dog_core::schema::HookMeta<serde_json::Value, P>) -> anyhow::Result<()>
290        where
291            P: Send + Clone + 'static,
292        {
293            let Some(obj) = data.as_object_mut() else {
294                return Ok(());
295            };
296
297            #(#trim_stmts)*
298            #(#default_stmts)*
299
300            Ok(())
301        }
302    }
303}
304
305fn gen_validate_create(
306    rules: &[FieldRule],
307    error_message: &LitStr,
308    backend: &LitStr,
309    create_ident: &syn::Ident,
310) -> proc_macro2::TokenStream {
311    if backend.value() == "validator" {
312        return quote! {
313            pub fn validate_create<P>(
314                data: &serde_json::Value,
315                _meta: &dog_core::schema::HookMeta<serde_json::Value, P>,
316            ) -> anyhow::Result<()>
317            where
318                P: Send + Clone + 'static,
319            {
320                let _parsed: #create_ident = dog_schema_validator::validate::<#create_ident>(data, #error_message)?;
321                Ok(())
322            }
323        };
324    }
325
326    let checks = rules.iter().map(|r| {
327        let key = &r.json_key;
328        let min_len = r.min_len;
329
330        match r.kind {
331            FieldKind::String => {
332                let min_len_check = if let Some(n) = min_len {
333                    quote! {
334                        if v.chars().count() < #n {
335                            errs.push_field(#key, format!("must be at least {} chars", #n));
336                        }
337                    }
338                } else {
339                    quote! {}
340                };
341
342                if r.optional {
343                    quote! {
344                        if let Some(v) = obj.get(#key).and_then(|v| v.as_str()) {
345                            if v.trim().is_empty() {
346                                errs.push_field(#key, "must not be empty");
347                            }
348                            #min_len_check
349                        }
350                    }
351                } else {
352                    quote! {
353                        match obj.get(#key) {
354                            None => errs.push_schema(format!("missing field `{}`", #key)),
355                            Some(val) => {
356                                if let Some(v) = val.as_str() {
357                                    if v.trim().is_empty() {
358                                        errs.push_field(#key, "must not be empty");
359                                    }
360                                    #min_len_check
361                                } else {
362                                    errs.push_field(#key, "must be a string");
363                                }
364                            }
365                        }
366                    }
367                }
368            }
369            FieldKind::Bool => {
370                let allow_missing = r.default_bool.is_some() || r.optional;
371                if allow_missing {
372                    quote! {
373                        if let Some(val) = obj.get(#key) {
374                            if !val.is_boolean() {
375                                errs.push_field(#key, "must be a boolean");
376                            }
377                        }
378                    }
379                } else {
380                    quote! {
381                        match obj.get(#key) {
382                            None => errs.push_schema(format!("missing field `{}`", #key)),
383                            Some(val) => {
384                                if !val.is_boolean() {
385                                    errs.push_field(#key, "must be a boolean");
386                                }
387                            }
388                        }
389                    }
390                }
391            }
392            FieldKind::Other => {
393                // For MVP: only enforce presence for non-optional fields.
394                if r.optional {
395                    quote! {}
396                } else {
397                    quote! {
398                        if obj.get(#key).is_none() {
399                            errs.push_schema(format!("missing field `{}`", #key));
400                        }
401                    }
402                }
403            }
404        }
405    });
406
407    quote! {
408        pub fn validate_create<P>(data: &serde_json::Value, _meta: &dog_core::schema::HookMeta<serde_json::Value, P>) -> anyhow::Result<()>
409        where
410            P: Send + Clone + 'static,
411        {
412            let Some(obj) = data.as_object() else {
413                return Err(dog_schema::schema_error(#error_message, "expected JSON object"));
414            };
415
416            let mut errs = dog_schema::SchemaErrors::default();
417
418            #(#checks)*
419
420            if errs.is_empty() {
421                Ok(())
422            } else {
423                Err(errs.into_unprocessable_anyhow(#error_message))
424            }
425        }
426    }
427}
428
429fn gen_validate_patch(
430    rules: &[FieldRule],
431    error_message: &LitStr,
432    backend: &LitStr,
433    patch_ident: &syn::Ident,
434) -> proc_macro2::TokenStream {
435    if backend.value() == "validator" {
436        return quote! {
437            pub fn validate_patch<P>(
438                data: &serde_json::Value,
439                _meta: &dog_core::schema::HookMeta<serde_json::Value, P>,
440            ) -> anyhow::Result<()>
441            where
442                P: Send + Clone + 'static,
443            {
444                let _parsed: #patch_ident = dog_schema_validator::validate::<#patch_ident>(data, #error_message)?;
445                Ok(())
446            }
447        };
448    }
449
450    let checks = rules.iter().map(|r| {
451        let key = &r.json_key;
452        let min_len = r.min_len;
453
454        match r.kind {
455            FieldKind::String => {
456                let min_len_check = if let Some(n) = min_len {
457                    quote! {
458                        if v.chars().count() < #n {
459                            errs.push_field(#key, format!("must be at least {} chars", #n));
460                        }
461                    }
462                } else {
463                    quote! {}
464                };
465
466                quote! {
467                    if let Some(val) = obj.get(#key) {
468                        if val.is_null() {
469                            // allow null (treat as not provided)
470                        } else if let Some(v) = val.as_str() {
471                            if v.trim().is_empty() {
472                                errs.push_field(#key, "must not be empty");
473                            }
474                            #min_len_check
475                        } else {
476                            errs.push_field(#key, "must be a string");
477                        }
478                    }
479                }
480            }
481            FieldKind::Bool => {
482                quote! {
483                    if let Some(val) = obj.get(#key) {
484                        if val.is_null() {
485                            // allow null
486                        } else if !val.is_boolean() {
487                            errs.push_field(#key, "must be a boolean");
488                        }
489                    }
490                }
491            }
492            FieldKind::Other => {
493                quote! {
494                    if let Some(val) = obj.get(#key) {
495                        if val.is_null() {
496                            // allow null
497                        }
498                    }
499                }
500            }
501        }
502    });
503
504    quote! {
505        pub fn validate_patch<P>(data: &serde_json::Value, _meta: &dog_core::schema::HookMeta<serde_json::Value, P>) -> anyhow::Result<()>
506        where
507            P: Send + Clone + 'static,
508        {
509            let Some(obj) = data.as_object() else {
510                return Err(dog_schema::schema_error(#error_message, "expected JSON object"));
511            };
512
513            let mut errs = dog_schema::SchemaErrors::default();
514
515            #(#checks)*
516
517            if errs.is_empty() {
518                Ok(())
519            } else {
520                Err(errs.into_unprocessable_anyhow(#error_message))
521            }
522        }
523    }
524}
525
526fn gen_register_fn(service: &LitStr, has_patch: bool) -> proc_macro2::TokenStream {
527    let svc = service.value();
528    let svc_lit = LitStr::new(&svc, service.span());
529
530    let patch = if has_patch {
531        quote! {
532            s.on_patch().validate(validate_patch);
533        }
534    } else {
535        quote! {}
536    };
537
538    quote! {
539        pub fn register<P>(app: &dog_core::DogApp<serde_json::Value, P>) -> anyhow::Result<()>
540        where
541            P: Send + Clone + 'static,
542        {
543            use dog_core::schema::SchemaHooksExt;
544
545            app.service(#svc_lit)?.hooks(|h| {
546                h.schema(|s| {
547                    s.on_create().resolve(resolve_create).validate(validate_create);
548                    #patch
549                    s.on_update().validate(validate_create);
550                });
551            });
552
553            Ok(())
554        }
555    }
556}