foxerror/
lib.rs

1//! derive macro for implementing Display and Error on enums
2//!
3//! ```rust
4//! #[derive(Debug, foxerror::FoxError)]
5//! enum Error {
6//!     NamedFields { a: i32, b: i32 },
7//!     #[err(msg = "a custom message")]
8//!     WithMessage(String),
9//!     /// or the first line of the doc comment
10//!     DocWorksToo,
11//! }
12//! ```
13
14use proc_macro2::{Span, TokenStream};
15use quote::quote;
16use syn::{
17    parse::{Parse, ParseStream},
18    DeriveInput, Token,
19};
20
21struct ParsedErrors {
22    ident: syn::Ident,
23    generics: syn::Generics,
24    variants: Vec<Variant>,
25}
26
27struct Variant {
28    ident: syn::Ident,
29    fields: syn::Fields,
30    msg: Option<String>,
31}
32
33struct AttrArg {
34    ident: syn::Ident,
35    value: Option<syn::Expr>,
36}
37
38impl Parse for AttrArg {
39    fn parse(input: ParseStream) -> syn::Result<Self> {
40        let ident = input.parse()?;
41        let value = if input.parse::<Token![=]>().is_ok() {
42            input.parse::<syn::Expr>().ok()
43        } else {
44            None
45        };
46        Ok(AttrArg { ident, value })
47    }
48}
49
50struct AttrArgs(Vec<AttrArg>);
51
52impl Parse for AttrArgs {
53    fn parse(input: ParseStream) -> syn::Result<Self> {
54        let mut args = vec![];
55        loop {
56            args.push(input.parse()?);
57            if input.parse::<Token![,]>().is_err() {
58                return Ok(Self(args));
59            }
60        }
61    }
62}
63
64fn parse_attr_doc(a: &syn::Attribute) -> Option<&syn::Expr> {
65    if !matches!(a.style, syn::AttrStyle::Outer) {
66        return None;
67    }
68    let syn::Meta::NameValue(ref nameval) = a.meta else {
69        return None;
70    };
71    if !nameval.path.is_ident("doc") {
72        return None;
73    }
74    Some(&nameval.value)
75}
76
77fn parse_attr(a: &syn::Attribute) -> Option<AttrArgs> {
78    if !matches!(a.style, syn::AttrStyle::Outer) {
79        return None;
80    }
81    let syn::Meta::List(ref list) = a.meta else {
82        return None;
83    };
84    if !matches!(list.delimiter, syn::MacroDelimiter::Paren(_)) {
85        return None;
86    }
87    if !list.path.is_ident("err") {
88        return None;
89    }
90    Some(list.parse_args().expect("could not parse attr args"))
91}
92
93fn expr_str(a: &syn::Expr) -> Option<String> {
94    match a {
95        syn::Expr::Lit(syn::ExprLit {
96            lit: syn::Lit::Str(s),
97            ..
98        }) => Some(s.value()),
99        _ => None,
100    }
101    .map(|s| s.strip_prefix(' ').unwrap_or(&s).to_string())
102}
103
104fn parse_variant(v: syn::Variant) -> Variant {
105    let doc = v.attrs.iter().flat_map(parse_attr_doc).next();
106    let args = v.attrs.iter().flat_map(parse_attr).last();
107    let amsg = args.and_then(|a| {
108        a.0.into_iter()
109            .find(|a| a.ident == "msg")
110            .and_then(|a| a.value)
111    });
112    let msg = amsg.as_ref().or(doc).and_then(expr_str);
113    Variant {
114        ident: v.ident,
115        fields: v.fields,
116        msg,
117    }
118}
119
120fn parse_derive(ast: DeriveInput) -> ParsedErrors {
121    let ident = ast.ident;
122    let generics = ast.generics;
123    let syn::Data::Enum(body) = ast.data else {
124        panic!("only enums are supported")
125    };
126    let variants = body.variants.into_iter().map(parse_variant).collect();
127
128    ParsedErrors {
129        ident,
130        generics,
131        variants,
132    }
133}
134
135fn generate(parsed: ParsedErrors) -> TokenStream {
136    let ParsedErrors {
137        ident,
138        generics,
139        variants,
140    } = parsed;
141
142    let arms = variants.into_iter().map(|v| {
143        let Variant {
144            ident: name,
145            fields,
146            msg,
147        } = v;
148        let msg = if let Some(msg) = msg {
149            quote!(#msg)
150        } else {
151            let name = name.to_string();
152            quote!(#name)
153        };
154        let mut set = quote!();
155        let mut get = vec![];
156        let mut fmt = vec![quote!("{}")];
157
158        match fields {
159            syn::Fields::Named(fields) => {
160                fmt.push(quote!(":"));
161                let mut ids = vec![];
162                for (fnum, field) in fields.named.into_iter().enumerate() {
163                    let fid = syn::Ident::new(format!("arg_{}", fnum).as_ref(), Span::call_site());
164                    get.push(quote!(#fid));
165                    let fnm = field.ident.expect("missing ident");
166                    ids.push(quote!(#fnm));
167                    if fnum > 0 {
168                        fmt.push(quote!(","));
169                    }
170                    let fo = format!(" {}: {{}}", fnm);
171                    fmt.push(quote!(#fo));
172                }
173                set = quote!({#(#ids: #get),*});
174            }
175            syn::Fields::Unnamed(fields) => {
176                fmt.push(quote!(":"));
177                for fnum in 0..fields.unnamed.len() {
178                    let fid = syn::Ident::new(format!("arg_{}", fnum).as_ref(), Span::call_site());
179                    get.push(quote!(#fid));
180                    if fnum > 0 {
181                        fmt.push(quote!(","));
182                    }
183                    fmt.push(quote!(" {}"));
184                }
185                set = quote!((#(#get),*));
186            }
187            syn::Fields::Unit => (),
188        };
189
190        quote! {
191            #ident::#name #set => write!(f, concat!(#(#fmt),*), #msg, #(#get),*)
192        }
193    });
194
195    quote! {
196        #[automatically_derived]
197        impl #generics ::core::fmt::Display for #ident #generics {
198            fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
199                match self {
200                    #(#arms,)*
201                }
202            }
203        }
204
205        #[automatically_derived]
206        impl #generics ::core::error::Error for #ident #generics {}
207    }
208}
209
210/// the derive macro itself
211///
212/// # more in-depth example
213/// ```rust
214/// #[derive(Debug, foxerror::FoxError)]
215/// enum Error<'a> {
216///     /// i am a doc comment
217///     /// other lines get ignored
218///     NoFields,
219///     /// or override the message with an attribute
220///     #[err(msg = "i have one field")]
221///     OneField(&'a str),
222///     /// my favorite numbers are
223///     ManyFields(i8, i8, i8, i8),
224///     // defaults to the variant name when no doc nor attr
225///     NamedFields {
226///         species: &'a str,
227///         leggies: u64,
228///     },
229/// }
230///
231/// assert_eq!(format!("{}", Error::NoFields), "i am a doc comment");
232/// assert_eq!(
233///     format!("{}", Error::OneField("hello")),
234///     "i have one field: hello",
235/// );
236/// assert_eq!(
237///     format!("{}", Error::ManyFields(3, 6, 2, 1)),
238///     "my favorite numbers are: 3, 6, 2, 1",
239/// );
240/// assert_eq!(
241///     format!("{}", Error::NamedFields { species: "fox", leggies: 4 }),
242///     "NamedFields: species: fox, leggies: 4",
243/// );
244/// ```
245#[proc_macro_derive(FoxError, attributes(err))]
246pub fn foxerror(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
247    let input = syn::parse(input).unwrap();
248    let parsed = parse_derive(input);
249    let output = generate(parsed);
250
251    output.into()
252}