fack_codegen/
expand.rs

1//! Final code expansion for error definitions.
2
3use alloc::{format, vec::Vec};
4
5use proc_macro2::TokenStream;
6
7use quote::{ToTokens, quote};
8
9use syn::Ident;
10
11use super::{
12    Target,
13    common::{FieldRef, Format, ImportRoot, InlineOptions, Transparent},
14    enumerate::{Enumeration, Variant},
15    structure::{Structure, StructureOptions},
16};
17
18/// A struct that encapsulates the code generation for error definitions.
19pub struct Expand {}
20
21impl Expand {
22    /// Expand the [`Target`] into the appropriate error implementation.
23    #[inline]
24    pub fn target(target_value: Target) -> syn::Result<TokenStream> {
25        match target_value {
26            Target::Struct(structure) => Self::structure(structure),
27            Target::Enum(enumeration) => Self::enumeration(enumeration),
28        }
29    }
30}
31
32impl Expand {
33    /// Expand to the error implementation for the target [`Structure`].
34    pub fn structure(target_value: Structure) -> syn::Result<TokenStream> {
35        let Structure {
36            inline_opts,
37            root_import,
38            name_ident,
39            generics,
40            options,
41            field_list,
42        } = target_value;
43
44        let inline_expand = match inline_opts {
45            Some(InlineOptions::Neutral) => quote! { #[inline] },
46            Some(InlineOptions::Always) => quote! { #[inline(always)] },
47            Some(InlineOptions::Never) => quote! { #[inline(never)] },
48            None => quote! {},
49        };
50
51        let root_expand = root_import
52            .map(|ImportRoot(root)| root.to_token_stream())
53            .unwrap_or_else(|| quote! { ::core });
54
55        let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
56
57        let field_pat = field_list.pattern()?;
58
59        match options {
60            StructureOptions::Standalone {
61                source_field,
62                format_args: Format { format, format_args },
63            } => {
64                let source_expand = match source_field {
65                    Some(FieldRef::Named(ref field)) => quote! { Some(self.#field) },
66                    Some(FieldRef::Indexed(ref field)) => quote! { Some(self.#field) },
67                    None => quote! { None },
68                };
69
70                Ok(quote! {
71                    #[automatically_derived]
72                    impl #impl_generics #root_expand::error::Error for #name_ident #ty_generics #where_clause {
73                        #inline_expand
74                        fn source(&self) -> Option<&(dyn #root_expand::error::Error + 'static)> {
75                            #source_expand
76                        }
77                    }
78
79                    #[automatically_derived]
80                    impl #impl_generics #root_expand::fmt::Display for #name_ident #ty_generics #where_clause {
81                        #inline_expand
82                        fn fmt(&self, f: &mut #root_expand::fmt::Formatter<'_>) -> #root_expand::fmt::Result {
83                            let Self #field_pat = self;
84
85                            #root_expand::write!(f, #format, #format_args)
86                        }
87                    }
88                })
89            }
90            StructureOptions::Transparent(Transparent(target_field)) => {
91                let field_expand = match target_field {
92                    FieldRef::Named(ref field) => quote! { self.#field },
93                    FieldRef::Indexed(ref field) => quote! { self.#field },
94                };
95
96                Ok(quote! {
97                    #[automatically_derived]
98                    impl #impl_generics #root_expand::error::Error for #name_ident #ty_generics #where_clause {
99                        #inline_expand
100                        fn source(&self) -> Option<&(dyn #root_expand::error::Error + 'static)> {
101
102                            #root_expand::error::Error::source(#field_expand)
103                        }
104                    }
105
106                    #[automatically_derived]
107                    impl #impl_generics #root_expand::fmt::Display for #name_ident #ty_generics #where_clause {
108                        #inline_expand
109                        fn fmt(&self, f: &mut #root_expand::fmt::Formatter<'_>) -> #root_expand::fmt::Result {
110                            let Self #field_pat = self;
111
112                            #root_expand::fmt::Display::fmt(#field_expand, f)
113                        }
114                    }
115                })
116            }
117            StructureOptions::Forward {
118                field_ref,
119                field_type,
120                source_field,
121                format_args: Format { format, format_args },
122            } => {
123                let source_expand = match source_field {
124                    Some(FieldRef::Named(ref field)) => quote! { Some(self.#field) },
125                    Some(FieldRef::Indexed(ref field)) => quote! { Some(self.#field) },
126                    None => quote! { None },
127                };
128
129                let from_expand = match field_ref {
130                    FieldRef::Named(ref field) => quote! {
131                        #[automatically_derived]
132                        impl #impl_generics From<#field_type> #name_ident #ty_generics #where_clause {
133                            #inline_expand
134                            fn from(#field: #field_type) -> Self {
135                               Self { #field }
136                            }
137                        }
138                    },
139                    FieldRef::Indexed(..) => quote! {
140                        #[automatically_derived]
141                        impl #impl_generics From<#field_type> #name_ident #ty_generics #where_clause {
142                            #inline_expand
143                            fn from(field: #field_type) -> Self {
144                                Self(field)
145                            }
146                        }
147                    },
148                };
149
150                Ok(quote! {
151                    #from_expand
152
153                    #[automatically_derived]
154                    impl #impl_generics #root_expand::error::Error for #name_ident #ty_generics #where_clause {
155                        #inline_expand
156                        fn source(&self) -> Option<&(dyn #root_expand::error::Error + 'static)> {
157                            #source_expand
158                        }
159                    }
160
161                    #[automatically_derived]
162                    impl #impl_generics #root_expand::fmt::Display for #name_ident #ty_generics #where_clause {
163                        #inline_expand
164                        fn fmt(&self, f: &mut #root_expand::fmt::Formatter<'_>) -> #root_expand::fmt::Result {
165                            let Self #field_pat = self;
166
167                            #root_expand::write!(f, #format, #format_args)
168                        }
169                    }
170                })
171            }
172        }
173    }
174
175    /// Expand to the error implementation for the target [`Enumeration`].
176    pub fn enumeration(target_value: Enumeration) -> syn::Result<TokenStream> {
177        let Enumeration {
178            inline_opts,
179            root_import,
180            name_ident,
181            generics,
182            variant_list,
183        } = target_value;
184
185        let inline_expand = match inline_opts {
186            Some(InlineOptions::Neutral) => quote! { #[inline] },
187            Some(InlineOptions::Always) => quote! { #[inline(always)] },
188            Some(InlineOptions::Never) => quote! { #[inline(never)] },
189            None => quote! {},
190        };
191
192        let root_expand = root_import
193            .map(|ImportRoot(root)| root.to_token_stream())
194            .unwrap_or_else(|| quote! { ::core });
195
196        let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
197
198        let mut source_expand = Vec::new();
199
200        let mut from_expand = Vec::new();
201
202        let mut display_expand = Vec::new();
203
204        for variant in variant_list {
205            match variant {
206                Variant::Struct {
207                    format: Format { format, format_args },
208                    variant_name,
209                    field_list,
210                    source_field,
211                } => {
212                    let field_pat = field_list.pattern()?;
213
214                    if let Some(source_field) = source_field {
215                        let field_expand = match source_field {
216                            FieldRef::Named(ref field) => quote! { Some(self.#field) },
217                            FieldRef::Indexed(index) => Ident::new(&format!("_{index}"), variant_name.span()).into_token_stream(),
218                        };
219
220                        source_expand.push(quote! {
221                            Self::#variant_name #field_pat => #field_expand,
222                        });
223                    } else {
224                        source_expand.push(quote! {
225                            Self::#variant_name #field_pat => None,
226                        });
227                    }
228
229                    display_expand.push(quote! {
230                        Self::#variant_name #field_pat  => #root_expand::write!(f, #format, #format_args),
231                    });
232                }
233                Variant::Unit {
234                    format: Format { format, format_args },
235                    variant_name,
236                } => {
237                    source_expand.push(quote! {
238                        Self::#variant_name => None,
239                    });
240
241                    display_expand.push(quote! {
242                        Self::#variant_name => #root_expand::write!(f, #format, #format_args),
243                    });
244                }
245                Variant::Transparent {
246                    transparent: Transparent(trans_field),
247                    variant_name,
248                    field_list,
249                } => {
250                    let trans_expand = match trans_field {
251                        FieldRef::Named(ref field) => quote! { #field },
252                        FieldRef::Indexed(index) => Ident::new(&format!("_{index}"), variant_name.span()).into_token_stream(),
253                    };
254
255                    let field_pat = field_list.pattern()?;
256
257                    source_expand.push(quote! {
258                        Self::#variant_name #field_pat => #root_expand::error::Error::source(#trans_expand),
259                    });
260
261                    display_expand.push(quote! {
262                        Self::#variant_name #field_pat => #root_expand::fmt::Display::fmt(#trans_expand, f),
263                    });
264                }
265                Variant::Forward {
266                    format: Format { format, format_args },
267                    variant_name,
268                    field_ref,
269                    field_type,
270                } => {
271                    let bare_field_name = match field_ref {
272                        FieldRef::Named(ref field) => quote! { #field },
273                        FieldRef::Indexed(ref field) => quote! { #field },
274                    };
275
276                    source_expand.push(quote! {
277                        Self::#variant_name { #bare_field_name: target_field }  => #root_expand::error::Error::source(target_field),
278                    });
279
280                    from_expand.push(match field_ref {
281                        FieldRef::Named(ref field) => quote! {
282                            #[automatically_derived]
283                            impl #impl_generics From<#field_type> for #name_ident #ty_generics #where_clause {
284                                #inline_expand
285                                fn from(#field: #field_type) -> Self {
286                                    Self::#variant_name { #field }
287                                }
288                            }
289                        },
290                        FieldRef::Indexed(..) => quote! {
291                            #[automatically_derived]
292                            impl #impl_generics From<#field_type> for #name_ident #ty_generics #where_clause {
293                                #inline_expand
294                                fn from(field: #field_type) -> Self {
295                                    Self::#variant_name(field)
296                                }
297                            }
298                        },
299                    });
300
301                    let replaced_field_name = match field_ref {
302                        FieldRef::Named(..) => None,
303                        FieldRef::Indexed(ref field) => Some({
304                            let ident = Ident::new(&format!("_{}", field), variant_name.span());
305
306                            quote! {
307                                : #ident
308                            }
309                        }),
310                    };
311
312                    display_expand.push(quote! {
313                        Self::#variant_name { #bare_field_name #replaced_field_name }  => #root_expand::write!(f, #format, #format_args),
314                    });
315                }
316            }
317        }
318
319        Ok(quote! {
320            #(#from_expand)*
321
322            #[automatically_derived]
323            impl #impl_generics #root_expand::error::Error for #name_ident #ty_generics #where_clause {
324                #inline_expand
325                fn source(&self) -> Option<&(dyn #root_expand::error::Error + 'static)> {
326                    match self {
327                        #(#source_expand)*
328                    }
329                }
330            }
331
332            #[automatically_derived]
333            impl #impl_generics #root_expand::fmt::Display for #name_ident #ty_generics #where_clause {
334                #inline_expand
335                fn fmt(&self, f: &mut #root_expand::fmt::Formatter<'_>) -> #root_expand::fmt::Result {
336                    match self {
337                        #(#display_expand)*
338                    }
339                }
340            }
341        })
342    }
343}