miden_thiserror_impl/
expand.rs

1use crate::ast::{Enum, Field, Input, Struct};
2use crate::attr::Trait;
3use crate::generics::InferredBounds;
4use crate::span::MemberSpan;
5use proc_macro2::TokenStream;
6use quote::{format_ident, quote, quote_spanned, ToTokens};
7use std::collections::BTreeSet as Set;
8use syn::{DeriveInput, GenericArgument, Member, PathArguments, Result, Token, Type};
9
10pub fn derive(input: &DeriveInput) -> TokenStream {
11    match try_expand(input) {
12        Ok(expanded) => expanded,
13        // If there are invalid attributes in the input, expand to an Error impl
14        // anyway to minimize spurious knock-on errors in other code that uses
15        // this type as an Error.
16        Err(error) => fallback(input, error),
17    }
18}
19
20fn try_expand(input: &DeriveInput) -> Result<TokenStream> {
21    let input = Input::from_syn(input)?;
22    input.validate()?;
23    Ok(match input {
24        Input::Struct(input) => impl_struct(input),
25        Input::Enum(input) => impl_enum(input),
26    })
27}
28
29fn fallback(input: &DeriveInput, error: syn::Error) -> TokenStream {
30    let ty = &input.ident;
31    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
32
33    let error = error.to_compile_error();
34
35    quote! {
36        #error
37
38        #[allow(unused_qualifications)]
39        impl #impl_generics thiserror::StdError for #ty #ty_generics #where_clause
40        where
41            // Work around trivial bounds being unstable.
42            // https://github.com/rust-lang/rust/issues/48214
43            for<'workaround> #ty #ty_generics: ::core::fmt::Debug,
44        {}
45
46        #[allow(unused_qualifications)]
47        impl #impl_generics ::core::fmt::Display for #ty #ty_generics #where_clause {
48            fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
49                ::core::unreachable!()
50            }
51        }
52    }
53}
54
55fn impl_struct(input: Struct) -> TokenStream {
56    let ty = &input.ident;
57    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
58    let mut error_inferred_bounds = InferredBounds::new();
59
60    let source_body = if let Some(transparent_attr) = &input.attrs.transparent {
61        let only_field = &input.fields[0];
62        if only_field.contains_generic {
63            error_inferred_bounds.insert(only_field.ty, quote!(thiserror::StdError));
64        }
65        let member = &only_field.member;
66        Some(quote_spanned! {transparent_attr.span=>
67            thiserror::StdError::source(self.#member.as_dyn_error())
68        })
69    } else if let Some(source_field) = input.source_field() {
70        let source = &source_field.member;
71        if source_field.contains_generic {
72            let ty = unoptional_type(source_field.ty);
73            error_inferred_bounds.insert(ty, quote!(thiserror::StdError + 'static));
74        }
75        let asref = if type_is_option(source_field.ty) {
76            Some(quote_spanned!(source.member_span()=> .as_ref()?))
77        } else {
78            None
79        };
80        let dyn_error = quote_spanned! {source_field.source_span()=>
81            self.#source #asref.as_dyn_error()
82        };
83        Some(quote! {
84            ::core::option::Option::Some(#dyn_error)
85        })
86    } else {
87        None
88    };
89    let source_method = source_body.map(|body| {
90        quote! {
91            fn source(&self) -> ::core::option::Option<&(dyn thiserror::StdError + 'static)> {
92                use thiserror::__private::AsDynError as _;
93                #body
94            }
95        }
96    });
97
98    #[cfg(feature = "std")]
99    let provide_method = input.backtrace_field().map(|backtrace_field| {
100        let request = quote!(request);
101        let backtrace = &backtrace_field.member;
102        let body = if let Some(source_field) = input.source_field() {
103            let source = &source_field.member;
104            let source_provide = if type_is_option(source_field.ty) {
105                quote_spanned! {source.member_span()=>
106                    if let ::core::option::Option::Some(source) = &self.#source {
107                        source.thiserror_provide(#request);
108                    }
109                }
110            } else {
111                quote_spanned! {source.member_span()=>
112                    self.#source.thiserror_provide(#request);
113                }
114            };
115            let self_provide = if source == backtrace {
116                None
117            } else if type_is_option(backtrace_field.ty) {
118                Some(quote! {
119                    if let ::core::option::Option::Some(backtrace) = &self.#backtrace {
120                        #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
121                    }
122                })
123            } else {
124                Some(quote! {
125                    #request.provide_ref::<std::backtrace::Backtrace>(&self.#backtrace);
126                })
127            };
128            quote! {
129                use thiserror::__private::ThiserrorProvide as _;
130                #source_provide
131                #self_provide
132            }
133        } else if type_is_option(backtrace_field.ty) {
134            quote! {
135                if let ::core::option::Option::Some(backtrace) = &self.#backtrace {
136                    #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
137                }
138            }
139        } else {
140            quote! {
141                #request.provide_ref::<std::backtrace::Backtrace>(&self.#backtrace);
142            }
143        };
144        quote! {
145            fn provide<'_request>(&'_request self, #request: &mut std::error::Request<'_request>) {
146                #body
147            }
148        }
149    });
150
151    #[cfg(not(feature = "std"))]
152    let provide_method: Option<TokenStream> = None;
153
154    let mut display_implied_bounds = Set::new();
155    let display_body = if input.attrs.transparent.is_some() {
156        let only_field = &input.fields[0].member;
157        display_implied_bounds.insert((0, Trait::Display));
158        Some(quote! {
159            ::core::fmt::Display::fmt(&self.#only_field, __formatter)
160        })
161    } else if let Some(display) = &input.attrs.display {
162        display_implied_bounds.clone_from(&display.implied_bounds);
163        let use_as_display = use_as_display(display.has_bonus_display);
164        let pat = fields_pat(&input.fields);
165        Some(quote! {
166            #use_as_display
167            #[allow(unused_variables, deprecated)]
168            let Self #pat = self;
169            #display
170        })
171    } else {
172        None
173    };
174    let display_impl = display_body.map(|body| {
175        let mut display_inferred_bounds = InferredBounds::new();
176        for (field, bound) in display_implied_bounds {
177            let field = &input.fields[field];
178            if field.contains_generic {
179                display_inferred_bounds.insert(field.ty, bound);
180            }
181        }
182        let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
183        quote! {
184            #[allow(unused_qualifications)]
185            impl #impl_generics ::core::fmt::Display for #ty #ty_generics #display_where_clause {
186                #[allow(clippy::used_underscore_binding)]
187                fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
188                    #body
189                }
190            }
191        }
192    });
193
194    let from_impl = input.from_field().map(|from_field| {
195        let backtrace_field = input.distinct_backtrace_field();
196        let from = unoptional_type(from_field.ty);
197        let body = from_initializer(from_field, backtrace_field);
198        quote! {
199            #[allow(unused_qualifications)]
200            impl #impl_generics ::core::convert::From<#from> for #ty #ty_generics #where_clause {
201                #[allow(deprecated)]
202                fn from(source: #from) -> Self {
203                    #ty #body
204                }
205            }
206        }
207    });
208
209    if input.generics.type_params().next().is_some() {
210        let self_token = <Token![Self]>::default();
211        error_inferred_bounds.insert(self_token, Trait::Debug);
212        error_inferred_bounds.insert(self_token, Trait::Display);
213    }
214    let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics);
215
216    quote! {
217        #[allow(unused_qualifications)]
218        impl #impl_generics thiserror::StdError for #ty #ty_generics #error_where_clause {
219            #source_method
220            #provide_method
221        }
222        #display_impl
223        #from_impl
224    }
225}
226
227fn impl_enum(input: Enum) -> TokenStream {
228    let ty = &input.ident;
229    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
230    let mut error_inferred_bounds = InferredBounds::new();
231
232    let source_method = if input.has_source() {
233        let arms = input.variants.iter().map(|variant| {
234            let ident = &variant.ident;
235            if let Some(transparent_attr) = &variant.attrs.transparent {
236                let only_field = &variant.fields[0];
237                if only_field.contains_generic {
238                    error_inferred_bounds.insert(only_field.ty, quote!(thiserror::StdError));
239                }
240                let member = &only_field.member;
241                let source = quote_spanned! {transparent_attr.span=>
242                    thiserror::StdError::source(transparent.as_dyn_error())
243                };
244                quote! {
245                    #ty::#ident {#member: transparent} => #source,
246                }
247            } else if let Some(source_field) = variant.source_field() {
248                let source = &source_field.member;
249                if source_field.contains_generic {
250                    let ty = unoptional_type(source_field.ty);
251                    error_inferred_bounds.insert(ty, quote!(thiserror::StdError + 'static));
252                }
253                let asref = if type_is_option(source_field.ty) {
254                    Some(quote_spanned!(source.member_span()=> .as_ref()?))
255                } else {
256                    None
257                };
258                let varsource = quote!(source);
259                let dyn_error = quote_spanned! {source_field.source_span()=>
260                    #varsource #asref.as_dyn_error()
261                };
262                quote! {
263                    #ty::#ident {#source: #varsource, ..} => ::core::option::Option::Some(#dyn_error),
264                }
265            } else {
266                quote! {
267                    #ty::#ident {..} => ::core::option::Option::None,
268                }
269            }
270        });
271        Some(quote! {
272            fn source(&self) -> ::core::option::Option<&(dyn thiserror::StdError + 'static)> {
273                use thiserror::__private::AsDynError as _;
274                #[allow(deprecated)]
275                match self {
276                    #(#arms)*
277                }
278            }
279        })
280    } else {
281        None
282    };
283
284    #[cfg(feature = "std")]
285    let provide_method = if input.has_backtrace() {
286        let request = quote!(request);
287        let arms = input.variants.iter().map(|variant| {
288            let ident = &variant.ident;
289            match (variant.backtrace_field(), variant.source_field()) {
290                (Some(backtrace_field), Some(source_field))
291                    if backtrace_field.attrs.backtrace.is_none() =>
292                {
293                    let backtrace = &backtrace_field.member;
294                    let source = &source_field.member;
295                    let varsource = quote!(source);
296                    let source_provide = if type_is_option(source_field.ty) {
297                        quote_spanned! {source.member_span()=>
298                            if let ::core::option::Option::Some(source) = #varsource {
299                                source.thiserror_provide(#request);
300                            }
301                        }
302                    } else {
303                        quote_spanned! {source.member_span()=>
304                            #varsource.thiserror_provide(#request);
305                        }
306                    };
307                    let self_provide = if type_is_option(backtrace_field.ty) {
308                        quote! {
309                            if let ::core::option::Option::Some(backtrace) = backtrace {
310                                #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
311                            }
312                        }
313                    } else {
314                        quote! {
315                            #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
316                        }
317                    };
318                    quote! {
319                        #ty::#ident {
320                            #backtrace: backtrace,
321                            #source: #varsource,
322                            ..
323                        } => {
324                            use thiserror::__private::ThiserrorProvide as _;
325                            #source_provide
326                            #self_provide
327                        }
328                    }
329                }
330                (Some(backtrace_field), Some(source_field))
331                    if backtrace_field.member == source_field.member =>
332                {
333                    let backtrace = &backtrace_field.member;
334                    let varsource = quote!(source);
335                    let source_provide = if type_is_option(source_field.ty) {
336                        quote_spanned! {backtrace.member_span()=>
337                            if let ::core::option::Option::Some(source) = #varsource {
338                                source.thiserror_provide(#request);
339                            }
340                        }
341                    } else {
342                        quote_spanned! {backtrace.member_span()=>
343                            #varsource.thiserror_provide(#request);
344                        }
345                    };
346                    quote! {
347                        #ty::#ident {#backtrace: #varsource, ..} => {
348                            use thiserror::__private::ThiserrorProvide as _;
349                            #source_provide
350                        }
351                    }
352                }
353                (Some(backtrace_field), _) => {
354                    let backtrace = &backtrace_field.member;
355                    let body = if type_is_option(backtrace_field.ty) {
356                        quote! {
357                            if let ::core::option::Option::Some(backtrace) = backtrace {
358                                #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
359                            }
360                        }
361                    } else {
362                        quote! {
363                            #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
364                        }
365                    };
366                    quote! {
367                        #ty::#ident {#backtrace: backtrace, ..} => {
368                            #body
369                        }
370                    }
371                }
372                (None, _) => quote! {
373                    #ty::#ident {..} => {}
374                },
375            }
376        });
377        Some(quote! {
378            fn provide<'_request>(&'_request self, #request: &mut std::error::Request<'_request>) {
379                #[allow(deprecated)]
380                match self {
381                    #(#arms)*
382                }
383            }
384        })
385    } else {
386        None
387    };
388
389    #[cfg(not(feature = "std"))]
390    let provide_method: Option<TokenStream> = None;
391
392    let display_impl = if input.has_display() {
393        let mut display_inferred_bounds = InferredBounds::new();
394        let has_bonus_display = input.variants.iter().any(|v| {
395            v.attrs
396                .display
397                .as_ref()
398                .map_or(false, |display| display.has_bonus_display)
399        });
400        let use_as_display = use_as_display(has_bonus_display);
401        let void_deref = if input.variants.is_empty() {
402            Some(quote!(*))
403        } else {
404            None
405        };
406        let arms = input.variants.iter().map(|variant| {
407            let mut display_implied_bounds = Set::new();
408            let display = match &variant.attrs.display {
409                Some(display) => {
410                    display_implied_bounds.clone_from(&display.implied_bounds);
411                    display.to_token_stream()
412                }
413                None => {
414                    let only_field = match &variant.fields[0].member {
415                        Member::Named(ident) => ident.clone(),
416                        Member::Unnamed(index) => format_ident!("_{}", index),
417                    };
418                    display_implied_bounds.insert((0, Trait::Display));
419                    quote!(::core::fmt::Display::fmt(#only_field, __formatter))
420                }
421            };
422            for (field, bound) in display_implied_bounds {
423                let field = &variant.fields[field];
424                if field.contains_generic {
425                    display_inferred_bounds.insert(field.ty, bound);
426                }
427            }
428            let ident = &variant.ident;
429            let pat = fields_pat(&variant.fields);
430            quote! {
431                #ty::#ident #pat => #display
432            }
433        });
434        let arms = arms.collect::<Vec<_>>();
435        let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
436        Some(quote! {
437            #[allow(unused_qualifications)]
438            impl #impl_generics ::core::fmt::Display for #ty #ty_generics #display_where_clause {
439                fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
440                    #use_as_display
441                    #[allow(unused_variables, deprecated, clippy::used_underscore_binding)]
442                    match #void_deref self {
443                        #(#arms,)*
444                    }
445                }
446            }
447        })
448    } else {
449        None
450    };
451
452    let from_impls = input.variants.iter().filter_map(|variant| {
453        let from_field = variant.from_field()?;
454        let backtrace_field = variant.distinct_backtrace_field();
455        let variant = &variant.ident;
456        let from = unoptional_type(from_field.ty);
457        let body = from_initializer(from_field, backtrace_field);
458        Some(quote! {
459            #[allow(unused_qualifications)]
460            impl #impl_generics ::core::convert::From<#from> for #ty #ty_generics #where_clause {
461                #[allow(deprecated)]
462                fn from(source: #from) -> Self {
463                    #ty::#variant #body
464                }
465            }
466        })
467    });
468
469    if input.generics.type_params().next().is_some() {
470        let self_token = <Token![Self]>::default();
471        error_inferred_bounds.insert(self_token, Trait::Debug);
472        error_inferred_bounds.insert(self_token, Trait::Display);
473    }
474    let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics);
475
476    quote! {
477        #[allow(unused_qualifications)]
478        impl #impl_generics thiserror::StdError for #ty #ty_generics #error_where_clause {
479            #source_method
480            #provide_method
481        }
482        #display_impl
483        #(#from_impls)*
484    }
485}
486
487fn fields_pat(fields: &[Field]) -> TokenStream {
488    let mut members = fields.iter().map(|field| &field.member).peekable();
489    match members.peek() {
490        Some(Member::Named(_)) => quote!({ #(#members),* }),
491        Some(Member::Unnamed(_)) => {
492            let vars = members.map(|member| match member {
493                Member::Unnamed(member) => format_ident!("_{}", member),
494                Member::Named(_) => unreachable!(),
495            });
496            quote!((#(#vars),*))
497        }
498        None => quote!({}),
499    }
500}
501
502fn use_as_display(needs_as_display: bool) -> Option<TokenStream> {
503    if needs_as_display {
504        Some(quote! {
505            use thiserror::__private::AsDisplay as _;
506        })
507    } else {
508        None
509    }
510}
511
512fn from_initializer(from_field: &Field, backtrace_field: Option<&Field>) -> TokenStream {
513    let from_member = &from_field.member;
514    let some_source = if type_is_option(from_field.ty) {
515        quote!(::core::option::Option::Some(source))
516    } else {
517        quote!(source)
518    };
519    let backtrace = backtrace_field.map(|backtrace_field| {
520        let backtrace_member = &backtrace_field.member;
521        if type_is_option(backtrace_field.ty) {
522            quote! {
523                #backtrace_member: ::core::option::Option::Some(std::backtrace::Backtrace::capture()),
524            }
525        } else {
526            quote! {
527                #backtrace_member: ::core::convert::From::from(std::backtrace::Backtrace::capture()),
528            }
529        }
530    });
531    quote!({
532        #from_member: #some_source,
533        #backtrace
534    })
535}
536
537fn type_is_option(ty: &Type) -> bool {
538    type_parameter_of_option(ty).is_some()
539}
540
541fn unoptional_type(ty: &Type) -> TokenStream {
542    let unoptional = type_parameter_of_option(ty).unwrap_or(ty);
543    quote!(#unoptional)
544}
545
546fn type_parameter_of_option(ty: &Type) -> Option<&Type> {
547    let path = match ty {
548        Type::Path(ty) => &ty.path,
549        _ => return None,
550    };
551
552    let last = path.segments.last().unwrap();
553    if last.ident != "Option" {
554        return None;
555    }
556
557    let bracketed = match &last.arguments {
558        PathArguments::AngleBracketed(bracketed) => bracketed,
559        _ => return None,
560    };
561
562    if bracketed.args.len() != 1 {
563        return None;
564    }
565
566    match &bracketed.args[0] {
567        GenericArgument::Type(arg) => Some(arg),
568        _ => None,
569    }
570}