enum_extract_macro/lib.rs
1// Copyright 2015-2018 Benjamin Fry <benjaminfry@me.com>
2// Copyright 2023 James La Novara-Gsell <james.lanovara.gsell@gmail.com>
3//
4// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
5// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
6// http://opensource.org/licenses/MIT>, at your option. This file may not be
7// copied, modified, or distributed except according to those terms.
8
9//! Derive functions on an Enum for easily accessing individual items in the Enum.
10//! This crate is intended to be used with the [enum-extract-error](https://crates.io/crates/enum-extract-error) crate.
11//!
12//! # Summary
13//!
14//! This crate adds a `EnumExtract` derive macro that adds the following functions for each variant in your enum:
15//!
16//! 1. `is_[variant]`: Returns a bool indicated whether the actual variant matches the expected variant.
17//! 2. `as_[variant]`: Returns a Result with a reference to the data contained by the variant, or an error if the actual variant is not the expected variant type.
18//! 3. `as_[variant]_mut`: Like `as_[variant]` but returns a mutable reference.
19//! 4. `into_[variant]`: Like `as_[variant]` but consumes the value and returns an owned value instead of a reference.
20//! 5. `extract_as_[variant]`: Calls `as_[variant]` and returns the data or panics if there was an error.
21//! 6. `extract_as_[variant]_mut`: Calls `as_[variant]_mut` and returns the data or panics if there was an error.
22//! 7. `extract_into_[variant]`: Calls `into_[variant]` and returns the data or panics if there was an error.
23//!
24//! ## Notes on the `extract` functions
25//!
26//! These functions are slightly different from calling `as_[variant]().unwrap()` because they panic with the `Display` output of `EnumExtractError` rather than the `Debug` output.
27//!
28//! Since these functions can panic they are not recommended for production code.
29//! Their main use is in tests, in which they can simplify and flatten tests significantly.
30//!
31//! # Examples
32//!
33//! ## Unit Variants
34//!
35//! Check if the variant is the expected variant:
36//!
37//! ```rust
38//! use enum_extract_macro::EnumExtract;
39//!
40//! #[derive(Debug, EnumExtract)]
41//! enum UnitVariants {
42//! One,
43//! Two,
44//! }
45//!
46//! let unit = UnitVariants::One;
47//! assert!(unit.is_one());
48//! assert!(!unit.is_two());
49//! ```
50//!
51//! ## Unnamed Variants
52//!
53//! Check if the variant is the expected variant:
54//!
55//! ```rust
56//! use enum_extract_macro::EnumExtract;
57//!
58//! #[derive(Debug, EnumExtract)]
59//! enum UnnamedVariants {
60//! One(u32),
61//! Two(u32, i32),
62//! }
63//!
64//! let unnamed = UnnamedVariants::One(1);
65//! assert!(unnamed.is_one());
66//! assert!(!unnamed.is_two());
67//! ```
68//!
69//! Get the variant's value:
70//!
71//! ```rust
72//! use enum_extract_macro::EnumExtract;
73//!
74//! #[derive(Debug, EnumExtract)]
75//! enum UnnamedVariants {
76//! One(u32),
77//! Two(u32, i32),
78//! }
79//!
80//! fn main() -> Result<(), enum_extract_error::EnumExtractError> {
81//! let mut unnamed = UnnamedVariants::One(1);
82//!
83//! // returns a reference to the value
84//! let one = unnamed.as_one()?;
85//! assert_eq!(*one, 1);
86//!
87//! // returns a mutable reference to the value
88//! let one = unnamed.as_one_mut()?;
89//! assert_eq!(*one, 1);
90//!
91//! // returns the value by consuming the enum
92//! let one = unnamed.into_one()?;
93//! assert_eq!(one, 1);
94//!
95//! Ok(())
96//! }
97//! ```
98//!
99//! If the variant has multiple values, a tuple will be returned:
100//!
101//! ```rust
102//! use enum_extract_macro::EnumExtract;
103//!
104//! #[derive(Debug, EnumExtract)]
105//! enum UnnamedVariants {
106//! One(u32),
107//! Two(u32, i32),
108//! }
109//!
110//! fn main() -> Result<(), enum_extract_error::EnumExtractError> {
111//! let mut unnamed = UnnamedVariants::Two(1, 2);
112//!
113//! // returns a reference to the value
114//! let two = unnamed.as_two()?;
115//! assert_eq!(two, (&1, &2));
116//!
117//! // returns a mutable reference to the value
118//! let two = unnamed.as_two_mut()?;
119//! assert_eq!(two, (&mut 1, &mut 2));
120//!
121//! // returns the value by consuming the enum
122//! let two = unnamed.into_two()?;
123//! assert_eq!(two, (1, 2));
124//!
125//! Ok(())
126//! }
127//! ```
128//!
129//! Extract variants of all of the above methods will panic with a decent message if the variant is not the expected variant.
130//! Very useful for testing, but not recommended for production code.
131//!
132//! See the [enum-extract-error](https://crates.io/crates/enum-extract-error) crate for more information on the error type.
133//!
134//! ```rust
135//! use enum_extract_macro::EnumExtract;
136//!
137//! #[derive(Debug, EnumExtract)]
138//! enum UnnamedVariants {
139//! One(u32),
140//! Two(u32, i32),
141//! }
142//!
143//! let mut unnamed = UnnamedVariants::One(1);
144//!
145//! // returns a reference to the value
146//! let one = unnamed.extract_as_one();
147//! assert_eq!(*one, 1);
148//!
149//! // returns a mutable reference to the value
150//! let one = unnamed.extract_as_one_mut();
151//! assert_eq!(*one, 1);
152//!
153//! // returns the value by consuming the enum
154//! let one = unnamed.extract_into_one();
155//! assert_eq!(one, 1);
156//! ```
157//!
158//! ```should_panic
159//! use enum_extract_macro::EnumExtract;
160//!
161//! #[derive(Debug, EnumExtract)]
162//! enum UnnamedVariants {
163//! One(u32),
164//! Two(u32, i32),
165//! }
166//!
167//! let unnamed = UnnamedVariants::One(1);
168//!
169//! // panics with a decent message
170//! let one = unnamed.extract_as_two();
171//! ```
172//!
173//! ## Named Variants
174//!
175//! Check if the variant is the expected variant:
176//!
177//! ```rust
178//! use enum_extract_macro::EnumExtract;
179//!
180//! #[derive(Debug, EnumExtract)]
181//! enum NamedVariants {
182//! One {
183//! first: u32
184//! },
185//! Two {
186//! first: u32,
187//! second: i32
188//! },
189//! }
190//!
191//! let named = NamedVariants::One { first: 1 };
192//! assert!(named.is_one());
193//! assert!(!named.is_two());
194//! ```
195//!
196//! Get the variant's value:
197//!
198//! ```rust
199//! use enum_extract_macro::EnumExtract;
200//!
201//! #[derive(Debug, EnumExtract)]
202//! enum NamedVariants {
203//! One {
204//! first: u32
205//! },
206//! Two {
207//! first: u32,
208//! second: i32
209//! },
210//! }
211//!
212//! fn main() -> Result<(), enum_extract_error::EnumExtractError> {
213//! let mut named = NamedVariants::One { first: 1 };
214//!
215//! // returns a reference to the value
216//! let one = named.as_one()?;
217//! assert_eq!(*one, 1);
218//!
219//! // returns a mutable reference to the value
220//! let one = named.as_one_mut()?;
221//! assert_eq!(*one, 1);
222//!
223//! // returns the value by consuming the enum
224//! let one = named.into_one()?;
225//! assert_eq!(one, 1);
226//!
227//! Ok(())
228//! }
229//! ```
230//!
231//! If the variant has multiple values, a tuple will be returned:
232//!
233//! ```rust
234//! use enum_extract_macro::EnumExtract;
235//!
236//! #[derive(Debug, EnumExtract)]
237//! enum NamedVariants {
238//! One {
239//! first: u32
240//! },
241//! Two {
242//! first: u32,
243//! second: i32
244//! },
245//! }
246//!
247//! fn main() -> Result<(), enum_extract_error::EnumExtractError> {
248//! let mut unnamed = NamedVariants::Two { first: 1, second: 2 };
249//!
250//! // returns a reference to the value
251//! let two = unnamed.as_two()?;
252//! assert_eq!(two, (&1, &2));
253//!
254//! // returns a mutable reference to the value
255//! let two = unnamed.as_two_mut()?;
256//! assert_eq!(two, (&mut 1, &mut 2));
257//!
258//! // returns the value by consuming the enum
259//! let two = unnamed.into_two()?;
260//! assert_eq!(two, (1, 2));
261//!
262//! Ok(())
263//! }
264//! ```
265//!
266//! Extract variants of all of the above methods will panic with a decent message if the variant is not the expected variant.
267//! Very useful for testing, but not recommended for production code.
268//!
269//! See the [enum-extract-error](https://crates.io/crates/enum-extract-error) crate for more information on the error type.
270//!
271//! ```rust
272//! use enum_extract_macro::EnumExtract;
273//!
274//! #[derive(Debug, EnumExtract)]
275//! enum NamedVariants {
276//! One {
277//! first: u32
278//! },
279//! Two {
280//! first: u32,
281//! second: i32
282//! },
283//! }
284//!
285//! let mut named = NamedVariants::One { first: 1 };
286//!
287//! // returns a reference to the value
288//! let one = named.extract_as_one();
289//! assert_eq!(*one, 1);
290//!
291//! // returns a mutable reference to the value
292//! let one = named.extract_as_one_mut();
293//! assert_eq!(*one, 1);
294//!
295//! // returns the value by consuming the enum
296//! let one = named.extract_into_one();
297//! assert_eq!(one, 1);
298//! ```
299//!
300//! ```should_panic
301//! use enum_extract_macro::EnumExtract;
302//!
303//! #[derive(Debug, EnumExtract)]
304//! enum NamedVariants {
305//! One {
306//! first: u32
307//! },
308//! Two {
309//! first: u32,
310//! second: i32
311//! },
312//! }
313//!
314//! let named = NamedVariants::One { first: 1 };
315//!
316//! // panics with a decent message
317//! let one = named.extract_as_two();
318//! ```
319
320#![warn(missing_docs)]
321
322use proc_macro2::{Ident, Span, TokenStream};
323use quote::quote;
324use syn::{parse_macro_input, DataEnum, DeriveInput};
325
326mod function_def;
327mod named_enum_functions;
328mod unit_enum_functions;
329mod unnamed_enum_functions;
330
331/// Derive functions on an Enum for easily accessing individual items in the Enum
332#[proc_macro_derive(EnumExtract, attributes(derive_err))]
333pub fn enum_extract(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
334 // get a usable token stream
335 let ast: DeriveInput = parse_macro_input!(input as DeriveInput);
336
337 let name = &ast.ident;
338 let generics = &ast.generics;
339
340 let enum_data = if let syn::Data::Enum(data) = &ast.data {
341 data
342 } else {
343 panic!("{} is not an enum", name);
344 };
345
346 let mut expanded = TokenStream::new();
347
348 // Build the impl
349 let fns = impl_all_as_fns(name, generics, enum_data);
350
351 expanded.extend(fns);
352
353 proc_macro::TokenStream::from(expanded)
354}
355
356/// Returns an impl block for all of the enum's functions.
357fn impl_all_as_fns(enum_name: &Ident, generics: &syn::Generics, data: &DataEnum) -> TokenStream {
358 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
359
360 let err_path = syn::Path::from(syn::PathSegment::from(syn::Ident::new(
361 "enum_extract_error",
362 Span::call_site(),
363 )));
364 let err_name = syn::Ident::new("EnumExtractError", Span::call_site());
365 let err_type = get_error_type(&err_name, &err_path);
366
367 let err_value_name = syn::Ident::new("EnumExtractValueError", Span::call_site());
368 let err_value_type = get_error_type(&err_value_name, &err_path);
369 let err_value_type_with_generics =
370 get_error_type_with_generics(err_value_name, err_path, enum_name, generics);
371
372 let mut stream = TokenStream::new();
373 let mut variant_names = TokenStream::new();
374 for variant_data in &data.variants {
375 let variant_name = &variant_data.ident;
376
377 let tokens = match &variant_data.fields {
378 syn::Fields::Unit => unit_enum_functions::all_unit_functions(enum_name, variant_name),
379 syn::Fields::Unnamed(unnamed) => unnamed_enum_functions::all_unnamed_functions(
380 enum_name,
381 variant_name,
382 &err_type,
383 &err_value_type,
384 &err_value_type_with_generics,
385 unnamed,
386 ),
387 syn::Fields::Named(named) => named_enum_functions::all_named_functions(
388 enum_name,
389 variant_name,
390 &err_type,
391 &err_value_type,
392 &err_value_type_with_generics,
393 named,
394 ),
395 };
396
397 stream.extend(tokens);
398
399 let variant_name = match &variant_data.fields {
400 syn::Fields::Unit => quote!(Self::#variant_name => stringify!(#variant_name),),
401 syn::Fields::Unnamed(_) => {
402 quote!(Self::#variant_name(..) => stringify!(#variant_name),)
403 }
404 syn::Fields::Named(_) => quote!(Self::#variant_name{..} => stringify!(#variant_name),),
405 };
406
407 variant_names.extend(variant_name);
408 }
409
410 quote!(
411 impl #impl_generics #enum_name #ty_generics #where_clause {
412 #stream
413
414 /// Returns the name of the variant.
415 fn variant_name(&self) -> &'static str {
416 match self {
417 #variant_names
418 _ => unreachable!(),
419 }
420 }
421 }
422 )
423}
424
425/// Returns the error type. ex: `EnumExtractError`
426fn get_error_type(err_name: &Ident, err_path: &syn::Path) -> syn::Type {
427 let err_type = {
428 let last_segment = syn::PathSegment::from(err_name.clone());
429 let mut path = err_path.clone();
430 path.segments.push(last_segment);
431 syn::Type::Path(syn::TypePath {
432 qself: None,
433 path: path,
434 })
435 };
436 err_type
437}
438
439/// Returns the error type with generics. ex: `EnumExtractError<T>`
440fn get_error_type_with_generics(
441 err_name: Ident,
442 err_path: syn::Path,
443 enum_name: &Ident,
444 generics: &syn::Generics,
445) -> syn::Type {
446 let err_type_with_generics = {
447 let mut last_segment = syn::PathSegment::from(err_name.clone());
448 let mut path = err_path.clone();
449
450 let mut inner_type_path = syn::Path::from(syn::PathSegment::from(enum_name.clone()));
451 let inner_type_segment = inner_type_path.segments.last_mut().unwrap();
452 let mut generic_args = syn::punctuated::Punctuated::new();
453 for param in generics.params.iter() {
454 match param {
455 syn::GenericParam::Lifetime(lifetime_param) => {
456 generic_args.push(syn::GenericArgument::Lifetime(syn::Lifetime::new(
457 &format!("'{}", lifetime_param.lifetime.ident),
458 Span::call_site(),
459 )));
460 }
461 syn::GenericParam::Const(const_param) => {
462 generic_args.push(syn::GenericArgument::Const(syn::Expr::Path(
463 syn::ExprPath {
464 attrs: vec![],
465 qself: None,
466 path: syn::Path::from(syn::PathSegment::from(
467 const_param.ident.clone(),
468 )),
469 },
470 )));
471 }
472 syn::GenericParam::Type(type_param) => {
473 generic_args.push(syn::GenericArgument::Type(syn::Type::Path(syn::TypePath {
474 qself: None,
475 path: syn::Path::from(syn::PathSegment::from(type_param.ident.clone())),
476 })));
477 }
478 }
479 }
480 inner_type_segment.arguments =
481 syn::PathArguments::AngleBracketed(syn::AngleBracketedGenericArguments {
482 colon2_token: None,
483 lt_token: syn::token::Lt::default(),
484 args: generic_args,
485 gt_token: syn::token::Gt::default(),
486 });
487
488 last_segment.arguments =
489 syn::PathArguments::AngleBracketed(syn::AngleBracketedGenericArguments {
490 colon2_token: None,
491 lt_token: syn::token::Lt::default(),
492 args: syn::punctuated::Punctuated::from_iter(vec![syn::GenericArgument::Type(
493 syn::Type::Path(syn::TypePath {
494 qself: None,
495 path: inner_type_path,
496 }),
497 )]),
498 gt_token: syn::token::Gt::default(),
499 });
500 path.segments.push(last_segment);
501 syn::Type::Path(syn::TypePath {
502 qself: None,
503 path: path,
504 })
505 };
506 err_type_with_generics
507}