enum_extract_macro/
lib.rs

1// Copyright 2015-2018 Benjamin Fry <benjaminfry@me.com>
2// Copyright 2023 James La Novara-Gsell <james.lanovara.gsell@gmail.com>
3//
4// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
5// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
6// http://opensource.org/licenses/MIT>, at your option. This file may not be
7// copied, modified, or distributed except according to those terms.
8
9//! Derive functions on an Enum for easily accessing individual items in the Enum.
10//! This crate is intended to be used with the [enum-extract-error](https://crates.io/crates/enum-extract-error) crate.
11//!
12//! # Summary
13//!
14//! This crate adds a `EnumExtract` derive macro that adds the following functions for each variant in your enum:
15//!
16//! 1. `is_[variant]`: Returns a bool indicated whether the actual variant matches the expected variant.
17//! 2. `as_[variant]`: Returns a Result with a reference to the data contained by the variant, or an error if the actual variant is not the expected variant type.
18//! 3. `as_[variant]_mut`: Like `as_[variant]` but returns a mutable reference.
19//! 4. `into_[variant]`: Like `as_[variant]` but consumes the value and returns an owned value instead of a reference.
20//! 5. `extract_as_[variant]`: Calls `as_[variant]` and returns the data or panics if there was an error.
21//! 6. `extract_as_[variant]_mut`: Calls `as_[variant]_mut` and returns the data or panics if there was an error.
22//! 7. `extract_into_[variant]`: Calls `into_[variant]` and returns the data or panics if there was an error.
23//!
24//! ## Notes on the `extract` functions
25//!
26//! These functions are slightly different from calling `as_[variant]().unwrap()` because they panic with the `Display` output of `EnumExtractError` rather than the `Debug` output.
27//!
28//! Since these functions can panic they are not recommended for production code.
29//! Their main use is in tests, in which they can simplify and flatten tests significantly.
30//!
31//! # Examples
32//!
33//! ## Unit Variants
34//!
35//! Check if the variant is the expected variant:
36//!
37//! ```rust
38//! use enum_extract_macro::EnumExtract;
39//!
40//! #[derive(Debug, EnumExtract)]
41//! enum UnitVariants {
42//!     One,
43//!     Two,
44//! }
45//!
46//! let unit = UnitVariants::One;
47//! assert!(unit.is_one());
48//! assert!(!unit.is_two());
49//! ```
50//!
51//! ## Unnamed Variants
52//!
53//! Check if the variant is the expected variant:
54//!
55//! ```rust
56//! use enum_extract_macro::EnumExtract;
57//!
58//! #[derive(Debug, EnumExtract)]
59//! enum UnnamedVariants {
60//!     One(u32),
61//!     Two(u32, i32),
62//! }
63//!
64//! let unnamed = UnnamedVariants::One(1);
65//! assert!(unnamed.is_one());
66//! assert!(!unnamed.is_two());
67//! ```
68//!
69//! Get the variant's value:
70//!
71//! ```rust
72//! use enum_extract_macro::EnumExtract;
73//!
74//! #[derive(Debug, EnumExtract)]
75//! enum UnnamedVariants {
76//!     One(u32),
77//!     Two(u32, i32),
78//! }
79//!
80//! fn main() -> Result<(), enum_extract_error::EnumExtractError> {
81//!     let mut unnamed = UnnamedVariants::One(1);
82//!
83//!     // returns a reference to the value
84//!     let one = unnamed.as_one()?;
85//!     assert_eq!(*one, 1);
86//!
87//!     // returns a mutable reference to the value
88//!     let one = unnamed.as_one_mut()?;
89//!     assert_eq!(*one, 1);
90//!
91//!     // returns the value by consuming the enum
92//!     let one = unnamed.into_one()?;
93//!     assert_eq!(one, 1);
94//!
95//!     Ok(())
96//! }
97//! ```
98//!
99//! If the variant has multiple values, a tuple will be returned:
100//!
101//! ```rust
102//! use enum_extract_macro::EnumExtract;
103//!
104//! #[derive(Debug, EnumExtract)]
105//! enum UnnamedVariants {
106//!     One(u32),
107//!     Two(u32, i32),
108//! }
109//!
110//! fn main() -> Result<(), enum_extract_error::EnumExtractError> {
111//!     let mut unnamed = UnnamedVariants::Two(1, 2);
112//!
113//!     // returns a reference to the value
114//!     let two = unnamed.as_two()?;
115//!     assert_eq!(two, (&1, &2));
116//!
117//!     // returns a mutable reference to the value
118//!     let two = unnamed.as_two_mut()?;
119//!     assert_eq!(two, (&mut 1, &mut 2));
120//!
121//!     // returns the value by consuming the enum
122//!     let two = unnamed.into_two()?;
123//!     assert_eq!(two, (1, 2));
124//!
125//!     Ok(())
126//! }
127//! ```
128//!
129//! Extract variants of all of the above methods will panic with a decent message if the variant is not the expected variant.
130//! Very useful for testing, but not recommended for production code.
131//!
132//! See the [enum-extract-error](https://crates.io/crates/enum-extract-error) crate for more information on the error type.
133//!
134//! ```rust
135//! use enum_extract_macro::EnumExtract;
136//!
137//! #[derive(Debug, EnumExtract)]
138//! enum UnnamedVariants {
139//!     One(u32),
140//!     Two(u32, i32),
141//! }
142//!
143//! let mut unnamed = UnnamedVariants::One(1);
144//!
145//! // returns a reference to the value
146//! let one = unnamed.extract_as_one();
147//! assert_eq!(*one, 1);
148//!
149//! // returns a mutable reference to the value
150//! let one = unnamed.extract_as_one_mut();
151//! assert_eq!(*one, 1);
152//!
153//! // returns the value by consuming the enum
154//! let one = unnamed.extract_into_one();
155//! assert_eq!(one, 1);
156//! ```
157//!
158//! ```should_panic
159//! use enum_extract_macro::EnumExtract;
160//!
161//! #[derive(Debug, EnumExtract)]
162//! enum UnnamedVariants {
163//!     One(u32),
164//!     Two(u32, i32),
165//! }
166//!
167//! let unnamed = UnnamedVariants::One(1);
168//!
169//! // panics with a decent message
170//! let one = unnamed.extract_as_two();
171//! ```
172//!
173//! ## Named Variants
174//!
175//! Check if the variant is the expected variant:
176//!
177//! ```rust
178//! use enum_extract_macro::EnumExtract;
179//!
180//! #[derive(Debug, EnumExtract)]
181//! enum NamedVariants {
182//!     One {
183//!         first: u32
184//!     },
185//!     Two {
186//!         first: u32,
187//!         second: i32
188//!     },
189//! }
190//!
191//! let named = NamedVariants::One { first: 1 };
192//! assert!(named.is_one());
193//! assert!(!named.is_two());
194//! ```
195//!
196//! Get the variant's value:
197//!
198//! ```rust
199//! use enum_extract_macro::EnumExtract;
200//!
201//! #[derive(Debug, EnumExtract)]
202//! enum NamedVariants {
203//!     One {
204//!         first: u32
205//!     },
206//!     Two {
207//!         first: u32,
208//!         second: i32
209//!     },
210//! }
211//!
212//! fn main() -> Result<(), enum_extract_error::EnumExtractError> {
213//!     let mut named = NamedVariants::One { first: 1 };
214//!
215//!     // returns a reference to the value
216//!     let one = named.as_one()?;
217//!     assert_eq!(*one, 1);
218//!
219//!     // returns a mutable reference to the value
220//!     let one = named.as_one_mut()?;
221//!     assert_eq!(*one, 1);
222//!
223//!     // returns the value by consuming the enum
224//!     let one = named.into_one()?;
225//!     assert_eq!(one, 1);
226//!
227//!     Ok(())
228//! }
229//! ```
230//!
231//! If the variant has multiple values, a tuple will be returned:
232//!
233//! ```rust
234//! use enum_extract_macro::EnumExtract;
235//!
236//! #[derive(Debug, EnumExtract)]
237//! enum NamedVariants {
238//!     One {
239//!         first: u32
240//!     },
241//!     Two {
242//!         first: u32,
243//!         second: i32
244//!     },
245//! }
246//!
247//! fn main() -> Result<(), enum_extract_error::EnumExtractError> {
248//!     let mut unnamed = NamedVariants::Two { first: 1, second: 2 };
249//!
250//!     // returns a reference to the value
251//!     let two = unnamed.as_two()?;
252//!     assert_eq!(two, (&1, &2));
253//!
254//!     // returns a mutable reference to the value
255//!     let two = unnamed.as_two_mut()?;
256//!     assert_eq!(two, (&mut 1, &mut 2));
257//!
258//!     // returns the value by consuming the enum
259//!     let two = unnamed.into_two()?;
260//!     assert_eq!(two, (1, 2));
261//!
262//!     Ok(())
263//! }
264//! ```
265//!
266//! Extract variants of all of the above methods will panic with a decent message if the variant is not the expected variant.
267//! Very useful for testing, but not recommended for production code.
268//!
269//! See the [enum-extract-error](https://crates.io/crates/enum-extract-error) crate for more information on the error type.
270//!
271//! ```rust
272//! use enum_extract_macro::EnumExtract;
273//!
274//! #[derive(Debug, EnumExtract)]
275//! enum NamedVariants {
276//!     One {
277//!         first: u32
278//!     },
279//!     Two {
280//!         first: u32,
281//!         second: i32
282//!     },
283//! }
284//!
285//! let mut named = NamedVariants::One { first: 1 };
286//!
287//! // returns a reference to the value
288//! let one = named.extract_as_one();
289//! assert_eq!(*one, 1);
290//!
291//! // returns a mutable reference to the value
292//! let one = named.extract_as_one_mut();
293//! assert_eq!(*one, 1);
294//!
295//! // returns the value by consuming the enum
296//! let one = named.extract_into_one();
297//! assert_eq!(one, 1);
298//! ```
299//!
300//! ```should_panic
301//! use enum_extract_macro::EnumExtract;
302//!
303//! #[derive(Debug, EnumExtract)]
304//! enum NamedVariants {
305//!     One {
306//!         first: u32
307//!     },
308//!     Two {
309//!         first: u32,
310//!         second: i32
311//!     },
312//! }
313//!
314//! let named = NamedVariants::One { first: 1 };
315//!
316//! // panics with a decent message
317//! let one = named.extract_as_two();
318//! ```
319
320#![warn(missing_docs)]
321
322use proc_macro2::{Ident, Span, TokenStream};
323use quote::quote;
324use syn::{parse_macro_input, DataEnum, DeriveInput};
325
326mod function_def;
327mod named_enum_functions;
328mod unit_enum_functions;
329mod unnamed_enum_functions;
330
331/// Derive functions on an Enum for easily accessing individual items in the Enum
332#[proc_macro_derive(EnumExtract, attributes(derive_err))]
333pub fn enum_extract(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
334    // get a usable token stream
335    let ast: DeriveInput = parse_macro_input!(input as DeriveInput);
336
337    let name = &ast.ident;
338    let generics = &ast.generics;
339
340    let enum_data = if let syn::Data::Enum(data) = &ast.data {
341        data
342    } else {
343        panic!("{} is not an enum", name);
344    };
345
346    let mut expanded = TokenStream::new();
347
348    // Build the impl
349    let fns = impl_all_as_fns(name, generics, enum_data);
350
351    expanded.extend(fns);
352
353    proc_macro::TokenStream::from(expanded)
354}
355
356/// Returns an impl block for all of the enum's functions.
357fn impl_all_as_fns(enum_name: &Ident, generics: &syn::Generics, data: &DataEnum) -> TokenStream {
358    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
359
360    let err_path = syn::Path::from(syn::PathSegment::from(syn::Ident::new(
361        "enum_extract_error",
362        Span::call_site(),
363    )));
364    let err_name = syn::Ident::new("EnumExtractError", Span::call_site());
365    let err_type = get_error_type(&err_name, &err_path);
366
367    let err_value_name = syn::Ident::new("EnumExtractValueError", Span::call_site());
368    let err_value_type = get_error_type(&err_value_name, &err_path);
369    let err_value_type_with_generics =
370        get_error_type_with_generics(err_value_name, err_path, enum_name, generics);
371
372    let mut stream = TokenStream::new();
373    let mut variant_names = TokenStream::new();
374    for variant_data in &data.variants {
375        let variant_name = &variant_data.ident;
376
377        let tokens = match &variant_data.fields {
378            syn::Fields::Unit => unit_enum_functions::all_unit_functions(enum_name, variant_name),
379            syn::Fields::Unnamed(unnamed) => unnamed_enum_functions::all_unnamed_functions(
380                enum_name,
381                variant_name,
382                &err_type,
383                &err_value_type,
384                &err_value_type_with_generics,
385                unnamed,
386            ),
387            syn::Fields::Named(named) => named_enum_functions::all_named_functions(
388                enum_name,
389                variant_name,
390                &err_type,
391                &err_value_type,
392                &err_value_type_with_generics,
393                named,
394            ),
395        };
396
397        stream.extend(tokens);
398
399        let variant_name = match &variant_data.fields {
400            syn::Fields::Unit => quote!(Self::#variant_name => stringify!(#variant_name),),
401            syn::Fields::Unnamed(_) => {
402                quote!(Self::#variant_name(..) => stringify!(#variant_name),)
403            }
404            syn::Fields::Named(_) => quote!(Self::#variant_name{..} => stringify!(#variant_name),),
405        };
406
407        variant_names.extend(variant_name);
408    }
409
410    quote!(
411        impl #impl_generics #enum_name #ty_generics #where_clause {
412            #stream
413
414            /// Returns the name of the variant.
415            fn variant_name(&self) -> &'static str {
416                match self {
417                    #variant_names
418                    _ => unreachable!(),
419                }
420            }
421        }
422    )
423}
424
425/// Returns the error type. ex: `EnumExtractError`
426fn get_error_type(err_name: &Ident, err_path: &syn::Path) -> syn::Type {
427    let err_type = {
428        let last_segment = syn::PathSegment::from(err_name.clone());
429        let mut path = err_path.clone();
430        path.segments.push(last_segment);
431        syn::Type::Path(syn::TypePath {
432            qself: None,
433            path: path,
434        })
435    };
436    err_type
437}
438
439/// Returns the error type with generics. ex: `EnumExtractError<T>`
440fn get_error_type_with_generics(
441    err_name: Ident,
442    err_path: syn::Path,
443    enum_name: &Ident,
444    generics: &syn::Generics,
445) -> syn::Type {
446    let err_type_with_generics = {
447        let mut last_segment = syn::PathSegment::from(err_name.clone());
448        let mut path = err_path.clone();
449
450        let mut inner_type_path = syn::Path::from(syn::PathSegment::from(enum_name.clone()));
451        let inner_type_segment = inner_type_path.segments.last_mut().unwrap();
452        let mut generic_args = syn::punctuated::Punctuated::new();
453        for param in generics.params.iter() {
454            match param {
455                syn::GenericParam::Lifetime(lifetime_param) => {
456                    generic_args.push(syn::GenericArgument::Lifetime(syn::Lifetime::new(
457                        &format!("'{}", lifetime_param.lifetime.ident),
458                        Span::call_site(),
459                    )));
460                }
461                syn::GenericParam::Const(const_param) => {
462                    generic_args.push(syn::GenericArgument::Const(syn::Expr::Path(
463                        syn::ExprPath {
464                            attrs: vec![],
465                            qself: None,
466                            path: syn::Path::from(syn::PathSegment::from(
467                                const_param.ident.clone(),
468                            )),
469                        },
470                    )));
471                }
472                syn::GenericParam::Type(type_param) => {
473                    generic_args.push(syn::GenericArgument::Type(syn::Type::Path(syn::TypePath {
474                        qself: None,
475                        path: syn::Path::from(syn::PathSegment::from(type_param.ident.clone())),
476                    })));
477                }
478            }
479        }
480        inner_type_segment.arguments =
481            syn::PathArguments::AngleBracketed(syn::AngleBracketedGenericArguments {
482                colon2_token: None,
483                lt_token: syn::token::Lt::default(),
484                args: generic_args,
485                gt_token: syn::token::Gt::default(),
486            });
487
488        last_segment.arguments =
489            syn::PathArguments::AngleBracketed(syn::AngleBracketedGenericArguments {
490                colon2_token: None,
491                lt_token: syn::token::Lt::default(),
492                args: syn::punctuated::Punctuated::from_iter(vec![syn::GenericArgument::Type(
493                    syn::Type::Path(syn::TypePath {
494                        qself: None,
495                        path: inner_type_path,
496                    }),
497                )]),
498                gt_token: syn::token::Gt::default(),
499            });
500        path.segments.push(last_segment);
501        syn::Type::Path(syn::TypePath {
502            qself: None,
503            path: path,
504        })
505    };
506    err_type_with_generics
507}