Skip to main content

orion_error_derive/
lib.rs

1//! Derive macros for `orion-error`.
2//!
3//! Most downstream crates should depend on `orion-error` and use its default
4//! `derive` feature instead of depending on this crate directly.
5
6use proc_macro::TokenStream;
7use proc_macro2::TokenStream as TokenStream2;
8use quote::{quote, ToTokens};
9use syn::{
10    parse_macro_input, spanned::Spanned, Attribute, Data, DeriveInput, Error, Expr, ExprLit,
11    ExprPath, Fields, Lit, LitStr, Result, Variant,
12};
13
14#[proc_macro_derive(ErrorCode, attributes(orion_error))]
15pub fn derive_error_code(input: TokenStream) -> TokenStream {
16    expand_error_code(parse_macro_input!(input as DeriveInput)).into()
17}
18
19#[proc_macro_derive(ErrorIdentityProvider, attributes(orion_error))]
20pub fn derive_error_identity_provider(input: TokenStream) -> TokenStream {
21    expand_error_identity_provider(parse_macro_input!(input as DeriveInput)).into()
22}
23
24#[proc_macro_derive(OrionError, attributes(orion_error))]
25pub fn derive_orion_error(input: TokenStream) -> TokenStream {
26    expand_orion_error(parse_macro_input!(input as DeriveInput)).into()
27}
28
29fn expand_error_code(input: DeriveInput) -> TokenStream2 {
30    match impl_error_code(input, MissingCode::Error) {
31        Ok(tokens) => tokens,
32        Err(err) => err.to_compile_error(),
33    }
34}
35
36fn expand_error_identity_provider(input: DeriveInput) -> TokenStream2 {
37    match impl_error_identity_provider(input) {
38        Ok(tokens) => tokens,
39        Err(err) => err.to_compile_error(),
40    }
41}
42
43fn expand_orion_error(input: DeriveInput) -> TokenStream2 {
44    let display = impl_display(input.clone());
45    let error_code = impl_error_code(input.clone(), MissingCode::Default);
46    let identity_provider = impl_error_identity_provider(input.clone());
47    let domain_reason = impl_domain_reason(input);
48
49    let mut out = TokenStream2::new();
50    let mut errors = Vec::new();
51    for result in [display, error_code, identity_provider, domain_reason] {
52        match result {
53            Ok(tokens) => out.extend(tokens),
54            Err(err) => errors.push(err),
55        }
56    }
57
58    match errors.into_iter().reduce(|mut first, second| {
59        first.combine(second);
60        first
61    }) {
62        Some(first) => first.to_compile_error(),
63        None => out,
64    }
65}
66
67fn impl_domain_reason(input: DeriveInput) -> Result<TokenStream2> {
68    let ident = input.ident;
69    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
70
71    match input.data {
72        Data::Enum(_) | Data::Struct(_) => Ok(quote! {
73            impl #impl_generics ::orion_error::DomainReason for #ident #ty_generics #where_clause {}
74        }),
75        Data::Union(_) => Err(Error::new(
76            ident.span(),
77            "OrionError can only be derived for enums or structs",
78        )),
79    }
80}
81
82fn impl_display(input: DeriveInput) -> Result<TokenStream2> {
83    let ident = input.ident;
84    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
85
86    match input.data {
87        Data::Enum(data) => {
88            let arms = data
89                .variants
90                .iter()
91                .map(|variant| {
92                    let args = OrionAttrs::from_attrs(&variant.attrs)?;
93                    if args.transparent {
94                        let (pattern, inner) = transparent_variant_pattern(variant)?;
95                        Ok(quote! {
96                            #pattern => ::std::fmt::Display::fmt(#inner, f)
97                        })
98                    } else if let Some(message) = args.display_message() {
99                        let pattern = variant_pattern(variant);
100                        Ok(quote! {
101                            #pattern => f.write_str(#message)
102                        })
103                    } else {
104                        Err(Error::new(
105                            variant.span(),
106                            "missing #[orion_error(message = ...)] or string literal #[orion_error(identity = ...)]",
107                        ))
108                    }
109                })
110                .collect::<Result<Vec<_>>>()?;
111
112            Ok(quote! {
113                impl #impl_generics ::std::fmt::Display for #ident #ty_generics #where_clause {
114                    fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
115                        match self {
116                            #(#arms,)*
117                        }
118                    }
119                }
120            })
121        }
122        Data::Struct(data) => {
123            let args = OrionAttrs::from_attrs(&input.attrs)?;
124            let body = if args.transparent {
125                let inner = transparent_struct_binding(&data.fields)?;
126                let pattern = struct_pattern(&ident, &data.fields);
127                quote! {
128                    match self {
129                        #pattern => ::std::fmt::Display::fmt(#inner, f),
130                    }
131                }
132            } else if let Some(message) = args.display_message() {
133                quote! { f.write_str(#message) }
134            } else {
135                return Err(Error::new(
136                    ident.span(),
137                    "missing container #[orion_error(message = ...)] or string literal #[orion_error(identity = ...)]",
138                ));
139            };
140
141            Ok(quote! {
142                impl #impl_generics ::std::fmt::Display for #ident #ty_generics #where_clause {
143                    fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
144                        #body
145                    }
146                }
147            })
148        }
149        Data::Union(_) => Err(Error::new(
150            ident.span(),
151            "OrionError can only be derived for enums or structs",
152        )),
153    }
154}
155
156#[derive(Clone, Copy)]
157enum MissingCode {
158    Error,
159    Default,
160}
161
162fn impl_error_code(input: DeriveInput, missing_code: MissingCode) -> Result<TokenStream2> {
163    let ident = input.ident;
164    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
165
166    match input.data {
167        Data::Enum(data) => {
168            let arms = data
169                .variants
170                .iter()
171                .map(|variant| {
172                    let args = OrionAttrs::from_attrs(&variant.attrs)?;
173                    if args.transparent {
174                        let (pattern, inner) = transparent_variant_pattern(variant)?;
175                        Ok(quote! {
176                            #pattern => ::orion_error::ErrorCode::error_code(#inner)
177                        })
178                    } else if let Some(code) = args.code {
179                        let pattern = variant_pattern(variant);
180                        Ok(quote! {
181                            #pattern => #code
182                        })
183                    } else if matches!(missing_code, MissingCode::Default) {
184                        let pattern = variant_pattern(variant);
185                        Ok(quote! {
186                            #pattern => 500
187                        })
188                    } else {
189                        Err(Error::new(
190                            variant.span(),
191                            "missing #[orion_error(code = ...)] or #[orion_error(transparent)]",
192                        ))
193                    }
194                })
195                .collect::<Result<Vec<_>>>()?;
196
197            Ok(quote! {
198                impl #impl_generics ::orion_error::ErrorCode for #ident #ty_generics #where_clause {
199                    fn error_code(&self) -> i32 {
200                        match self {
201                            #(#arms,)*
202                        }
203                    }
204                }
205            })
206        }
207        Data::Struct(data) => {
208            let args = OrionAttrs::from_attrs(&input.attrs)?;
209            let body = if args.transparent {
210                let inner = transparent_struct_binding(&data.fields)?;
211                let pattern = struct_pattern(&ident, &data.fields);
212                quote! {
213                    match self {
214                        #pattern => ::orion_error::ErrorCode::error_code(#inner),
215                    }
216                }
217            } else if let Some(code) = args.code {
218                quote! { #code }
219            } else if matches!(missing_code, MissingCode::Default) {
220                quote! { 500 }
221            } else {
222                return Err(Error::new(
223                    ident.span(),
224                    "missing container #[orion_error(code = ...)] or #[orion_error(transparent)]",
225                ));
226            };
227
228            Ok(quote! {
229                impl #impl_generics ::orion_error::ErrorCode for #ident #ty_generics #where_clause {
230                    fn error_code(&self) -> i32 {
231                        #body
232                    }
233                }
234            })
235        }
236        Data::Union(_) => Err(Error::new(
237            ident.span(),
238            "ErrorCode can only be derived for enums or structs",
239        )),
240    }
241}
242
243fn impl_error_identity_provider(input: DeriveInput) -> Result<TokenStream2> {
244    let ident = input.ident;
245    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
246
247    match input.data {
248        Data::Enum(data) => {
249            let stable_arms = data
250                .variants
251                .iter()
252                .map(|variant| {
253                    let args = OrionAttrs::from_attrs(&variant.attrs)?;
254                    if args.transparent {
255                        let (pattern, inner) = transparent_variant_pattern(variant)?;
256                        Ok(quote! {
257                            #pattern => ::orion_error::ErrorIdentityProvider::stable_code(#inner)
258                        })
259                    } else if let Some(identity) = args.identity {
260                        let pattern = variant_pattern(variant);
261                        Ok(quote! {
262                            #pattern => #identity
263                        })
264                    } else {
265                        Err(Error::new(
266                            variant.span(),
267                            "missing #[orion_error(identity = ...)] or #[orion_error(transparent)]",
268                        ))
269                    }
270                })
271                .collect::<Result<Vec<_>>>()?;
272
273            let category_arms = data
274                .variants
275                .iter()
276                .map(|variant| {
277                    let args = OrionAttrs::from_attrs(&variant.attrs)?;
278                    if args.transparent {
279                        let (pattern, inner) = transparent_variant_pattern(variant)?;
280                        Ok(quote! {
281                            #pattern => ::orion_error::ErrorIdentityProvider::error_category(#inner)
282                        })
283                    } else if let Some(category) = args.error_category()? {
284                        let pattern = variant_pattern(variant);
285                        Ok(quote! {
286                            #pattern => #category
287                        })
288                    } else {
289                        Err(Error::new(
290                            variant.span(),
291                            "missing #[orion_error(category = ...)] or category-prefixed string literal #[orion_error(identity = ...)]",
292                        ))
293                    }
294                })
295                .collect::<Result<Vec<_>>>()?;
296
297            Ok(quote! {
298                impl #impl_generics ::orion_error::ErrorIdentityProvider for #ident #ty_generics #where_clause {
299                    fn stable_code(&self) -> &'static str {
300                        match self {
301                            #(#stable_arms,)*
302                        }
303                    }
304
305                    fn error_category(&self) -> ::orion_error::ErrorCategory {
306                        match self {
307                            #(#category_arms,)*
308                        }
309                    }
310                }
311            })
312        }
313        Data::Struct(data) => {
314            let args = OrionAttrs::from_attrs(&input.attrs)?;
315            let (stable_body, category_body) = if args.transparent {
316                let inner = transparent_struct_binding(&data.fields)?;
317                let pattern = struct_pattern(&ident, &data.fields);
318                (
319                    quote! {
320                        match self {
321                            #pattern => ::orion_error::ErrorIdentityProvider::stable_code(#inner),
322                        }
323                    },
324                    quote! {
325                        match self {
326                            #pattern => ::orion_error::ErrorIdentityProvider::error_category(#inner),
327                        }
328                    },
329                )
330            } else {
331                let identity = args.identity.clone().ok_or_else(|| {
332                    Error::new(
333                        ident.span(),
334                        "missing container #[orion_error(identity = ...)]",
335                    )
336                })?;
337                let category = args.error_category()?.ok_or_else(|| {
338                    Error::new(
339                        ident.span(),
340                        "missing container #[orion_error(category = ...)] or category-prefixed string literal #[orion_error(identity = ...)]",
341                    )
342                })?;
343                (quote! { #identity }, quote! { #category })
344            };
345
346            Ok(quote! {
347                impl #impl_generics ::orion_error::ErrorIdentityProvider for #ident #ty_generics #where_clause {
348                    fn stable_code(&self) -> &'static str {
349                        #stable_body
350                    }
351
352                    fn error_category(&self) -> ::orion_error::ErrorCategory {
353                        #category_body
354                    }
355                }
356            })
357        }
358        Data::Union(_) => Err(Error::new(
359            ident.span(),
360            "ErrorIdentityProvider can only be derived for enums or structs",
361        )),
362    }
363}
364
365#[derive(Default)]
366struct OrionAttrs {
367    message: Option<Expr>,
368    code: Option<Expr>,
369    identity: Option<Expr>,
370    category: Option<TokenStream2>,
371    transparent: bool,
372}
373
374impl OrionAttrs {
375    fn from_attrs(attrs: &[Attribute]) -> Result<Self> {
376        let mut out = Self::default();
377        for attr in attrs {
378            if !attr.path().is_ident("orion_error") {
379                continue;
380            }
381
382            attr.parse_nested_meta(|meta| {
383                if meta.path.is_ident("transparent") {
384                    out.transparent = true;
385                    return Ok(());
386                }
387
388                if meta.path.is_ident("code") {
389                    out.code = Some(meta.value()?.parse()?);
390                    return Ok(());
391                }
392
393                if meta.path.is_ident("message") {
394                    out.message = Some(meta.value()?.parse()?);
395                    return Ok(());
396                }
397
398                if meta.path.is_ident("identity") {
399                    out.identity = Some(meta.value()?.parse()?);
400                    return Ok(());
401                }
402
403                if meta.path.is_ident("category") {
404                    let expr: Expr = meta.value()?.parse()?;
405                    out.category = Some(category_expr(expr)?);
406                    return Ok(());
407                }
408
409                Err(meta.error("unsupported orion_error attribute"))
410            })?;
411        }
412        Ok(out)
413    }
414
415    fn display_message(&self) -> Option<LitStr> {
416        self.message
417            .as_ref()
418            .and_then(expr_lit_str)
419            .cloned()
420            .or_else(|| {
421                self.identity
422                    .as_ref()
423                    .and_then(expr_lit_str)
424                    .map(message_from_identity)
425            })
426    }
427
428    fn error_category(&self) -> Result<Option<TokenStream2>> {
429        if let Some(category) = self.category.clone() {
430            return Ok(Some(category));
431        }
432
433        let Some(identity) = self.identity.as_ref().and_then(expr_lit_str) else {
434            return Ok(None);
435        };
436
437        identity_category(identity).transpose()
438    }
439}
440
441fn expr_lit_str(expr: &Expr) -> Option<&LitStr> {
442    match expr {
443        Expr::Lit(ExprLit {
444            lit: Lit::Str(lit), ..
445        }) => Some(lit),
446        _ => None,
447    }
448}
449
450fn message_from_identity(identity: &LitStr) -> LitStr {
451    let message = identity
452        .value()
453        .rsplit('.')
454        .next()
455        .unwrap_or_default()
456        .replace('_', " ");
457    LitStr::new(&message, identity.span())
458}
459
460fn identity_category(identity: &LitStr) -> Option<Result<TokenStream2>> {
461    let value = identity.value();
462    let prefix = value.split('.').next().unwrap_or_default();
463    match prefix {
464        "conf" => Some(Ok(quote! { ::orion_error::ErrorCategory::Conf })),
465        "biz" => Some(Ok(quote! { ::orion_error::ErrorCategory::Biz })),
466        "logic" => Some(Ok(quote! { ::orion_error::ErrorCategory::Logic })),
467        "sys" => Some(Ok(quote! { ::orion_error::ErrorCategory::Sys })),
468        value => Some(Err(Error::new(
469            identity.span(),
470            format!(
471                "unknown identity category prefix `{value}`; expected one of: conf, biz, logic, sys"
472            ),
473        ))),
474    }
475}
476
477fn category_expr(expr: Expr) -> Result<TokenStream2> {
478    match expr {
479        Expr::Lit(ExprLit {
480            lit: Lit::Str(lit), ..
481        }) => match lit.value().as_str() {
482            "conf" => Ok(quote! { ::orion_error::ErrorCategory::Conf }),
483            "biz" => Ok(quote! { ::orion_error::ErrorCategory::Biz }),
484            "logic" => Ok(quote! { ::orion_error::ErrorCategory::Logic }),
485            "sys" => Ok(quote! { ::orion_error::ErrorCategory::Sys }),
486            value => Err(Error::new(
487                lit.span(),
488                format!("unknown error category `{value}`; expected one of: conf, biz, logic, sys"),
489            )),
490        },
491        Expr::Path(ExprPath { path, .. })
492            if path.leading_colon.is_none() && path.segments.len() == 1 =>
493        {
494            let ident = &path.segments[0].ident;
495            match ident.to_string().as_str() {
496                "Conf" => Ok(quote! { ::orion_error::ErrorCategory::Conf }),
497                "Biz" => Ok(quote! { ::orion_error::ErrorCategory::Biz }),
498                "Logic" => Ok(quote! { ::orion_error::ErrorCategory::Logic }),
499                "Sys" => Ok(quote! { ::orion_error::ErrorCategory::Sys }),
500                _ => Ok(path.to_token_stream()),
501            }
502        }
503        other => Ok(other.to_token_stream()),
504    }
505}
506
507fn variant_pattern(variant: &Variant) -> TokenStream2 {
508    let ident = &variant.ident;
509    match &variant.fields {
510        Fields::Unit => quote! { Self::#ident },
511        Fields::Unnamed(_) => quote! { Self::#ident(..) },
512        Fields::Named(_) => quote! { Self::#ident { .. } },
513    }
514}
515
516fn transparent_variant_pattern(variant: &Variant) -> Result<(TokenStream2, TokenStream2)> {
517    let ident = &variant.ident;
518    match &variant.fields {
519        Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
520            Ok((quote! { Self::#ident(__inner) }, quote! { __inner }))
521        }
522        Fields::Named(fields) if fields.named.len() == 1 => {
523            let field = fields
524                .named
525                .iter()
526                .next()
527                .and_then(|field| field.ident.as_ref())
528                .unwrap();
529            Ok((quote! { Self::#ident { #field } }, quote! { #field }))
530        }
531        _ => Err(Error::new(
532            variant.span(),
533            "#[orion_error(transparent)] requires exactly one field",
534        )),
535    }
536}
537
538fn transparent_struct_binding(fields: &Fields) -> Result<TokenStream2> {
539    match fields {
540        Fields::Unnamed(fields) if fields.unnamed.len() == 1 => Ok(quote! { __inner }),
541        Fields::Named(fields) if fields.named.len() == 1 => {
542            let field = fields
543                .named
544                .iter()
545                .next()
546                .and_then(|field| field.ident.as_ref())
547                .unwrap();
548            Ok(quote! { #field })
549        }
550        _ => Err(Error::new(
551            fields.span(),
552            "#[orion_error(transparent)] requires exactly one field",
553        )),
554    }
555}
556
557fn struct_pattern(ident: &syn::Ident, fields: &Fields) -> TokenStream2 {
558    match fields {
559        Fields::Unit => quote! { #ident },
560        Fields::Unnamed(_) => quote! { #ident(__inner) },
561        Fields::Named(fields) => {
562            let field = fields
563                .named
564                .iter()
565                .next()
566                .and_then(|field| field.ident.as_ref());
567            quote! { #ident { #field } }
568        }
569    }
570}