Skip to main content

modkit_errors_macro/
lib.rs

1#![cfg_attr(coverage_nightly, feature(coverage_attribute))]
2//! Proc-macro for generating strongly-typed error catalogs from JSON.
3//!
4//! This macro reads a JSON file at compile time, validates error definitions,
5//! and generates type-safe error code enums and helper macros.
6//!
7//! ## Usage
8//!
9//! The macro is self-contained and handles imports automatically.
10//!
11//! ```rust,ignore
12//! declare_errors! {
13//!     path = "gts/errors_system.json",
14//!     namespace = "system_errors",
15//!     vis = "pub"
16//! }
17//! ```
18
19use proc_macro::TokenStream;
20use proc_macro2::{Span, TokenStream as TokenStream2};
21use quote::quote;
22use serde::Deserialize;
23use syn::parse::{Parse, ParseStream};
24use syn::{LitStr, Token, parse_macro_input};
25
26/// JSON schema for a single error definition
27#[derive(Debug, Clone, Deserialize)]
28struct ErrorEntry {
29    status: u16,
30    title: String,
31    code: String,
32    #[serde(rename = "type")]
33    type_url: Option<String>,
34    #[serde(default)]
35    alias: Option<String>,
36}
37
38/// Parsed macro input
39struct DeclareErrorsInput {
40    path: String,
41    namespace: String,
42    vis: syn::Visibility,
43}
44
45impl Parse for DeclareErrorsInput {
46    fn parse(input: ParseStream) -> syn::Result<Self> {
47        let mut path = None;
48        let mut namespace = None;
49        let mut vis = syn::Visibility::Inherited;
50
51        while !input.is_empty() {
52            let key: syn::Ident = input.parse()?;
53            input.parse::<Token![=]>()?;
54
55            match key.to_string().as_str() {
56                "path" => {
57                    let lit: LitStr = input.parse()?;
58                    path = Some(lit.value());
59                }
60                "namespace" => {
61                    let lit: LitStr = input.parse()?;
62                    namespace = Some(lit.value());
63                }
64                "vis" => {
65                    let lit: LitStr = input.parse()?;
66                    vis = match lit.value().as_str() {
67                        "pub" => syn::Visibility::Public(syn::token::Pub::default()),
68                        _ => syn::Visibility::Inherited,
69                    };
70                }
71                _ => return Err(syn::Error::new(key.span(), "Unknown parameter")),
72            }
73
74            if !input.is_empty() {
75                input.parse::<Token![,]>()?;
76            }
77        }
78
79        Ok(DeclareErrorsInput {
80            path: path.ok_or_else(|| input.error("Missing 'path' parameter"))?,
81            namespace: namespace.ok_or_else(|| input.error("Missing 'namespace' parameter"))?,
82            vis,
83        })
84    }
85}
86
87/// Main proc-macro entry point
88#[proc_macro]
89pub fn declare_errors(input: TokenStream) -> TokenStream {
90    let input = parse_macro_input!(input as DeclareErrorsInput);
91
92    match generate_errors(&input) {
93        Ok(tokens) => tokens.into(),
94        Err(e) => e.to_compile_error().into(),
95    }
96}
97
98fn generate_errors(input: &DeclareErrorsInput) -> syn::Result<TokenStream2> {
99    // Load and parse JSON file
100    let manifest_dir = std::env::var("CARGO_MANIFEST_DIR")
101        .map_err(|_| syn::Error::new(Span::call_site(), "CARGO_MANIFEST_DIR not set"))?;
102    let json_path = std::path::Path::new(&manifest_dir).join(&input.path);
103
104    let json_content = std::fs::read_to_string(&json_path).map_err(|e| {
105        syn::Error::new(
106            Span::call_site(),
107            format!(
108                "Failed to read error catalog at {}: {}",
109                json_path.display(),
110                e
111            ),
112        )
113    })?;
114
115    let entries: Vec<ErrorEntry> = serde_json::from_str(&json_content).map_err(|e| {
116        syn::Error::new(
117            Span::call_site(),
118            format!(
119                "Failed to parse error catalog JSON at {}: {}",
120                json_path.display(),
121                e
122            ),
123        )
124    })?;
125
126    // Validate entries
127    validate_entries(&entries)?;
128
129    // Compute short names and check for collisions
130    let short_names = compute_short_names(&entries)?;
131
132    let namespace_ident = syn::Ident::new(&input.namespace, Span::call_site());
133    let vis = &input.vis;
134    let json_file_path = &input.path;
135
136    let enum_variants = generate_enum_variants(&entries);
137    let const_defs = generate_const_defs(&entries);
138    let impl_methods = generate_impl_methods(&entries);
139    let short_accessors = generate_short_accessors(&entries, &short_names);
140    let from_literal_impl = generate_from_literal(&entries);
141    let macro_rules_single = generate_macro_rules_single(&entries, &namespace_ident);
142    let macro_rules_double = generate_macro_rules_double(&entries, &namespace_ident);
143    let response_macro_rules = generate_response_macro_rules(&entries, &namespace_ident);
144
145    Ok(quote! {
146        // Force Cargo to rebuild if errors.json changes
147        const _: &str = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/", #json_file_path));
148
149        // Fully-qualified imports (work both inside and outside modkit)
150        use ::modkit_errors::catalog::ErrDef;
151        use ::modkit_errors::problem::Problem;
152
153        /// Strongly-typed error codes generated from the JSON catalog
154        #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
155        #[non_exhaustive]
156        #[allow(non_camel_case_types)]
157        #vis enum ErrorCode {
158            #(#enum_variants),*
159        }
160
161        impl ErrorCode {
162            /// Get the HTTP status code for this error
163            pub const fn status(&self) -> u16 {
164                match self {
165                    #(#const_defs),*
166                }
167            }
168
169            /// Get the error definition for this error code
170            pub const fn def(&self) -> ErrDef {
171                match self {
172                    #(#impl_methods),*
173                }
174            }
175
176            /// Convert to Problem with detail (without instance/trace)
177            pub fn as_problem(&self, detail: impl Into<String>) -> Problem {
178                self.def().as_problem(detail)
179            }
180
181            /// Create a Problem with `instance` and optional `trace_id` context.
182            pub fn with_context(
183                &self,
184                detail: impl Into<String>,
185                instance: &str,
186                trace_id: Option<String>,
187            ) -> Problem {
188                let mut p = self.as_problem(detail);
189                p = p.with_instance(instance);
190                if let Some(tid) = trace_id {
191                    p = p.with_trace_id(tid);
192                }
193                p
194            }
195
196            // Short ergonomic accessor functions
197            #(#short_accessors)*
198
199            /// Internal helper to get ErrorCode from a literal string
200            #[doc(hidden)]
201            pub fn from_literal(code: &str) -> Self {
202                match code {
203                    #(#from_literal_impl,)*
204                    _ => panic!("Unknown error code literal — must be present in errors.json"),
205                }
206            }
207        }
208
209        /// Macro to create a Problem from a literal error code (compile-time validated)
210        #[macro_export]
211        macro_rules! problem_from_catalog {
212            #(#macro_rules_single)*
213            #(#macro_rules_double)*
214
215            // Catch-all for unknown codes
216            ($unknown:literal) => {
217                compile_error!(concat!("Unknown error code: ", $unknown))
218            };
219            ($unknown:literal, $detail:expr) => {
220                compile_error!(concat!("Unknown error code: ", $unknown))
221            };
222        }
223        use problem_from_catalog;
224
225        /// Macro to create a Problem directly from a literal error code with instance/trace
226        #[macro_export]
227        macro_rules! response_from_catalog {
228            #(#response_macro_rules)*
229
230            // Catch-all for unknown codes
231            ($unknown:literal, $instance:expr, $trace:expr, $($arg:tt)+) => {
232                compile_error!(concat!("Unknown error code: ", $unknown))
233            };
234            ($unknown:literal, $instance:expr, $trace:expr) => {
235                compile_error!(concat!("Unknown error code: ", $unknown))
236            };
237        }
238        use response_from_catalog;
239    })
240}
241
242fn validate_entries(entries: &[ErrorEntry]) -> syn::Result<()> {
243    let mut codes = std::collections::HashSet::new();
244    let mut titles_and_statuses = std::collections::HashMap::new();
245
246    for entry in entries {
247        // Validate status code
248        if !(100..=599).contains(&entry.status) {
249            return Err(syn::Error::new(
250                Span::call_site(),
251                format!(
252                    "Invalid HTTP status code {} for error '{}'",
253                    entry.status, entry.code
254                ),
255            ));
256        }
257
258        // Validate non-empty title
259        if entry.title.trim().is_empty() {
260            return Err(syn::Error::new(
261                Span::call_site(),
262                format!("Empty title for error '{}'", entry.code),
263            ));
264        }
265
266        // Check for duplicate codes
267        if !codes.insert(&entry.code) {
268            return Err(syn::Error::new(
269                Span::call_site(),
270                format!("Duplicate error code: '{}'", entry.code),
271            ));
272        }
273
274        // Strict GTS validation
275        validate_gts_format(&entry.code)?;
276
277        // Optional: Detect redundancy (same title+status)
278        let key = (entry.title.trim(), entry.status);
279        if let Some(existing_code) = titles_and_statuses.get(&key) {
280            eprintln!(
281                "Warning: Error codes '{}' and '{}' share identical title+status ({}:{}). Consider consolidating.",
282                existing_code, entry.code, entry.title, entry.status
283            );
284        } else {
285            titles_and_statuses.insert(key, entry.code.clone());
286        }
287    }
288
289    Ok(())
290}
291
292/// Strict GTS format validation
293///
294/// Valid format: `gts.vendor.package.namespace.type.version~chain1~chain2~...~instanceGTX`
295/// Where the final GTX (instance) must have at least 5 segments: vendor.package.namespace.type.version
296fn validate_gts_format(code: &str) -> syn::Result<()> {
297    // Must start with 'gts.'
298    if !code.starts_with("gts.") {
299        return Err(syn::Error::new(
300            Span::call_site(),
301            format!("GTS code '{code}' must start with 'gts.'"),
302        ));
303    }
304
305    // Split by '~' to get GTX chain
306    let parts: Vec<&str> = code.split('~').collect();
307    if parts.is_empty() {
308        return Err(syn::Error::new(
309            Span::call_site(),
310            format!("GTS code '{code}' is empty or malformed"),
311        ));
312    }
313
314    // Validate each GTX in the chain
315    for (idx, gtx) in parts.iter().enumerate() {
316        let segments: Vec<&str> = gtx.split('.').collect();
317
318        // First GTX must start with 'gts'
319        if idx == 0 && segments.first().is_none_or(|s| *s != "gts") {
320            return Err(syn::Error::new(
321                Span::call_site(),
322                format!("GTS code '{code}' must start with 'gts' in the first GTX"),
323            ));
324        }
325
326        // All GTX segments must be non-empty and lowercase alphanumeric (with underscores)
327        for segment in &segments {
328            if segment.is_empty() {
329                return Err(syn::Error::new(
330                    Span::call_site(),
331                    format!("GTS code '{code}' contains empty segment"),
332                ));
333            }
334            if !segment
335                .chars()
336                .all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '_')
337            {
338                return Err(syn::Error::new(
339                    Span::call_site(),
340                    format!(
341                        "GTS code '{code}' has invalid segment '{segment}': only lowercase letters, digits and underscores are allowed"
342                    ),
343                ));
344            }
345        }
346
347        // Final GTX (instance) must have at least 5 segments after 'gts'
348        if idx == parts.len() - 1 {
349            // Subtract 1 for 'gts' prefix
350            let meaningful_segments = if segments.first().is_some_and(|s| *s == "gts") {
351                segments.len() - 1
352            } else {
353                segments.len()
354            };
355
356            if meaningful_segments < 5 {
357                return Err(syn::Error::new(
358                    Span::call_site(),
359                    format!(
360                        "GTS code '{code}' is expected to have at least 5 segments in final GTX: vendor.package.namespace.type.version (found {meaningful_segments} segments)"
361                    ),
362                ));
363            }
364
365            // Validate that the final segment is a version (vN or vN.M format)
366            if let Some(last) = segments.last() {
367                let is_version = last.starts_with('v')
368                    && last.len() > 1
369                    && last[1..].chars().all(|c| c.is_ascii_digit() || c == '.')
370                    && last[1..].split('.').all(|t| !t.is_empty());
371                if !is_version {
372                    return Err(syn::Error::new(
373                        Span::call_site(),
374                        format!(
375                            "GTS code '{code}' final GTX must end with version 'vN' or 'vN.M' (found '{last}')"
376                        ),
377                    ));
378                }
379            }
380        }
381    }
382
383    // Validate that there's at least one chained GTX (format: gts.X~Y or gts.X)
384    if parts.is_empty() {
385        return Err(syn::Error::new(
386            Span::call_site(),
387            format!("GTS code '{code}' must have at least one GTX"),
388        ));
389    }
390
391    Ok(())
392}
393
394fn generate_enum_variants(entries: &[ErrorEntry]) -> Vec<TokenStream2> {
395    entries
396        .iter()
397        .map(|e| {
398            let variant = code_to_ident(&e.code);
399            let code = &e.code;
400            quote! {
401                #[doc = #code]
402                #variant
403            }
404        })
405        .collect()
406}
407
408fn generate_const_defs(entries: &[ErrorEntry]) -> Vec<TokenStream2> {
409    entries
410        .iter()
411        .map(|e| {
412            let variant = code_to_ident(&e.code);
413            let status = e.status;
414            quote! {
415                ErrorCode::#variant => #status
416            }
417        })
418        .collect()
419}
420
421fn generate_impl_methods(entries: &[ErrorEntry]) -> Vec<TokenStream2> {
422    entries
423        .iter()
424        .map(|e| {
425            let variant = code_to_ident(&e.code);
426            let status = e.status;
427            let title = &e.title;
428            let code = &e.code;
429            let type_url = match &e.type_url {
430                Some(s) => s.clone(),
431                None => format!("https://errors.example.com/{}", e.code),
432            };
433
434            quote! {
435                ErrorCode::#variant => ErrDef {
436                    status: #status,
437                    title: #title,
438                    code: #code,
439                    type_url: #type_url,
440                }
441            }
442        })
443        .collect()
444}
445
446fn generate_macro_rules_single(
447    entries: &[ErrorEntry],
448    namespace: &syn::Ident,
449) -> Vec<TokenStream2> {
450    entries
451        .iter()
452        .map(|e| {
453            let code_lit = &e.code;
454            let variant = code_to_ident(&e.code);
455
456            quote! {
457                (#code_lit) => {
458                    $crate::#namespace::ErrorCode::#variant.as_problem("")
459                };
460            }
461        })
462        .collect()
463}
464
465fn generate_macro_rules_double(
466    entries: &[ErrorEntry],
467    namespace: &syn::Ident,
468) -> Vec<TokenStream2> {
469    entries
470        .iter()
471        .map(|e| {
472            let code_lit = &e.code;
473            let variant = code_to_ident(&e.code);
474
475            quote! {
476                (#code_lit, $detail:expr) => {
477                    $crate::#namespace::ErrorCode::#variant.as_problem($detail)
478                };
479            }
480        })
481        .collect()
482}
483
484/// Convert a dotted error code to a valid Rust identifier
485fn code_to_ident(code: &str) -> syn::Ident {
486    let mut sanitized = code.replace(['.', '-', '/', '~'], "_");
487
488    // Prefix with underscore if it starts with a digit
489    if sanitized.chars().next().is_some_and(|c| c.is_ascii_digit()) {
490        sanitized = format!("_{sanitized}");
491    }
492
493    syn::Ident::new(&sanitized, Span::call_site())
494}
495
496/// Extract the final GTX segment (after last `~`) from a GTS identifier.
497/// If there is no `~`, use the entire code.
498fn last_gtx_segment(code: &str) -> &str {
499    if let Some(pos) = code.rfind('~') {
500        &code[pos + 1..]
501    } else {
502        code
503    }
504}
505
506/// Given a GTX segment "vendor.package.namespace.type.version",
507/// produce alias "`package_namespace_type_version`".
508///
509/// - Drops the vendor (first path segment)
510/// - Replaces dots with underscores
511/// - Ensures a valid Rust identifier (prefix '_' if starts with a digit)
512fn derive_alias_from_gts(code: &str) -> syn::Result<String> {
513    let gtx = last_gtx_segment(code);
514    // Expect vendor.package.namespace.type.version
515    let parts: Vec<&str> = gtx.split('.').collect();
516    if parts.len() < 5 {
517        return Err(syn::Error::new(
518            Span::call_site(),
519            format!(
520                "GTS code '{code}' is expected to have at least 5 segments in final GTX: vendor.package.namespace.type.version"
521            ),
522        ));
523    }
524    // parts[0] = vendor; we drop it
525    let rest = &parts[1..]; // package, namespace, type, version, (optionally extra minor parts are already in version)
526    let alias_raw = rest.join("_");
527
528    // Ensure valid Rust identifier (lowercase is already per spec)
529    let mut ident = alias_raw.replace(['-', '/', '~'], "_"); // just in case
530    if ident.chars().next().is_some_and(|c| c.is_ascii_digit()) {
531        ident = format!("_{ident}");
532    }
533    Ok(ident)
534}
535
536/// Compute short names for all entries, detecting collisions
537fn compute_short_names(entries: &[ErrorEntry]) -> syn::Result<Vec<String>> {
538    use std::collections::HashMap;
539
540    let mut name_to_codes: HashMap<String, Vec<&str>> = HashMap::new();
541
542    // Collect all short names (alias or derived via GTS)
543    for entry in entries {
544        let short = if let Some(alias) = &entry.alias {
545            alias.clone()
546        } else {
547            // Use new GTS-aware derivation
548            derive_alias_from_gts(&entry.code)?
549        };
550
551        name_to_codes.entry(short).or_default().push(&entry.code);
552    }
553
554    // Collision detection
555    for (name, codes) in &name_to_codes {
556        if codes.len() > 1 {
557            return Err(syn::Error::new(
558                Span::call_site(),
559                format!(
560                    "Short name collision: '{}' would be used by multiple error codes: {}. \
561                     Please add explicit 'alias' fields in errors.json to resolve this.",
562                    name,
563                    codes.join(", ")
564                ),
565            ));
566        }
567    }
568
569    // Return short names in the same order as entries
570    entries
571        .iter()
572        .map(|e| {
573            if let Some(alias) = &e.alias {
574                Ok(alias.clone())
575            } else {
576                derive_alias_from_gts(&e.code)
577            }
578        })
579        // Turn Vec<Result<String>> into Result<Vec<String>>
580        .collect::<syn::Result<Vec<String>>>()
581}
582
583/// Generate short ergonomic accessor functions
584fn generate_short_accessors(entries: &[ErrorEntry], short_names: &[String]) -> Vec<TokenStream2> {
585    entries
586        .iter()
587        .zip(short_names.iter())
588        .map(|(entry, short_name)| {
589            let full_variant = code_to_ident(&entry.code);
590            let short_ident = syn::Ident::new(short_name, Span::call_site());
591            let code = &entry.code;
592
593            quote! {
594                #[doc = concat!("Returns the error code for `", #code, "`.")]
595                pub const fn #short_ident() -> Self {
596                    Self::#full_variant
597                }
598            }
599        })
600        .collect()
601}
602
603/// Generate `from_literal` match arms
604fn generate_from_literal(entries: &[ErrorEntry]) -> Vec<TokenStream2> {
605    entries
606        .iter()
607        .map(|e| {
608            let code_lit = &e.code;
609            let variant = code_to_ident(&e.code);
610
611            quote! {
612                #code_lit => Self::#variant
613            }
614        })
615        .collect()
616}
617
618/// Generate `response_from_catalog`! macro rules (with format support)
619fn generate_response_macro_rules(
620    entries: &[ErrorEntry],
621    namespace: &syn::Ident,
622) -> Vec<TokenStream2> {
623    let mut rules = Vec::new();
624
625    for entry in entries {
626        let code_lit = &entry.code;
627        let variant = code_to_ident(&entry.code);
628
629        // Rule with formatted detail
630        rules.push(quote! {
631            (#code_lit, $instance:expr, $trace:expr, $($arg:tt)+) => {
632                $crate::#namespace::ErrorCode::#variant.with_context(
633                    format!($($arg)+),
634                    $instance,
635                    $trace
636                )
637            };
638        });
639
640        // Rule with static/empty detail
641        rules.push(quote! {
642            (#code_lit, $instance:expr, $trace:expr) => {
643                $crate::#namespace::ErrorCode::#variant.with_context("", $instance, $trace)
644            };
645        });
646    }
647
648    rules
649}