axum_thiserror/
lib.rs

1use proc_macro2::{Ident, TokenStream};
2use quote::{quote, quote_spanned};
3use syn::{
4    parse_macro_input, punctuated::Punctuated, spanned::Spanned, token::Comma, Data, DeriveInput,
5    Expr, LitInt, Meta, Path, Variant,
6};
7
8///! # axum_thiserror
9///! `axum_thiserror` is a library that offers a procedural macro to allow `thiserror` error types to be used as `axum` responses.
10///! ## Usage
11///! Add the library to your current project using Cargo:
12///! ```bash
13///! cargo add axum_thiserror
14///! ```
15///! Then you can create a basic `thiserror` error:
16///! ```rust
17///! #[derive(Error, Debug)]
18///! pub enum UserCreateError {
19///!   #[error("User {0} already exists")]
20///!   UserAlreadyExists(String),
21///! }
22///! ```
23///! Now you can use `axum_thiserror` to implement `IntoResponse` on your error:
24///! ```rust
25///! #[derive(Error, Debug, ErrorStatus)]
26///! pub enum UserCreateError {
27///!   #[error("User {0} already exists")]
28///!   #[status(StatusCode::CONFLICT)]
29///!   UserAlreadyExists(String),
30///! }
31///! ```
32///! ## License
33///! This project is licensed under the [MIT License](LICENSE).
34
35/// A derivation that implements the `IntoResponse` trait for a specific attribute.
36#[proc_macro_derive(ErrorStatus, attributes(status))]
37pub fn derive_error_status(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
38    let ast: DeriveInput = parse_macro_input!(input);
39    let enum_ident = ast.ident;
40    let cases: Punctuated<TokenStream, Comma> = match ast.data {
41        Data::Enum(data) => data.variants,
42        _ => panic!(
43            "#[derive(ErrorStatus)] is only available for enums, other types are not supported."
44        ),
45    }
46    .iter()
47    .map(impl_enum_variant)
48    .collect();
49
50    quote! {
51        impl axum::response::IntoResponse for #enum_ident {
52            fn into_response(self) -> axum::response::Response {
53                match self {
54                    #cases
55                }
56            }
57        }
58    }
59    .into()
60}
61
62fn impl_enum_variant(input: &Variant) -> TokenStream {
63    let status_code = find_status_code(input);
64    let case = if input.fields.is_empty() {
65        case_empty_fields(input)
66    } else if input.fields.iter().filter(|f| f.ident.is_none()).count() > 0 {
67        case_unnamed_fields(input)
68    } else {
69        case_named_fields(input)
70    };
71    quote! {
72        Self::#case => (#status_code, format!("{}", self)).into_response()
73    }
74}
75
76fn case_empty_fields(input: &Variant) -> TokenStream {
77    let ident = &input.ident;
78    quote!(#ident)
79}
80
81fn case_unnamed_fields(input: &Variant) -> TokenStream {
82    let ident = &input.ident;
83    quote!(#ident( .. ))
84}
85
86fn case_named_fields(input: &Variant) -> TokenStream {
87    let ident = &input.ident;
88    quote!(#ident { .. })
89}
90
91fn find_status_code(input: &Variant) -> TokenStream {
92    match input
93        .attrs
94        .iter()
95        .find(|attr| attr.path().is_ident("status"))
96    {
97        Some(attr) => match &attr.meta {
98            Meta::List(l) => {
99                if let Ok(number) = l.parse_args::<LitInt>() {
100                    quote! {
101                        axum::http::StatusCode::from_u16(#number as u16).unwrap()
102                    }
103                } else if let Ok(expr) = l.parse_args::<Path>() {
104                    quote! {
105                        #expr
106                    }
107                } else {
108                    quote_spanned!(l.span() => compile_error!("Only #[status(StatusCode)] or #[status(u16)] syntaxe is supported"))
109                }
110            }
111            _ => {
112                quote_spanned! { attr.span() => compile_error!("Only #[status(StatusCode)] or #[status(u16)] syntaxe is supported") }
113            }
114        },
115        None => {
116            quote_spanned! { input.span() => compile_error!("Each enum variant should have a status code provided using the #[status()] attribute") }
117        }
118    }
119}