enum_utils/
iter.rs

1use std::ops::{Range, RangeInclusive};
2
3use failure::format_err;
4use proc_macro2::{Literal, TokenStream};
5use quote::quote;
6
7use crate::attr::{Discriminant, Enum, ErrorList};
8
9enum IterImpl {
10    Empty,
11    Range {
12        repr: syn::Path,
13        range: Range<Discriminant>,
14    },
15    RangeInclusive {
16        repr: syn::Path,
17        range: RangeInclusive<Discriminant>,
18    },
19    Slice(Vec<TokenStream>),
20}
21
22impl IterImpl {
23    /// Constructs the fastest `IterImpl` for the given set of discriminants.
24    ///
25    /// If the discriminants form a single, contiguous, increasing run, we will create a
26    /// `Range` (or `RangeInclusive`) containing the discriminants as the `#[repr(...)]` of the
27    /// enum.
28    fn for_enum(Enum { name, variants, discriminants, primitive_repr, .. }: &Enum) -> Result<Self, ErrorList> {
29        // See if we can generate a fast, transmute-based iterator.
30        if let Some(discriminants) = discriminants {
31            let is_zst = discriminants.len() <= 1;
32
33            if let Ok(Some((repr, repr_path))) = primitive_repr {
34                let unskipped_discriminants: Vec<_> = discriminants
35                    .iter()
36                    .cloned()
37                    .zip(variants.iter())
38                    .filter(|(_, (_, attr))| !attr.skip)
39                    .map(|(d, _)| d)
40                    .collect();
41
42                if unskipped_discriminants.is_empty() {
43                    return Ok(IterImpl::Empty);
44                }
45
46                if !is_zst {
47                    if let Some(range) = detect_contiguous_run(unskipped_discriminants.into_iter()) {
48                        // If range.end() is less than the maximum value of the primitive repr, we can
49                        // use the (faster) non-inclusive `Range`
50                        let end = *range.end();
51                        if end < 0 || repr.max_value().map_or(false, |max| (end as u128) < max) {
52                            return Ok(IterImpl::Range {
53                                repr: repr_path.clone(),
54                                range: *range.start()..(end + 1),
55                            })
56                        }
57
58                        return Ok(IterImpl::RangeInclusive {
59                            repr: repr_path.clone(),
60                            range,
61                        })
62                    }
63                }
64            }
65        }
66
67        // ...if not, fall back to the slice based one.
68        let mut errors = ErrorList::new();
69        let unskipped_variants: Vec<_> = variants
70            .iter()
71            .filter_map(|(v, attr)| {
72                if attr.skip {
73                    return None;
74                }
75
76                if v.fields != syn::Fields::Unit {
77                    errors.push_back(format_err!("An (unskipped) variant cannot have fields"));
78                    return None;
79                }
80
81                let vident = &v.ident;
82                Some(quote!(#name::#vident))
83            })
84            .collect();
85
86        if !errors.is_empty() {
87            return Err(errors);
88        }
89
90        if unskipped_variants.is_empty() {
91            return Ok(IterImpl::Empty);
92        }
93
94        Ok(IterImpl::Slice(unskipped_variants))
95    }
96
97    fn tokens(&self, ty: &syn::Ident) -> TokenStream {
98        let body = match self {
99            IterImpl::Empty => quote! {
100                ::std::iter::empty()
101            },
102
103            IterImpl::Range { range, repr } => {
104                let start = Literal::i128_unsuffixed(range.start);
105                let end = Literal::i128_unsuffixed(range.end);
106
107                quote! {
108                    let start: #repr = #start;
109                    let end: #repr = #end;
110                    (start .. end).map(|discrim| unsafe { ::std::mem::transmute(discrim) })
111                }
112            },
113
114            IterImpl::RangeInclusive { range, repr } => {
115                let start = Literal::i128_unsuffixed(*range.start());
116                let end = Literal::i128_unsuffixed(*range.end());
117                quote! {
118                    let start: #repr = #start;
119                    let end: #repr = #end;
120                    (start ..= end).map(|discrim| unsafe { ::std::mem::transmute(discrim) })
121                }
122            },
123
124            IterImpl::Slice(variants) => quote! {
125                const VARIANTS: &[#ty] = &[#( #variants ),*];
126
127                VARIANTS.iter().cloned()
128            },
129        };
130
131        quote! {
132            impl #ty {
133                fn iter() -> impl Iterator<Item = #ty> + Clone {
134                    #body
135                }
136            }
137        }
138    }
139}
140
141/// Returns a range containing the discriminants of this enum if they comprise a single, contiguous
142/// run. Returns `None` if there were no discriminants or they were not contiguous.
143fn detect_contiguous_run(mut discriminants: impl Iterator<Item = Discriminant>)
144    -> Option<RangeInclusive<Discriminant>>
145{
146    let first = discriminants.next()?;
147
148    let mut last = first;
149    while let Some(next) = discriminants.next() {
150        if last.checked_add(1)? != next {
151            return None;
152        }
153
154        last = next
155    }
156
157    Some(first..=last)
158}
159
160pub fn derive(input: &syn::DeriveInput) -> Result<TokenStream, ErrorList> {
161    let input = Enum::parse(input)?;
162    let imp = IterImpl::for_enum(&input)?;
163    Ok(imp.tokens(&input.name))
164}