errore_impl/expand/
display.rs

1use proc_macro2::TokenStream;
2use std::collections::BTreeSet as Set;
3
4use quote::{format_ident, quote, ToTokens};
5use syn::{DeriveInput, Ident, ImplGenerics, Member, Result, TypeGenerics};
6
7use crate::ast::{DeriveType, Enum, Input, Struct};
8use crate::attr::Trait;
9use crate::generics::InferredBounds;
10use crate::util::{fields_pat, use_as_display};
11
12pub fn derive(input: &DeriveInput) -> TokenStream {
13    match try_expand(input) {
14        Ok(expanded) => expanded,
15        // If there are invalid attributes in the input, expand to an Error impl
16        // anyway to minimize spurious knock-on errors in other code that uses
17        // this type as an Error.
18        Err(error) => fallback(input, error),
19    }
20}
21
22fn try_expand(input: &DeriveInput) -> Result<TokenStream> {
23    let input = Input::from_syn(input, DeriveType::Display)?;
24    input.validate()?;
25    match input {
26        Input::Enum(input) => Ok(impl_enum(input)),
27        Input::Struct(input) => Ok(impl_struct(input)),
28    }
29}
30
31fn fallback(input: &DeriveInput, error: syn::Error) -> TokenStream {
32    let ty = &input.ident;
33    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
34
35    let error = error.to_compile_error();
36
37    quote! {
38        #error
39
40        #[allow(unused_qualifications)]
41        #[automatically_derived]
42        impl #impl_generics ::core::fmt::Display for #ty #ty_generics #where_clause {
43            fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
44                ::core::unreachable!()
45            }
46        }
47    }
48}
49
50fn impl_enum(input: Enum) -> TokenStream {
51    let ty = &input.ident;
52    let mut error = Option::<TokenStream>::None;
53    let (impl_generics, ty_generics, _where_clause) = input.generics.split_for_impl();
54
55    for variant in &input.variants {
56        if let Some(display) = &variant.attrs.display {
57            error = display.recursing();
58        }
59    }
60
61    let mut display_inferred_bounds = InferredBounds::new();
62    let has_bonus_display = input.variants.iter().any(|v| {
63        v.attrs
64            .display
65            .as_ref()
66            .map_or(false, |display| display.has_bonus_display)
67    });
68    let use_as_display = use_as_display(has_bonus_display);
69    let void_deref = if input.variants.is_empty() {
70        Some(quote!(*))
71    } else {
72        None
73    };
74
75    let arms = input.variants.iter().map(|variant| {
76        let ident = &variant.ident;
77        let mut display_implied_bounds = Set::<(usize, Trait)>::new();
78        let display = match &variant.attrs.display {
79            Some(display) => {
80                display_implied_bounds.clone_from(&display.implied_bounds);
81                display.to_token_stream()
82            }
83            None => {
84                if variant.fields.len() > 0 {
85                    let only_field = match &variant.fields[0].member {
86                        Member::Named(ident) => ident.clone(),
87                        Member::Unnamed(index) => format_ident!("_{}", index),
88                    };
89                    display_implied_bounds.insert((0, Trait::Display));
90                    quote!(::core::fmt::Display::fmt(#only_field, __formatter))
91                } else {
92                    // if no #[display("...")] is found, fallback to '<enum_name>::<enum_field_name>'
93                    quote! {::core::fmt::Display::fmt(concat!(stringify!(#ty), "::", stringify!(#ident)), __formatter)}
94                }
95            }
96        };
97        for (field, bound) in display_implied_bounds {
98            let field = &variant.fields[field];
99            if field.contains_generic {
100                display_inferred_bounds.insert(field.ty, bound);
101            }
102        }
103        let pat = fields_pat(&variant.fields);
104        quote! {
105            #ty::#ident #pat => #display,
106        }
107    });
108
109    let arms = arms.collect::<Vec<_>>();
110    let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
111    let display_impl = quote! {
112        #[allow(unused_qualifications)]
113        #[automatically_derived]
114        impl #impl_generics ::core::fmt::Display for #ty #ty_generics #display_where_clause {
115            fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
116                #use_as_display
117                #[allow(unused_variables, deprecated, clippy::used_underscore_binding)]
118                match #void_deref self {
119                    #(#arms)*
120                }
121            }
122        }
123    };
124
125    quote! {
126        #error
127
128        #[allow(unused_qualifications)]
129        #[automatically_derived]
130        impl #impl_generics ::core::error::Error for #ty #ty_generics #display_where_clause {}
131
132        #display_impl
133    }
134}
135
136pub fn impl_struct_display_body(
137    input: &Struct,
138    display_implied_bounds: &mut Set<(usize, Trait)>,
139) -> Option<TokenStream> {
140    return if input.attrs.transparent.is_some() {
141        let only_field = &input.fields[0].member;
142        display_implied_bounds.insert((0, Trait::Display));
143        Some(quote! {
144            ::core::fmt::Display::fmt(&self.#only_field, __formatter)
145        })
146    } else if let Some(display) = &input.attrs.display {
147        display_implied_bounds.clone_from(&display.implied_bounds);
148        let use_as_display = use_as_display(display.has_bonus_display);
149        let pat = fields_pat(&input.fields);
150        Some(quote! {
151            #use_as_display
152            #[allow(unused_variables, deprecated)]
153            let Self #pat = self;
154            #display
155        })
156    } else {
157        None
158    };
159}
160
161pub fn impl_struct_display(
162    input: &Struct,
163    ty: &Ident,
164    ty_generics: &TypeGenerics<'_>,
165    display_inferred_bounds: &mut InferredBounds,
166    display_implied_bounds: Set<(usize, Trait)>,
167    display_body: Option<TokenStream>,
168    impl_generics: &ImplGenerics<'_>,
169) -> Option<TokenStream> {
170    for (field, bound) in display_implied_bounds {
171        let field = &input.fields[field];
172        if field.contains_generic {
173            display_inferred_bounds.insert(field.ty, bound);
174        }
175    }
176    let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
177    return display_body.as_ref().map(|body| {
178        quote! {
179            #[allow(unused_qualifications)]
180            #[automatically_derived]
181            impl #impl_generics ::core::fmt::Display for #ty #ty_generics #display_where_clause {
182                #[allow(clippy::used_underscore_binding)]
183                fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
184                    #body
185                }
186            }
187        }
188    });
189}
190
191fn impl_struct(input: Struct) -> TokenStream {
192    let ty = &input.ident;
193    let mut error = Option::<TokenStream>::None;
194    let (impl_generics, ty_generics, _where_clause) = input.generics.split_for_impl();
195
196    if let Some(display) = &input.attrs.display {
197        error = display.recursing();
198    }
199
200    // implement display body
201    let mut display_implied_bounds = Set::new();
202    let mut display_body = impl_struct_display_body(&input, &mut display_implied_bounds);
203    if display_body.is_none() {
204        display_body = Some(quote! {
205            ::core::fmt::Display::fmt(&format!("{:#?}", self), __formatter)
206        });
207    }
208
209    // implement display
210    let mut display_inferred_bounds = InferredBounds::new();
211    let display_impl = impl_struct_display(
212        &input,
213        ty,
214        &ty_generics,
215        &mut display_inferred_bounds,
216        display_implied_bounds,
217        display_body,
218        &impl_generics,
219    );
220    let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
221
222    let error_impl = quote! {
223        #[allow(unused_qualifications)]
224        #[automatically_derived]
225        impl #impl_generics core::error::Error for #ty #ty_generics #display_where_clause {}
226    };
227
228    quote! {
229        #error
230
231        #display_impl
232        #error_impl
233    }
234}