thiserror_impl/
expand.rs

1use crate::ast::{Enum, Field, Input, Struct};
2use crate::attr::Trait;
3use crate::generics::InferredBounds;
4use crate::span::MemberSpan;
5use proc_macro2::TokenStream;
6use quote::{format_ident, quote, quote_spanned, ToTokens};
7use std::collections::BTreeSet as Set;
8use syn::{DeriveInput, GenericArgument, Member, PathArguments, Result, Token, Type};
9
10pub fn derive(input: &DeriveInput) -> TokenStream {
11    match try_expand(input) {
12        Ok(expanded) => expanded,
13        // If there are invalid attributes in the input, expand to an Error impl
14        // anyway to minimize spurious knock-on errors in other code that uses
15        // this type as an Error.
16        Err(error) => fallback(input, error),
17    }
18}
19
20fn try_expand(input: &DeriveInput) -> Result<TokenStream> {
21    let input = Input::from_syn(input)?;
22    input.validate()?;
23    Ok(match input {
24        Input::Struct(input) => impl_struct(input),
25        Input::Enum(input) => impl_enum(input),
26    })
27}
28
29fn fallback(input: &DeriveInput, error: syn::Error) -> TokenStream {
30    let ty = &input.ident;
31    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
32
33    let error = error.to_compile_error();
34
35    quote! {
36        #error
37
38        #[allow(unused_qualifications)]
39        impl #impl_generics std::error::Error for #ty #ty_generics #where_clause
40        where
41            // Work around trivial bounds being unstable.
42            // https://github.com/rust-lang/rust/issues/48214
43            for<'workaround> #ty #ty_generics: ::core::fmt::Debug,
44        {}
45
46        #[allow(unused_qualifications)]
47        impl #impl_generics ::core::fmt::Display for #ty #ty_generics #where_clause {
48            fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
49                ::core::unreachable!()
50            }
51        }
52    }
53}
54
55fn impl_struct(input: Struct) -> TokenStream {
56    let ty = &input.ident;
57    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
58    let mut error_inferred_bounds = InferredBounds::new();
59
60    let source_body = if let Some(transparent_attr) = &input.attrs.transparent {
61        let only_field = &input.fields[0];
62        if only_field.contains_generic {
63            error_inferred_bounds.insert(only_field.ty, quote!(std::error::Error));
64        }
65        let member = &only_field.member;
66        Some(quote_spanned! {transparent_attr.span=>
67            std::error::Error::source(self.#member.as_dyn_error())
68        })
69    } else if let Some(source_field) = input.source_field() {
70        let source = &source_field.member;
71        if source_field.contains_generic {
72            let ty = unoptional_type(source_field.ty);
73            error_inferred_bounds.insert(ty, quote!(std::error::Error + 'static));
74        }
75        let asref = if type_is_option(source_field.ty) {
76            Some(quote_spanned!(source.member_span()=> .as_ref()?))
77        } else {
78            None
79        };
80        let dyn_error = quote_spanned! {source_field.source_span()=>
81            self.#source #asref.as_dyn_error()
82        };
83        Some(quote! {
84            ::core::option::Option::Some(#dyn_error)
85        })
86    } else {
87        None
88    };
89    let source_method = source_body.map(|body| {
90        quote! {
91            fn source(&self) -> ::core::option::Option<&(dyn std::error::Error + 'static)> {
92                use thiserror::__private::AsDynError as _;
93                #body
94            }
95        }
96    });
97
98    let provide_method = input.backtrace_field().map(|backtrace_field| {
99        let request = quote!(request);
100        let backtrace = &backtrace_field.member;
101        let body = if let Some(source_field) = input.source_field() {
102            let source = &source_field.member;
103            let source_provide = if type_is_option(source_field.ty) {
104                quote_spanned! {source.member_span()=>
105                    if let ::core::option::Option::Some(source) = &self.#source {
106                        source.thiserror_provide(#request);
107                    }
108                }
109            } else {
110                quote_spanned! {source.member_span()=>
111                    self.#source.thiserror_provide(#request);
112                }
113            };
114            let self_provide = if source == backtrace {
115                None
116            } else if type_is_option(backtrace_field.ty) {
117                Some(quote! {
118                    if let ::core::option::Option::Some(backtrace) = &self.#backtrace {
119                        #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
120                    }
121                })
122            } else {
123                Some(quote! {
124                    #request.provide_ref::<std::backtrace::Backtrace>(&self.#backtrace);
125                })
126            };
127            quote! {
128                use thiserror::__private::ThiserrorProvide as _;
129                #source_provide
130                #self_provide
131            }
132        } else if type_is_option(backtrace_field.ty) {
133            quote! {
134                if let ::core::option::Option::Some(backtrace) = &self.#backtrace {
135                    #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
136                }
137            }
138        } else {
139            quote! {
140                #request.provide_ref::<std::backtrace::Backtrace>(&self.#backtrace);
141            }
142        };
143        quote! {
144            fn provide<'_request>(&'_request self, #request: &mut std::error::Request<'_request>) {
145                #body
146            }
147        }
148    });
149
150    let mut display_implied_bounds = Set::new();
151    let display_body = if input.attrs.transparent.is_some() {
152        let only_field = &input.fields[0].member;
153        display_implied_bounds.insert((0, Trait::Display));
154        Some(quote! {
155            ::core::fmt::Display::fmt(&self.#only_field, __formatter)
156        })
157    } else if let Some(display) = &input.attrs.display {
158        display_implied_bounds = display.implied_bounds.clone();
159        let use_as_display = use_as_display(display.has_bonus_display);
160        let pat = fields_pat(&input.fields);
161        Some(quote! {
162            #use_as_display
163            #[allow(unused_variables, deprecated)]
164            let Self #pat = self;
165            #display
166        })
167    } else {
168        None
169    };
170    let display_impl = display_body.map(|body| {
171        let mut display_inferred_bounds = InferredBounds::new();
172        for (field, bound) in display_implied_bounds {
173            let field = &input.fields[field];
174            if field.contains_generic {
175                display_inferred_bounds.insert(field.ty, bound);
176            }
177        }
178        let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
179        quote! {
180            #[allow(unused_qualifications)]
181            impl #impl_generics ::core::fmt::Display for #ty #ty_generics #display_where_clause {
182                #[allow(clippy::used_underscore_binding)]
183                fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
184                    #body
185                }
186            }
187        }
188    });
189
190    let from_impl = input.from_field().map(|from_field| {
191        let backtrace_field = input.distinct_backtrace_field();
192        let from = unoptional_type(from_field.ty);
193        let body = from_initializer(from_field, backtrace_field);
194        quote! {
195            #[allow(unused_qualifications)]
196            impl #impl_generics ::core::convert::From<#from> for #ty #ty_generics #where_clause {
197                #[allow(deprecated)]
198                fn from(source: #from) -> Self {
199                    #ty #body
200                }
201            }
202        }
203    });
204
205    if input.generics.type_params().next().is_some() {
206        let self_token = <Token![Self]>::default();
207        error_inferred_bounds.insert(self_token, Trait::Debug);
208        error_inferred_bounds.insert(self_token, Trait::Display);
209    }
210    let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics);
211
212    if crate::use_std() {
213        quote! {
214            #[allow(unused_qualifications)]
215            impl #impl_generics std::error::Error for #ty #ty_generics #error_where_clause {
216                #source_method
217                #provide_method
218            }
219            #display_impl
220            #from_impl
221        }
222    } else {
223        quote! {
224            #display_impl
225            #from_impl
226        }
227    }
228}
229
230fn impl_enum(input: Enum) -> TokenStream {
231    let ty = &input.ident;
232    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
233    let mut error_inferred_bounds = InferredBounds::new();
234
235    let source_method = if input.has_source() {
236        let arms = input.variants.iter().map(|variant| {
237            let ident = &variant.ident;
238            if let Some(transparent_attr) = &variant.attrs.transparent {
239                let only_field = &variant.fields[0];
240                if only_field.contains_generic {
241                    error_inferred_bounds.insert(only_field.ty, quote!(std::error::Error));
242                }
243                let member = &only_field.member;
244                let source = quote_spanned! {transparent_attr.span=>
245                    std::error::Error::source(transparent.as_dyn_error())
246                };
247                quote! {
248                    #ty::#ident {#member: transparent} => #source,
249                }
250            } else if let Some(source_field) = variant.source_field() {
251                let source = &source_field.member;
252                if source_field.contains_generic {
253                    let ty = unoptional_type(source_field.ty);
254                    error_inferred_bounds.insert(ty, quote!(std::error::Error + 'static));
255                }
256                let asref = if type_is_option(source_field.ty) {
257                    Some(quote_spanned!(source.member_span()=> .as_ref()?))
258                } else {
259                    None
260                };
261                let varsource = quote!(source);
262                let dyn_error = quote_spanned! {source_field.source_span()=>
263                    #varsource #asref.as_dyn_error()
264                };
265                quote! {
266                    #ty::#ident {#source: #varsource, ..} => ::core::option::Option::Some(#dyn_error),
267                }
268            } else {
269                quote! {
270                    #ty::#ident {..} => ::core::option::Option::None,
271                }
272            }
273        });
274        Some(quote! {
275            fn source(&self) -> ::core::option::Option<&(dyn std::error::Error + 'static)> {
276                use thiserror::__private::AsDynError as _;
277                #[allow(deprecated)]
278                match self {
279                    #(#arms)*
280                }
281            }
282        })
283    } else {
284        None
285    };
286
287    let provide_method = if input.has_backtrace() {
288        let request = quote!(request);
289        let arms = input.variants.iter().map(|variant| {
290            let ident = &variant.ident;
291            match (variant.backtrace_field(), variant.source_field()) {
292                (Some(backtrace_field), Some(source_field))
293                    if backtrace_field.attrs.backtrace.is_none() =>
294                {
295                    let backtrace = &backtrace_field.member;
296                    let source = &source_field.member;
297                    let varsource = quote!(source);
298                    let source_provide = if type_is_option(source_field.ty) {
299                        quote_spanned! {source.member_span()=>
300                            if let ::core::option::Option::Some(source) = #varsource {
301                                source.thiserror_provide(#request);
302                            }
303                        }
304                    } else {
305                        quote_spanned! {source.member_span()=>
306                            #varsource.thiserror_provide(#request);
307                        }
308                    };
309                    let self_provide = if type_is_option(backtrace_field.ty) {
310                        quote! {
311                            if let ::core::option::Option::Some(backtrace) = backtrace {
312                                #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
313                            }
314                        }
315                    } else {
316                        quote! {
317                            #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
318                        }
319                    };
320                    quote! {
321                        #ty::#ident {
322                            #backtrace: backtrace,
323                            #source: #varsource,
324                            ..
325                        } => {
326                            use thiserror::__private::ThiserrorProvide as _;
327                            #source_provide
328                            #self_provide
329                        }
330                    }
331                }
332                (Some(backtrace_field), Some(source_field))
333                    if backtrace_field.member == source_field.member =>
334                {
335                    let backtrace = &backtrace_field.member;
336                    let varsource = quote!(source);
337                    let source_provide = if type_is_option(source_field.ty) {
338                        quote_spanned! {backtrace.member_span()=>
339                            if let ::core::option::Option::Some(source) = #varsource {
340                                source.thiserror_provide(#request);
341                            }
342                        }
343                    } else {
344                        quote_spanned! {backtrace.member_span()=>
345                            #varsource.thiserror_provide(#request);
346                        }
347                    };
348                    quote! {
349                        #ty::#ident {#backtrace: #varsource, ..} => {
350                            use thiserror::__private::ThiserrorProvide as _;
351                            #source_provide
352                        }
353                    }
354                }
355                (Some(backtrace_field), _) => {
356                    let backtrace = &backtrace_field.member;
357                    let body = if type_is_option(backtrace_field.ty) {
358                        quote! {
359                            if let ::core::option::Option::Some(backtrace) = backtrace {
360                                #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
361                            }
362                        }
363                    } else {
364                        quote! {
365                            #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
366                        }
367                    };
368                    quote! {
369                        #ty::#ident {#backtrace: backtrace, ..} => {
370                            #body
371                        }
372                    }
373                }
374                (None, _) => quote! {
375                    #ty::#ident {..} => {}
376                },
377            }
378        });
379        Some(quote! {
380            fn provide<'_request>(&'_request self, #request: &mut std::error::Request<'_request>) {
381                #[allow(deprecated)]
382                match self {
383                    #(#arms)*
384                }
385            }
386        })
387    } else {
388        None
389    };
390
391    let display_impl = if input.has_display() {
392        let mut display_inferred_bounds = InferredBounds::new();
393        let has_bonus_display = input.variants.iter().any(|v| {
394            v.attrs
395                .display
396                .as_ref()
397                .map_or(false, |display| display.has_bonus_display)
398        });
399        let use_as_display = use_as_display(has_bonus_display);
400        let void_deref = if input.variants.is_empty() {
401            Some(quote!(*))
402        } else {
403            None
404        };
405        let arms = input.variants.iter().map(|variant| {
406            let mut display_implied_bounds = Set::new();
407            let display = match &variant.attrs.display {
408                Some(display) => {
409                    display_implied_bounds = display.implied_bounds.clone();
410                    display.to_token_stream()
411                }
412                None => {
413                    let only_field = match &variant.fields[0].member {
414                        Member::Named(ident) => ident.clone(),
415                        Member::Unnamed(index) => format_ident!("_{}", index),
416                    };
417                    display_implied_bounds.insert((0, Trait::Display));
418                    quote!(::core::fmt::Display::fmt(#only_field, __formatter))
419                }
420            };
421            for (field, bound) in display_implied_bounds {
422                let field = &variant.fields[field];
423                if field.contains_generic {
424                    display_inferred_bounds.insert(field.ty, bound);
425                }
426            }
427            let ident = &variant.ident;
428            let pat = fields_pat(&variant.fields);
429            quote! {
430                #ty::#ident #pat => #display
431            }
432        });
433        let arms = arms.collect::<Vec<_>>();
434        let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
435        Some(quote! {
436            #[allow(unused_qualifications)]
437            impl #impl_generics ::core::fmt::Display for #ty #ty_generics #display_where_clause {
438                fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
439                    #use_as_display
440                    #[allow(unused_variables, deprecated, clippy::used_underscore_binding)]
441                    match #void_deref self {
442                        #(#arms,)*
443                    }
444                }
445            }
446        })
447    } else {
448        None
449    };
450
451    let from_impls = input.variants.iter().filter_map(|variant| {
452        let from_field = variant.from_field()?;
453        let backtrace_field = variant.distinct_backtrace_field();
454        let variant = &variant.ident;
455        let from = unoptional_type(from_field.ty);
456        let body = from_initializer(from_field, backtrace_field);
457        Some(quote! {
458            #[allow(unused_qualifications)]
459            impl #impl_generics ::core::convert::From<#from> for #ty #ty_generics #where_clause {
460                #[allow(deprecated)]
461                fn from(source: #from) -> Self {
462                    #ty::#variant #body
463                }
464            }
465        })
466    });
467
468    if input.generics.type_params().next().is_some() {
469        let self_token = <Token![Self]>::default();
470        error_inferred_bounds.insert(self_token, Trait::Debug);
471        error_inferred_bounds.insert(self_token, Trait::Display);
472    }
473    let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics);
474
475    if crate::use_std() {
476        quote! {
477            #[allow(unused_qualifications)]
478            impl #impl_generics std::error::Error for #ty #ty_generics #error_where_clause {
479                #source_method
480                #provide_method
481            }
482            #display_impl
483            #(#from_impls)*
484        }
485    } else {
486        quote! {
487            #display_impl
488            #(#from_impls)*
489        }
490    }
491}
492
493fn fields_pat(fields: &[Field]) -> TokenStream {
494    let mut members = fields.iter().map(|field| &field.member).peekable();
495    match members.peek() {
496        Some(Member::Named(_)) => quote!({ #(#members),* }),
497        Some(Member::Unnamed(_)) => {
498            let vars = members.map(|member| match member {
499                Member::Unnamed(member) => format_ident!("_{}", member),
500                Member::Named(_) => unreachable!(),
501            });
502            quote!((#(#vars),*))
503        }
504        None => quote!({}),
505    }
506}
507
508fn use_as_display(needs_as_display: bool) -> Option<TokenStream> {
509    if needs_as_display {
510        Some(quote! {
511            use thiserror::__private::AsDisplay as _;
512        })
513    } else {
514        None
515    }
516}
517
518fn from_initializer(from_field: &Field, backtrace_field: Option<&Field>) -> TokenStream {
519    let from_member = &from_field.member;
520    let some_source = if type_is_option(from_field.ty) {
521        quote!(::core::option::Option::Some(source))
522    } else {
523        quote!(source)
524    };
525    let backtrace = backtrace_field.map(|backtrace_field| {
526        let backtrace_member = &backtrace_field.member;
527        if type_is_option(backtrace_field.ty) {
528            quote! {
529                #backtrace_member: ::core::option::Option::Some(std::backtrace::Backtrace::capture()),
530            }
531        } else {
532            quote! {
533                #backtrace_member: ::core::convert::From::from(std::backtrace::Backtrace::capture()),
534            }
535        }
536    });
537    quote!({
538        #from_member: #some_source,
539        #backtrace
540    })
541}
542
543fn type_is_option(ty: &Type) -> bool {
544    type_parameter_of_option(ty).is_some()
545}
546
547fn unoptional_type(ty: &Type) -> TokenStream {
548    let unoptional = type_parameter_of_option(ty).unwrap_or(ty);
549    quote!(#unoptional)
550}
551
552fn type_parameter_of_option(ty: &Type) -> Option<&Type> {
553    let path = match ty {
554        Type::Path(ty) => &ty.path,
555        _ => return None,
556    };
557
558    let last = path.segments.last().unwrap();
559    if last.ident != "Option" {
560        return None;
561    }
562
563    let bracketed = match &last.arguments {
564        PathArguments::AngleBracketed(bracketed) => bracketed,
565        _ => return None,
566    };
567
568    if bracketed.args.len() != 1 {
569        return None;
570    }
571
572    match &bracketed.args[0] {
573        GenericArgument::Type(arg) => Some(arg),
574        _ => None,
575    }
576}