Skip to main content

enum_as_inner/
lib.rs

1// Copyright 2015-2018 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! # enum-as-inner
9//!
10//! A deriving proc-macro for generating functions to automatically give access to the inner members of enum.
11//!
12//! ## Basic unnamed field case
13//!
14//! The basic case is meant for single item enums, like:
15//!
16//! ```rust
17//! use enum_as_inner::EnumAsInner;
18//!
19//! #[derive(Debug, EnumAsInner)]
20//! enum OneEnum {
21//!     One(u32),
22//! }
23//!
24//! let one = OneEnum::One(1);
25//!
26//! assert_eq!(*one.as_one().unwrap(), 1);
27//! assert_eq!(one.into_one().unwrap(), 1);
28//! ```
29//!
30//! where the result is either a reference for inner items or a tuple containing the inner items.
31//!
32//! ## Unit case
33//!
34//! This will return true if enum's variant matches the expected type
35//!
36//! ```rust
37//! use enum_as_inner::EnumAsInner;
38//!
39//! #[derive(EnumAsInner)]
40//! enum UnitVariants {
41//!     Zero,
42//!     One,
43//!     Two,
44//! }
45//!
46//! let unit = UnitVariants::Two;
47//!
48//! assert!(unit.is_two());
49//! ```
50//!
51//! ## Mutliple, unnamed field case
52//!
53//! This will return a tuple of the inner types:
54//!
55//! ```rust
56//! use enum_as_inner::EnumAsInner;
57//!
58//! #[derive(Debug, EnumAsInner)]
59//! enum ManyVariants {
60//!     One(u32),
61//!     Two(u32, i32),
62//!     Three(bool, u32, i64),
63//! }
64//!
65//! let many = ManyVariants::Three(true, 1, 2);
66//!
67//! assert!(many.is_three());
68//! assert_eq!(many.as_three().unwrap(), (&true, &1_u32, &2_i64));
69//! assert_eq!(many.into_three().unwrap(), (true, 1_u32, 2_i64));
70//! ```
71//!
72//! ## Multiple, named field case
73//!
74//! This will return a tuple of the inner types, like the unnamed option:
75//!
76//! ```rust
77//! use enum_as_inner::EnumAsInner;
78//!
79//! #[derive(Debug, EnumAsInner)]
80//! enum ManyVariants {
81//!     One { one: u32 },
82//!     Two { one: u32, two: i32 },
83//!     Three { one: bool, two: u32, three: i64 },
84//! }
85//!
86//! let many = ManyVariants::Three { one: true, two: 1, three: 2 };
87//!
88//! assert!(many.is_three());
89//! assert_eq!(many.as_three().unwrap(), (&true, &1_u32, &2_i64));
90//! assert_eq!(many.into_three().unwrap(), (true, 1_u32, 2_i64));
91//! ```
92
93use heck::ToSnakeCase;
94use proc_macro2::{Ident, Span, TokenStream};
95use quote::quote;
96use syn::{parse_macro_input, DeriveInput};
97
98/// returns first the types to return, the match names, and then tokens to the field accesses
99fn unit_fields_return(variant_name: &syn::Ident, function_name: &Ident, doc: &str) -> TokenStream {
100    quote!(
101        #[doc = #doc]
102        #[inline]
103        pub fn #function_name(&self) -> bool {
104            matches!(self, Self::#variant_name)
105        }
106    )
107}
108
109/// returns first the types to return, the match names, and then tokens to the field accesses
110#[allow(clippy::too_many_arguments)]
111fn unnamed_fields_return(
112    variant_name: &syn::Ident,
113    (function_name_is, doc_is): (&Ident, &str),
114    (function_name_mut_ref, doc_mut_ref): (&Ident, &str),
115    (function_name_ref, doc_ref): (&Ident, &str),
116    (function_name_val, doc_val): (&Ident, &str),
117    (function_name_val_unchecked, doc_val_unchecked): (&Ident, &str),
118    (function_name_ref_unchecked, doc_ref_unchecked): (&Ident, &str),
119    (function_name_mut_ref_unchecked, doc_mut_ref_unchecked): (&Ident, &str),
120    fields: &syn::FieldsUnnamed,
121) -> TokenStream {
122    let (returns_mut_ref, returns_ref, returns_val, matches) = match fields.unnamed.len() {
123        1 => {
124            let field = fields.unnamed.first().expect("no fields on type");
125
126            let returns = &field.ty;
127            let returns_mut_ref = quote!(&mut #returns);
128            let returns_ref = quote!(&#returns);
129            let returns_val = quote!(#returns);
130            let matches = quote!(inner);
131
132            (returns_mut_ref, returns_ref, returns_val, matches)
133        }
134        0 => (quote!(()), quote!(()), quote!(()), quote!()),
135        _ => {
136            let mut returns_mut_ref = TokenStream::new();
137            let mut returns_ref = TokenStream::new();
138            let mut returns_val = TokenStream::new();
139            let mut matches = TokenStream::new();
140
141            for (i, field) in fields.unnamed.iter().enumerate() {
142                let rt = &field.ty;
143                let match_name = Ident::new(&format!("match_{}", i), Span::call_site());
144                returns_mut_ref.extend(quote!(&mut #rt,));
145                returns_ref.extend(quote!(&#rt,));
146                returns_val.extend(quote!(#rt,));
147                matches.extend(quote!(#match_name,));
148            }
149
150            (
151                quote!((#returns_mut_ref)),
152                quote!((#returns_ref)),
153                quote!((#returns_val)),
154                quote!(#matches),
155            )
156        }
157    };
158
159    quote!(
160        #[doc = #doc_is ]
161        #[inline]
162        #[allow(unused_variables)]
163        pub fn #function_name_is(&self) -> bool {
164            matches!(self, Self::#variant_name(..))
165        }
166
167        #[doc = #doc_mut_ref ]
168        #[inline]
169        pub fn #function_name_mut_ref(&mut self) -> ::core::option::Option<#returns_mut_ref> {
170            match self {
171                Self::#variant_name(#matches) => {
172                    ::core::option::Option::Some((#matches))
173                }
174                _ => ::core::option::Option::None
175            }
176        }
177
178        #[doc = #doc_ref ]
179        #[inline]
180        pub fn #function_name_ref(&self) -> ::core::option::Option<#returns_ref> {
181            match self {
182                Self::#variant_name(#matches) => {
183                    ::core::option::Option::Some((#matches))
184                }
185                _ => ::core::option::Option::None
186            }
187        }
188
189        #[doc = #doc_val ]
190        #[inline]
191        pub fn #function_name_val(self) -> ::core::result::Result<#returns_val, Self> {
192            match self {
193                Self::#variant_name(#matches) => {
194                    ::core::result::Result::Ok((#matches))
195                },
196                _ => ::core::result::Result::Err(self)
197            }
198        }
199
200        #[doc = #doc_val_unchecked ]
201        #[inline]
202        pub unsafe fn #function_name_val_unchecked(self) -> #returns_val {
203            match self {
204                Self::#variant_name(#matches) => (#matches),
205                _ => std::hint::unreachable_unchecked(),
206            }
207        }
208
209        #[doc = #doc_ref_unchecked ]
210        #[inline]
211        pub unsafe fn #function_name_ref_unchecked(&self) -> #returns_ref {
212            match self {
213                Self::#variant_name(#matches) => (#matches),
214                _ => std::hint::unreachable_unchecked(),
215            }
216        }
217
218        #[doc = #doc_mut_ref_unchecked ]
219        #[inline]
220        pub unsafe fn #function_name_mut_ref_unchecked(&mut self) -> #returns_mut_ref {
221            match self {
222                Self::#variant_name(#matches) => (#matches),
223                _ => std::hint::unreachable_unchecked(),
224            }
225        }
226    )
227}
228
229/// returns first the types to return, the match names, and then tokens to the field accesses
230#[allow(clippy::too_many_arguments)]
231fn named_fields_return(
232    variant_name: &syn::Ident,
233    (function_name_is, doc_is): (&Ident, &str),
234    (function_name_mut_ref, doc_mut_ref): (&Ident, &str),
235    (function_name_ref, doc_ref): (&Ident, &str),
236    (function_name_val, doc_val): (&Ident, &str),
237    (function_name_val_unchecked, doc_val_unchecked): (&Ident, &str),
238    (function_name_ref_unchecked, doc_ref_unchecked): (&Ident, &str),
239    (function_name_mut_ref_unchecked, doc_mut_ref_unchecked): (&Ident, &str),
240    fields: &syn::FieldsNamed,
241) -> TokenStream {
242    let (returns_mut_ref, returns_ref, returns_val, matches) = match fields.named.len() {
243        1 => {
244            let field = fields.named.first().expect("no fields on type");
245            let match_name = field.ident.as_ref().expect("expected a named field");
246
247            let returns = &field.ty;
248            let returns_mut_ref = quote!(&mut #returns);
249            let returns_ref = quote!(&#returns);
250            let returns_val = quote!(#returns);
251            let matches = quote!(#match_name);
252
253            (returns_mut_ref, returns_ref, returns_val, matches)
254        }
255        0 => (quote!(()), quote!(()), quote!(()), quote!(())),
256        _ => {
257            let mut returns_mut_ref = TokenStream::new();
258            let mut returns_ref = TokenStream::new();
259            let mut returns_val = TokenStream::new();
260            let mut matches = TokenStream::new();
261
262            for field in fields.named.iter() {
263                let rt = &field.ty;
264                let match_name = field.ident.as_ref().expect("expected a named field");
265
266                returns_mut_ref.extend(quote!(&mut #rt,));
267                returns_ref.extend(quote!(&#rt,));
268                returns_val.extend(quote!(#rt,));
269                matches.extend(quote!(#match_name,));
270            }
271
272            (
273                quote!((#returns_mut_ref)),
274                quote!((#returns_ref)),
275                quote!((#returns_val)),
276                quote!(#matches),
277            )
278        }
279    };
280
281    quote!(
282        #[doc = #doc_is ]
283        #[inline]
284        #[allow(unused_variables)]
285        pub fn #function_name_is(&self) -> bool {
286            matches!(self, Self::#variant_name{ .. })
287        }
288
289        #[doc = #doc_mut_ref ]
290        #[inline]
291        pub fn #function_name_mut_ref(&mut self) -> ::core::option::Option<#returns_mut_ref> {
292            match self {
293                Self::#variant_name{ #matches } => {
294                    ::core::option::Option::Some((#matches))
295                }
296                _ => ::core::option::Option::None
297            }
298        }
299
300        #[doc = #doc_ref ]
301        #[inline]
302        pub fn #function_name_ref(&self) -> ::core::option::Option<#returns_ref> {
303            match self {
304                Self::#variant_name{ #matches } => {
305                    ::core::option::Option::Some((#matches))
306                }
307                _ => ::core::option::Option::None
308            }
309        }
310
311        #[doc = #doc_val ]
312        #[inline]
313        pub fn #function_name_val(self) -> ::core::result::Result<#returns_val, Self> {
314            match self {
315                Self::#variant_name{ #matches } => {
316                    ::core::result::Result::Ok((#matches))
317                }
318                _ => ::core::result::Result::Err(self)
319            }
320        }
321
322        #[doc = #doc_val_unchecked ]
323        #[inline]
324        pub unsafe fn #function_name_val_unchecked(self) -> #returns_val {
325            match self {
326                Self::#variant_name{ #matches } => (#matches),
327                _ => std::hint::unreachable_unchecked(),
328            }
329        }
330
331        #[doc = #doc_ref_unchecked ]
332        #[inline]
333        pub unsafe fn #function_name_ref_unchecked(&self) -> #returns_ref {
334            match self {
335                Self::#variant_name{ #matches } => (#matches),
336                _ => std::hint::unreachable_unchecked(),
337            }
338        }
339
340        #[doc = #doc_mut_ref_unchecked ]
341        #[inline]
342        pub unsafe fn #function_name_mut_ref_unchecked(&mut self) -> #returns_mut_ref {
343            match self {
344                Self::#variant_name{ #matches } => (#matches),
345                _ => std::hint::unreachable_unchecked(),
346            }
347        }
348    )
349}
350
351fn impl_all_as_fns(ast: &DeriveInput) -> TokenStream {
352    let name = &ast.ident;
353    let generics = &ast.generics;
354
355    let enum_data = if let syn::Data::Enum(data) = &ast.data {
356        data
357    } else {
358        panic!("{} is not an enum", name);
359    };
360
361    let mut stream = TokenStream::new();
362
363    for variant_data in &enum_data.variants {
364        let variant_name = &variant_data.ident;
365        let function_name_ref = Ident::new(
366            &format!("as_{}", variant_name).to_snake_case(),
367            Span::call_site(),
368        );
369        let doc_ref = format!(
370            "Optionally returns references to the inner fields if this is a `{}::{}`, otherwise `None`",
371            name,
372            variant_name,
373        );
374        let function_name_mut_ref = Ident::new(
375            &format!("as_{}_mut", variant_name).to_snake_case(),
376            Span::call_site(),
377        );
378        let doc_mut_ref = format!(
379            "Optionally returns mutable references to the inner fields if this is a `{}::{}`, otherwise `None`",
380            name,
381            variant_name,
382        );
383
384        let function_name_val = Ident::new(
385            &format!("into_{}", variant_name).to_snake_case(),
386            Span::call_site(),
387        );
388        let doc_val = format!(
389            "Returns the inner fields if this is a `{}::{}`, otherwise returns back the enum in the `Err` case of the result",
390            name,
391            variant_name,
392        );
393
394        let function_name_is = Ident::new(
395            &format!("is_{}", variant_name).to_snake_case(),
396            Span::call_site(),
397        );
398        let doc_is = format!(
399            "Returns true if this is a `{}::{}`, otherwise false",
400            name, variant_name,
401        );
402
403        let function_name_val_unchecked = Ident::new(
404            &format!("into_{}_unchecked", variant_name).to_snake_case(),
405            Span::call_site(),
406        );
407        let doc_val_unchecked = format!(
408            r#"Unchecked return of the inner fields of `{}::{}`.
409# Safety
410Results in undefined behavior when it is the incorrect variant."#,
411            name, variant_name
412        );
413
414        let function_name_ref_unchecked = Ident::new(
415            &format!("as_{}_unchecked", variant_name).to_snake_case(),
416            Span::call_site(),
417        );
418        let doc_ref_unchecked = format!(
419            r#"Unchecked reference of the inner fields of `{}::{}`.
420# Safety
421Results in undefined behavior when it is the incorrect variant."#,
422            name, variant_name
423        );
424
425        let function_name_mut_ref_unchecked = Ident::new(
426            &format!("as_{}_mut_unchecked", variant_name).to_snake_case(),
427            Span::call_site(),
428        );
429        let doc_mut_ref_unchecked = format!(
430            r#"Unchecked mutable reference of the inner fields of `{}::{}`.
431# Safety
432Results in undefined behavior when it is the incorrect variant."#,
433            name, variant_name
434        );
435
436        let tokens = match &variant_data.fields {
437            syn::Fields::Unit => unit_fields_return(variant_name, &function_name_is, &doc_is),
438            syn::Fields::Unnamed(unnamed) => unnamed_fields_return(
439                variant_name,
440                (&function_name_is, &doc_is),
441                (&function_name_mut_ref, &doc_mut_ref),
442                (&function_name_ref, &doc_ref),
443                (&function_name_val, &doc_val),
444                (&function_name_val_unchecked, &doc_val_unchecked),
445                (&function_name_ref_unchecked, &doc_ref_unchecked),
446                (&function_name_mut_ref_unchecked, &doc_mut_ref_unchecked),
447                unnamed,
448            ),
449            syn::Fields::Named(named) => named_fields_return(
450                variant_name,
451                (&function_name_is, &doc_is),
452                (&function_name_mut_ref, &doc_mut_ref),
453                (&function_name_ref, &doc_ref),
454                (&function_name_val, &doc_val),
455                (&function_name_val_unchecked, &doc_val_unchecked),
456                (&function_name_ref_unchecked, &doc_ref_unchecked),
457                (&function_name_mut_ref_unchecked, &doc_mut_ref_unchecked),
458                named,
459            ),
460        };
461
462        stream.extend(tokens);
463    }
464
465    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
466
467    quote!(
468        impl #impl_generics #name #ty_generics #where_clause {
469            #stream
470        }
471    )
472}
473
474/// Derive functions on an Enum for easily accessing individual items in the Enum
475#[proc_macro_derive(EnumAsInner)]
476pub fn enum_as_inner(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
477    // get a usable token stream
478    let ast: DeriveInput = parse_macro_input!(input as DeriveInput);
479
480    // Build the impl
481    let expanded: TokenStream = impl_all_as_fns(&ast);
482
483    // Return the generated impl
484    proc_macro::TokenStream::from(expanded)
485}