quickerr/
lib.rs

1#![deny(missing_docs, rustdoc::all)]
2#![doc = include_str!("../README.md")]
3
4use proc_macro::TokenStream;
5use syn::__private::quote::quote;
6use syn::__private::{ToTokens, TokenStream2};
7use syn::parse::discouraged::Speculative;
8use syn::parse::{Parse, ParseStream};
9use syn::punctuated::Punctuated;
10use syn::token::{self, Colon, Comma};
11use syn::{bracketed, Attribute, Generics, Ident, LitStr, Result, Type, Visibility};
12
13/// This macro allows quickly defining errors in the format that this crate produces.
14///
15/// It has 5 major forms:
16/// - Unit struct:
17/// ```
18/// # use quickerr::error;
19/// error! {
20///     MyUnitError
21///     "it's a unit error"
22/// }
23/// ```
24/// - Record struct:
25/// ```
26/// # use quickerr::error;
27/// # #[derive(Debug)]
28/// # struct Type;
29/// # #[derive(Debug)]
30/// # struct Type2;
31/// error! {
32///     MyStructError
33///     "it's a struct! Field 2 is {field2:?}"
34///     field: Type,
35///     field2: Type2,
36/// }
37/// ```
38/// - Enum:
39/// ```
40/// # use quickerr::error;
41/// # error! { SourceError1 "" }
42/// # error! { MyUnitError "" }
43/// # error! { MyStructError "" }
44/// error! {
45///     MyEnumError
46///     "it's a whole enum"
47///     SourceError1,
48///     MyUnitError,
49///     MyStructError,
50/// }
51/// ```
52/// - Transparent enum:
53/// ```
54/// # use quickerr::error;
55/// # error! { MyEnumError "uh oh" }
56/// # error! { REALLY_LOUD_ERROR "uh oh" }
57/// error! {
58///     QuietAsAMouse
59///     MyEnumError,
60///     REALLY_LOUD_ERROR,
61/// }
62/// ```
63/// - Array:
64/// ```
65/// # use quickerr::error;
66/// # error! { SomeError "" }
67/// error! {
68///     ManyProblems
69///     "encountered many problems"
70///     [SomeError]
71/// }
72/// ```
73///
74/// Each form implements `Debug`, `Error`, and `From` as appropriate. The enum forms implement
75/// [`std::error::Error::source()`] for each of their variants, and each variant must be the name
76/// of an existing error type. The struct form exposes the fields for use in the error message.
77/// The transparent enum form does not append a message, and simply passes the source along
78/// directly. All forms are `#[non_exhaustive]` and all fields are public. They can be made public
79/// by adding `pub` to the name like `pub MyError`.
80///
81/// Additional attributes can be added before the name to add them to the error type,
82/// like so:
83/// ```
84/// # use quickerr::error;
85/// error! {
86///     #[derive(PartialEq, Eq)]
87///     AttrsError
88///     "has attributes!"
89///     /// a number for something
90///     num: i32
91/// }
92/// ```
93///
94/// Attributes can be added to fields and variants of struct/enum/array errors, and they can be
95/// made generic:
96/// ```
97/// # use quickerr::error;
98/// error! {
99///     /// In case of emergency
100///     BreakGlass<BreakingTool: std::fmt::Debug>
101///     "preferably with a blunt object"
102///     like_this_one: BreakingTool,
103/// }
104/// ```
105///
106/// If cfg attributes are used, they're copied to relevant places to ensure it compiles properly:
107/// ```
108/// # use quickerr::error;
109/// # error!{ Case1 "" }
110/// # error!{ Case2 "" }
111/// # struct Foo;
112/// # struct Bar;
113/// error! {
114///     #[cfg(feature = "drop_the_whole_error")]
115///     EnumErr
116///     "foo"
117///     #[cfg(feature = "foo")]
118///     Case1,
119///     #[cfg(feature = "bar")]
120///     Case2,
121/// }
122///
123/// error! {
124///     StructErr
125///     "bar"
126///     #[cfg(feature = "foo")]
127///     field1: Foo,
128///     #[cfg(feature = "bar")]
129///     field2: Bar,
130/// }
131/// ```
132/// Make sure not to use cfg'd fields in the error message string if those fields can ever be not
133/// present.
134#[proc_macro]
135pub fn error(tokens: TokenStream) -> TokenStream {
136    match error_impl(tokens.into()) {
137        Ok(toks) => toks.into(),
138        Err(err) => err.to_compile_error().into(),
139    }
140}
141
142fn error_impl(tokens: TokenStream2) -> Result<TokenStream2> {
143    let Error {
144        attrs,
145        vis,
146        name,
147        generics,
148        msg,
149        contents,
150    } = syn::parse2(tokens)?;
151
152    let (impl_gen, ty_gen, where_gen) = generics.split_for_impl();
153
154    let item_cfgs: Vec<&Attribute> = attrs
155        .iter()
156        .filter(|attr| attr.meta.path().is_ident("cfg"))
157        .collect();
158    let item_cfgs = quote! { #(#item_cfgs)* };
159
160    Ok(match contents {
161        ErrorContents::Unit => quote! {
162            #(#attrs)*
163            #[derive(Debug)]
164            #[non_exhaustive]
165            #vis struct #name #generics;
166
167            #item_cfgs
168            impl #impl_gen ::std::fmt::Display for #name #ty_gen #where_gen {
169                fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
170                    f.write_str(#msg)
171                }
172            }
173
174            #item_cfgs
175            impl #impl_gen ::std::error::Error for #name #ty_gen #where_gen {}
176        },
177        ErrorContents::Struct { fields } => {
178            let cfgs: Vec<Vec<&Attribute>> = fields
179                .iter()
180                .map(|field| {
181                    field
182                        .attrs
183                        .iter()
184                        .filter(|attr| attr.meta.path().is_ident("cfg"))
185                        .collect()
186                })
187                .collect();
188            let field_names: Vec<&Ident> = fields.iter().map(|field| &field.name).collect();
189            quote! {
190                #(#attrs)*
191                #[derive(Debug)]
192                #[non_exhaustive]
193                #vis struct #name #generics {
194                    #fields
195                }
196
197                #item_cfgs
198                impl #impl_gen ::std::fmt::Display for #name #ty_gen #where_gen {
199                    #[allow(unused_variables)]
200                    fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
201                        let Self {
202                            #(
203                                #(#cfgs)*
204                                #field_names,
205                            )*
206                        } = self;
207                        f.write_fmt(format_args!(#msg))
208                    }
209                }
210
211                #item_cfgs
212                impl #impl_gen ::std::error::Error for #name #ty_gen #where_gen {}
213            }
214        }
215        ErrorContents::Enum { sources } => {
216            let source_attrs: Vec<&Vec<Attribute>> =
217                sources.iter().map(|source| &source.attrs).collect();
218            let cfgs: Vec<Vec<Attribute>> = source_attrs
219                .iter()
220                .map(|&attrs| {
221                    let mut attrs = attrs.clone();
222                    attrs.retain(|attr| attr.meta.path().is_ident("cfg"));
223                    attrs
224                })
225                .collect();
226            let source_idents: Vec<&Ident> = sources.iter().map(|source| &source.ident).collect();
227            let write_msg = match &msg {
228                Some(msg) => quote! {
229                    f.write_str(#msg)
230                },
231                None => {
232                    quote! {
233                        match self {
234                            #(
235                                #(#cfgs)*
236                                Self::#source_idents(err) => ::std::fmt::Display::fmt(err, f),
237                            )*
238                            _ => unreachable!(),
239                        }
240                    }
241                }
242            };
243            quote! {
244                #(#attrs)*
245                #[derive(Debug)]
246                #[non_exhaustive]
247                #vis enum #name #generics {
248                    #(
249                        #(#source_attrs)*
250                        #source_idents(#source_idents),
251                    )*
252                }
253
254                #item_cfgs
255                impl #impl_gen ::std::fmt::Display for #name #ty_gen #where_gen {
256                    fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
257                        #write_msg
258                    }
259                }
260
261                #item_cfgs
262                impl #impl_gen ::std::error::Error for #name #ty_gen #where_gen {
263                    fn source(&self) -> ::std::option::Option<&(dyn ::std::error::Error + 'static)> {
264                        Some(match self {
265                            #(
266                                #(#cfgs)*
267                                #name::#source_idents(err) => err,
268                            )*
269                            _ => unreachable!(),
270                        })
271                    }
272                }
273
274                #(
275                    #item_cfgs
276                    #(#cfgs)*
277                    impl #impl_gen ::std::convert::From<#source_idents> for #name #ty_gen #where_gen {
278                        fn from(source: #source_idents) -> Self {
279                            Self::#source_idents(source)
280                        }
281                    }
282                )*
283            }
284        }
285        ErrorContents::Array {
286            inner_attrs, inner, ..
287        } => quote! {
288            #(#attrs)*
289            #[derive(Debug)]
290            #[non_exhaustive]
291            #vis struct #name #generics (#(#inner_attrs)* pub Vec<#inner>);
292
293            #item_cfgs
294            impl #impl_gen ::std::fmt::Display for #name #ty_gen #where_gen {
295                fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
296                    f.write_str(#msg)?;
297                    f.write_str(":")?;
298                    for err in &self.0 {
299                        f.write_str("\n")?;
300                        f.write_fmt(format_args!("{}", err))?;
301                    }
302                    Ok(())
303                }
304            }
305
306            #item_cfgs
307            impl #impl_gen ::std::error::Error for #name #ty_gen #where_gen {}
308        },
309    })
310}
311
312struct Field {
313    attrs: Vec<Attribute>,
314    vis: Visibility,
315    name: Ident,
316    colon: Colon,
317    ty: Type,
318}
319
320impl Parse for Field {
321    fn parse(input: ParseStream) -> Result<Self> {
322        Ok(Self {
323            attrs: input.call(Attribute::parse_outer)?,
324            vis: input.parse()?,
325            name: input.parse()?,
326            colon: input.parse()?,
327            ty: input.parse()?,
328        })
329    }
330}
331
332impl ToTokens for Field {
333    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
334        for attr in &self.attrs {
335            attr.to_tokens(tokens);
336        }
337        self.vis.to_tokens(tokens);
338        self.name.to_tokens(tokens);
339        self.colon.to_tokens(tokens);
340        self.ty.to_tokens(tokens);
341    }
342}
343
344struct ErrorVariant {
345    attrs: Vec<Attribute>,
346    ident: Ident,
347}
348
349impl Parse for ErrorVariant {
350    fn parse(input: ParseStream) -> Result<Self> {
351        Ok(Self {
352            attrs: input.call(Attribute::parse_outer)?,
353            ident: input.parse()?,
354        })
355    }
356}
357
358enum ErrorContents {
359    Unit,
360    Struct {
361        fields: Punctuated<Field, Comma>,
362    },
363    Enum {
364        sources: Punctuated<ErrorVariant, Comma>,
365    },
366    Array {
367        inner_attrs: Vec<Attribute>,
368        inner: Type,
369    },
370}
371
372impl Parse for ErrorContents {
373    fn parse(input: ParseStream) -> Result<Self> {
374        if input.is_empty() {
375            return Ok(Self::Unit);
376        }
377
378        let fork = input.fork();
379        if let Ok(fields) = fork.call(Punctuated::parse_terminated) {
380            input.advance_to(&fork);
381            return Ok(Self::Struct { fields });
382        }
383
384        let fork = input.fork();
385        if let Ok(sources) = fork.call(Punctuated::parse_terminated) {
386            input.advance_to(&fork);
387            return Ok(Self::Enum { sources });
388        }
389
390        if input.peek(token::Bracket) {
391            let content;
392            let _ = bracketed!(content in input);
393            let attrs = content.call(Attribute::parse_outer)?;
394            let inner = content.parse::<Type>()?;
395            return Ok(Self::Array {
396                inner_attrs: attrs,
397                inner,
398            });
399        }
400
401        Err(input.error("invalid error contents"))
402    }
403}
404
405struct Error {
406    attrs: Vec<Attribute>,
407    vis: Visibility,
408    name: Ident,
409    generics: Generics,
410    msg: Option<LitStr>,
411    contents: ErrorContents,
412}
413
414impl Parse for Error {
415    fn parse(input: ParseStream) -> Result<Self> {
416        let attrs = input.call(Attribute::parse_outer)?;
417        let vis = input.parse::<Visibility>()?;
418        let name = input.parse::<Ident>()?;
419        let generics = input.parse::<Generics>()?;
420        let msg = input.parse::<LitStr>().ok();
421        let contents = input.parse::<ErrorContents>()?;
422
423        if msg.is_none() && !matches!(contents, ErrorContents::Enum { .. }) {
424            return Err(input.error("any non-enum error must have a display message"));
425        }
426
427        Ok(Self {
428            attrs,
429            vis,
430            name,
431            generics,
432            msg,
433            contents,
434        })
435    }
436}