enum_try_as_inner/
lib.rs

1// Copyright 2015-2018 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8#![doc = include_str!("../README.md")]
9#![warn(
10    clippy::default_trait_access,
11    clippy::dbg_macro,
12    clippy::print_stdout,
13    clippy::unimplemented,
14    clippy::use_self,
15    missing_copy_implementations,
16    missing_docs,
17    non_snake_case,
18    non_upper_case_globals,
19    rust_2018_idioms,
20    unreachable_pub
21)]
22
23use heck::ToSnakeCase;
24use proc_macro2::{Ident, Span, TokenStream};
25use quote::quote;
26use syn::{parse_macro_input, DataEnum, DeriveInput, Visibility};
27
28/// returns first the types to return, the match names, and then tokens to the field accesses
29fn unit_fields_return(
30    variant_name: &syn::Ident,
31    err_name: &syn::Ident,
32    ty_generics: &syn::TypeGenerics<'_>,
33    (function_name_is, doc_is): (&Ident, &str),
34    (function_name_ref, doc_ref): (&Ident, &str),
35    (function_name_val, doc_val): (&Ident, &str),
36) -> TokenStream {
37    quote!(
38        #[doc = #doc_is]
39        #[inline]
40        pub fn #function_name_is(&self) -> bool {
41            matches!(self, Self::#variant_name)
42        }
43
44        #[doc = #doc_ref ]
45        #[inline]
46        pub fn #function_name_ref(&self) -> ::core::result::Result<&(), #err_name #ty_generics> {
47            match self {
48                Self::#variant_name => {
49                    ::core::result::Result::Ok(&())
50                }
51                _ => {
52                    ::core::result::Result::Err(#err_name::new(
53                        stringify!(#variant_name),
54                        self.variant_name(),
55                        ::core::option::Option::None,
56                    ))
57                }
58            }
59        }
60
61        #[doc = #doc_val ]
62        #[inline]
63        pub fn #function_name_val(self) -> ::core::result::Result<(), #err_name #ty_generics> {
64            match self {
65                Self::#variant_name => {
66                    ::core::result::Result::Ok(())
67                }
68                _ => {
69                    ::core::result::Result::Err(#err_name::new(
70                        stringify!(#variant_name),
71                        self.variant_name(),
72                        ::core::option::Option::Some(self),
73                    ))
74                }
75            }
76        }
77    )
78}
79
80/// returns first the types to return, the match names, and then tokens to the field accesses
81#[allow(clippy::too_many_arguments)]
82fn unnamed_fields_return(
83    variant_name: &syn::Ident,
84    err_name: &syn::Ident,
85    ty_generics: &syn::TypeGenerics<'_>,
86    (function_name_is, doc_is): (&Ident, &str),
87    (function_name_mut_ref, doc_mut_ref): (&Ident, &str),
88    (function_name_ref, doc_ref): (&Ident, &str),
89    (function_name_val, doc_val): (&Ident, &str),
90    fields: &syn::FieldsUnnamed,
91) -> TokenStream {
92    let (returns_mut_ref, returns_ref, returns_val, matches) = match fields.unnamed.len() {
93        1 => {
94            let field = fields.unnamed.first().expect("no fields on type");
95
96            let returns = &field.ty;
97            let returns_mut_ref = quote!(&mut #returns);
98            let returns_ref = quote!(&#returns);
99            let returns_val = quote!(#returns);
100            let matches = quote!(inner);
101
102            (returns_mut_ref, returns_ref, returns_val, matches)
103        }
104        0 => (quote!(()), quote!(()), quote!(()), quote!()),
105        _ => {
106            let mut returns_mut_ref = TokenStream::new();
107            let mut returns_ref = TokenStream::new();
108            let mut returns_val = TokenStream::new();
109            let mut matches = TokenStream::new();
110
111            for (i, field) in fields.unnamed.iter().enumerate() {
112                let rt = &field.ty;
113                let match_name = Ident::new(&format!("match_{}", i), Span::call_site());
114                returns_mut_ref.extend(quote!(&mut #rt,));
115                returns_ref.extend(quote!(&#rt,));
116                returns_val.extend(quote!(#rt,));
117                matches.extend(quote!(#match_name,));
118            }
119
120            (
121                quote!((#returns_mut_ref)),
122                quote!((#returns_ref)),
123                quote!((#returns_val)),
124                quote!(#matches),
125            )
126        }
127    };
128
129    quote!(
130        #[doc = #doc_is ]
131        #[inline]
132        #[allow(unused_variables)]
133        pub fn #function_name_is(&self) -> bool {
134            matches!(self, Self::#variant_name(#matches))
135        }
136
137        #[doc = #doc_mut_ref ]
138        #[inline]
139        pub fn #function_name_mut_ref(&mut self) -> ::core::result::Result<#returns_mut_ref, #err_name #ty_generics> {
140            match self {
141                Self::#variant_name(#matches) => {
142                    ::core::result::Result::Ok((#matches))
143                }
144                _ => {
145                    ::core::result::Result::Err(#err_name::new(
146                        stringify!(#variant_name),
147                        self.variant_name(),
148                        ::core::option::Option::None,
149                    ))
150                }
151            }
152        }
153
154        #[doc = #doc_ref ]
155        #[inline]
156        pub fn #function_name_ref(&self) -> ::core::result::Result<#returns_ref, #err_name #ty_generics> {
157            match self {
158                Self::#variant_name(#matches) => {
159                    ::core::result::Result::Ok((#matches))
160                }
161                _ => {
162                    ::core::result::Result::Err(#err_name::new(
163                        stringify!(#variant_name),
164                        self.variant_name(),
165                        ::core::option::Option::None,
166                    ))
167                }
168            }
169        }
170
171        #[doc = #doc_val ]
172        #[inline]
173        pub fn #function_name_val(self) -> ::core::result::Result<#returns_val, #err_name #ty_generics> {
174            match self {
175                Self::#variant_name(#matches) => {
176                    ::core::result::Result::Ok((#matches))
177                }
178                _ => {
179                    ::core::result::Result::Err(#err_name::new(
180                        stringify!(#variant_name),
181                        self.variant_name(),
182                        ::core::option::Option::Some(self),
183                    ))
184                }
185            }
186        }
187    )
188}
189
190/// returns first the types to return, the match names, and then tokens to the field accesses
191#[allow(clippy::too_many_arguments)]
192fn named_fields_return(
193    variant_name: &syn::Ident,
194    err_name: &syn::Ident,
195    ty_generics: &syn::TypeGenerics<'_>,
196    (function_name_is, doc_is): (&Ident, &str),
197    (function_name_mut_ref, doc_mut_ref): (&Ident, &str),
198    (function_name_ref, doc_ref): (&Ident, &str),
199    (function_name_val, doc_val): (&Ident, &str),
200    fields: &syn::FieldsNamed,
201) -> TokenStream {
202    let (returns_mut_ref, returns_ref, returns_val, matches) = match fields.named.len() {
203        1 => {
204            let field = fields.named.first().expect("no fields on type");
205            let match_name = field.ident.as_ref().expect("expected a named field");
206
207            let returns = &field.ty;
208            let returns_mut_ref = quote!(&mut #returns);
209            let returns_ref = quote!(&#returns);
210            let returns_val = quote!(#returns);
211            let matches = quote!(#match_name);
212
213            (returns_mut_ref, returns_ref, returns_val, matches)
214        }
215        0 => (quote!(()), quote!(()), quote!(()), quote!(())),
216        _ => {
217            let mut returns_mut_ref = TokenStream::new();
218            let mut returns_ref = TokenStream::new();
219            let mut returns_val = TokenStream::new();
220            let mut matches = TokenStream::new();
221
222            for field in fields.named.iter() {
223                let rt = &field.ty;
224                let match_name = field.ident.as_ref().expect("expected a named field");
225
226                returns_mut_ref.extend(quote!(&mut #rt,));
227                returns_ref.extend(quote!(&#rt,));
228                returns_val.extend(quote!(#rt,));
229                matches.extend(quote!(#match_name,));
230            }
231
232            (
233                quote!((#returns_mut_ref)),
234                quote!((#returns_ref)),
235                quote!((#returns_val)),
236                quote!(#matches),
237            )
238        }
239    };
240
241    quote!(
242        #[doc = #doc_is ]
243        #[inline]
244        #[allow(unused_variables)]
245        pub fn #function_name_is(&self) -> bool {
246            matches!(self, Self::#variant_name{ #matches })
247        }
248
249        #[doc = #doc_mut_ref ]
250        #[inline]
251        pub fn #function_name_mut_ref(&mut self) -> ::core::result::Result<#returns_mut_ref, #err_name #ty_generics> {
252            match self {
253                Self::#variant_name{ #matches } => {
254                    ::core::result::Result::Ok((#matches))
255                }
256                _ => {
257                    ::core::result::Result::Err(#err_name::new(
258                        stringify!(#variant_name),
259                        self.variant_name(),
260                        ::core::option::Option::None,
261                    ))
262                }
263            }
264        }
265
266        #[doc = #doc_ref ]
267        #[inline]
268        pub fn #function_name_ref(&self) -> ::core::result::Result<#returns_ref, #err_name #ty_generics> {
269            match self {
270                Self::#variant_name{ #matches } => {
271                    ::core::result::Result::Ok((#matches))
272                }
273                _ => {
274                    ::core::result::Result::Err(#err_name::new(
275                        stringify!(#variant_name),
276                        self.variant_name(),
277                        ::core::option::Option::None,
278                    ))
279                }
280            }
281        }
282
283        #[doc = #doc_val ]
284        #[inline]
285        pub fn #function_name_val(self) -> ::core::result::Result<#returns_val, #err_name #ty_generics> {
286            match self {
287                Self::#variant_name{ #matches } => {
288                    ::core::result::Result::Ok((#matches))
289                }
290                _ => {
291                    ::core::result::Result::Err(#err_name::new(
292                        stringify!(#variant_name),
293                        self.variant_name(),
294                        ::core::option::Option::Some(self),
295                    ))
296                }
297            }
298        }
299    )
300}
301
302fn impl_all_as_fns(
303    name: &Ident,
304    err_name: &Ident,
305    generics: &syn::Generics,
306    data: &DataEnum,
307) -> TokenStream {
308    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
309
310    let mut stream = TokenStream::new();
311    let mut variant_names = TokenStream::new();
312    for variant_data in &data.variants {
313        let variant_name = &variant_data.ident;
314        let function_name_ref = Ident::new(
315            &format!("try_as_{}", variant_name).to_snake_case(),
316            Span::call_site(),
317        );
318        let doc_ref = format!(
319            "Returns references to the inner fields if this is a `{}::{}`, otherwise an `{}`",
320            name, variant_name, &err_name,
321        );
322        let function_name_mut_ref = Ident::new(
323            &format!("try_as_{}_mut", variant_name).to_snake_case(),
324            Span::call_site(),
325        );
326        let doc_mut_ref = format!(
327            "Returns mutable references to the inner fields if this is a `{}::{}`, otherwise an `{}`",
328            name,
329            variant_name,
330            &err_name,
331        );
332
333        let function_name_val = Ident::new(
334            &format!("try_into_{}", variant_name).to_snake_case(),
335            Span::call_site(),
336        );
337        let doc_val = format!(
338            "Returns the inner fields if this is a `{}::{}`, otherwise returns back the enum in the `Err` case of the result",
339            name,
340            variant_name,
341        );
342
343        let function_name_is = Ident::new(
344            &format!("is_{}", variant_name).to_snake_case(),
345            Span::call_site(),
346        );
347        let doc_is = format!(
348            "Returns true if this is a `{}::{}`, otherwise false",
349            name, variant_name,
350        );
351
352        let tokens = match &variant_data.fields {
353            syn::Fields::Unit => unit_fields_return(
354                variant_name,
355                err_name,
356                &ty_generics,
357                (&function_name_is, &doc_is),
358                (&function_name_ref, &doc_ref),
359                (&function_name_val, &doc_val),
360            ),
361            syn::Fields::Unnamed(unnamed) => unnamed_fields_return(
362                variant_name,
363                err_name,
364                &ty_generics,
365                (&function_name_is, &doc_is),
366                (&function_name_mut_ref, &doc_mut_ref),
367                (&function_name_ref, &doc_ref),
368                (&function_name_val, &doc_val),
369                unnamed,
370            ),
371            syn::Fields::Named(named) => named_fields_return(
372                variant_name,
373                err_name,
374                &ty_generics,
375                (&function_name_is, &doc_is),
376                (&function_name_mut_ref, &doc_mut_ref),
377                (&function_name_ref, &doc_ref),
378                (&function_name_val, &doc_val),
379                named,
380            ),
381        };
382
383        stream.extend(tokens);
384
385        let variant_name = match &variant_data.fields {
386            syn::Fields::Unit => quote!(Self::#variant_name => stringify!(#variant_name),),
387            syn::Fields::Unnamed(_) => {
388                quote!(Self::#variant_name(..) => stringify!(#variant_name),)
389            }
390            syn::Fields::Named(_) => quote!(Self::#variant_name{..} => stringify!(#variant_name),),
391        };
392
393        variant_names.extend(variant_name);
394    }
395
396    quote!(
397        impl #impl_generics #name #ty_generics #where_clause {
398            #stream
399
400            /// Returns the name of the variant.
401            fn variant_name(&self) -> &'static str {
402                match self {
403                    #variant_names
404                    _ => unreachable!(),
405                }
406            }
407        }
408    )
409}
410
411fn impl_err(
412    name: &Ident,
413    err_name: &Ident,
414    vis: &Visibility,
415    generics: &syn::Generics,
416    attrs: &[syn::Attribute],
417) -> TokenStream {
418    let doc_err = format!("An error type for the `{}::try_as_*` functions", name);
419
420    // get the derives for the error type
421    let mut derives = Vec::new();
422    let mut derive_debug = false;
423    for attr in attrs {
424        if attr.path().is_ident("derive_err") {
425            attr.parse_nested_meta(|meta| {
426                if meta.path.is_ident("Debug") {
427                    derive_debug = true;
428                } else {
429                    derives.push(meta.path);
430                }
431
432                Ok(())
433            })
434            .expect("failed to parse derive nested meta");
435        }
436    }
437
438    let derive_err = if derives.is_empty() {
439        quote!()
440    } else {
441        quote!(#[derive(#(#derives),*)])
442    };
443
444    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
445
446    let mut err_impl = quote!(
447        #[doc = #doc_err ]
448        #derive_err
449        #vis struct #err_name #generics {
450            expected: &'static str,
451            actual: &'static str,
452            value: ::core::option::Option<#name #ty_generics>,
453        }
454
455        impl #impl_generics #err_name #ty_generics #where_clause {
456            /// Creates a new error indicating the expected variant and the actual variant.
457            fn new(
458                expected: &'static str,
459                actual: &'static str,
460                value: ::core::option::Option<#name #ty_generics>
461            ) -> Self {
462                Self {
463                    expected,
464                    actual,
465                    value,
466                }
467            }
468
469            /// Returns the name of the variant that was expected.
470            pub fn expected(&self) -> &'static str {
471                self.expected
472            }
473
474            /// Returns the name of the actual variant.
475            pub fn actual(&self) -> &'static str {
476                self.actual
477            }
478
479            /// Returns a reference to the actual value, if present.
480            pub fn value(&self) -> ::core::option::Option<&#name #ty_generics> {
481                self.value.as_ref()
482            }
483
484            /// Returns the actual value, if present.
485            pub fn into_value(self) -> ::core::option::Option<#name #ty_generics> {
486                self.value
487            }
488        }
489    );
490
491    if derive_debug {
492        let impl_debug_body = {
493            let where_clause = if let Some(where_clause) = where_clause {
494                quote!(#where_clause, #name #ty_generics: ::core::fmt::Debug)
495            } else {
496                quote!(where #name #ty_generics: ::core::fmt::Debug)
497            };
498
499            quote!(
500                impl #impl_generics ::core::fmt::Debug for #err_name #ty_generics #where_clause {
501                    fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
502                        f.debug_struct(stringify!(#err_name))
503                            .field("expected", &self.expected)
504                            .field("actual", &self.actual)
505                            .field("value", &self.value)
506                            .finish()
507                    }
508                }
509            )
510        };
511
512        let impl_display_body = {
513            let display_fmt = format!("expected {name}::{{}}, but got {name}::{{}}");
514            quote!(
515                impl #impl_generics ::core::fmt::Display for #err_name #ty_generics #where_clause {
516                    fn fmt(&self, formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
517                        write!(
518                            formatter,
519                            #display_fmt,
520                            self.expected(),
521                            self.actual(),
522                        )
523                    }
524                }
525            )
526        };
527
528        let impl_err_body = {
529            let where_clause = if let Some(where_clause) = where_clause {
530                quote!(#where_clause, #name #ty_generics: ::core::fmt::Debug)
531            } else {
532                quote!(where #name #ty_generics: ::core::fmt::Debug)
533            };
534
535            quote!(
536                impl #impl_generics ::std::error::Error for #err_name #ty_generics #where_clause {}
537            )
538        };
539
540        err_impl.extend(quote!(
541            #impl_debug_body
542
543            #impl_display_body
544
545            #impl_err_body
546        ))
547    }
548
549    err_impl
550}
551
552/// Derive functions on an Enum for easily accessing individual items in the Enum
553#[proc_macro_derive(EnumTryAsInner, attributes(derive_err))]
554pub fn enum_try_as_inner(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
555    // get a usable token stream
556    let ast: DeriveInput = parse_macro_input!(input as DeriveInput);
557
558    let name = &ast.ident;
559    let err_name = Ident::new(&format!("{}Error", name), Span::call_site());
560    let generics = &ast.generics;
561    let vis = &ast.vis;
562
563    let enum_data = if let syn::Data::Enum(data) = &ast.data {
564        data
565    } else {
566        panic!("{} is not an enum", name);
567    };
568
569    let mut expanded = TokenStream::new();
570
571    // Build the impl
572    let fns = impl_all_as_fns(name, &err_name, generics, enum_data);
573
574    // Build the error
575    let err = impl_err(name, &err_name, vis, generics, &ast.attrs);
576
577    expanded.extend(fns);
578    expanded.extend(err);
579
580    proc_macro::TokenStream::from(expanded)
581}