1use std::ops::{Range, RangeInclusive};
2
3use failure::format_err;
4use proc_macro2::{Literal, TokenStream};
5use quote::quote;
6
7use crate::attr::{Discriminant, Enum, ErrorList};
8
9enum IterImpl {
10 Empty,
11 Range {
12 repr: syn::Path,
13 range: Range<Discriminant>,
14 },
15 RangeInclusive {
16 repr: syn::Path,
17 range: RangeInclusive<Discriminant>,
18 },
19 Slice(Vec<TokenStream>),
20}
21
22impl IterImpl {
23 fn for_enum(Enum { name, variants, discriminants, primitive_repr, .. }: &Enum) -> Result<Self, ErrorList> {
29 if let Some(discriminants) = discriminants {
31 let is_zst = discriminants.len() <= 1;
32
33 if let Ok(Some((repr, repr_path))) = primitive_repr {
34 let unskipped_discriminants: Vec<_> = discriminants
35 .iter()
36 .cloned()
37 .zip(variants.iter())
38 .filter(|(_, (_, attr))| !attr.skip)
39 .map(|(d, _)| d)
40 .collect();
41
42 if unskipped_discriminants.is_empty() {
43 return Ok(IterImpl::Empty);
44 }
45
46 if !is_zst {
47 if let Some(range) = detect_contiguous_run(unskipped_discriminants.into_iter()) {
48 let end = *range.end();
51 if end < 0 || repr.max_value().map_or(false, |max| (end as u128) < max) {
52 return Ok(IterImpl::Range {
53 repr: repr_path.clone(),
54 range: *range.start()..(end + 1),
55 })
56 }
57
58 return Ok(IterImpl::RangeInclusive {
59 repr: repr_path.clone(),
60 range,
61 })
62 }
63 }
64 }
65 }
66
67 let mut errors = ErrorList::new();
69 let unskipped_variants: Vec<_> = variants
70 .iter()
71 .filter_map(|(v, attr)| {
72 if attr.skip {
73 return None;
74 }
75
76 if v.fields != syn::Fields::Unit {
77 errors.push_back(format_err!("An (unskipped) variant cannot have fields"));
78 return None;
79 }
80
81 let vident = &v.ident;
82 Some(quote!(#name::#vident))
83 })
84 .collect();
85
86 if !errors.is_empty() {
87 return Err(errors);
88 }
89
90 if unskipped_variants.is_empty() {
91 return Ok(IterImpl::Empty);
92 }
93
94 Ok(IterImpl::Slice(unskipped_variants))
95 }
96
97 fn tokens(&self, ty: &syn::Ident) -> TokenStream {
98 let body = match self {
99 IterImpl::Empty => quote! {
100 ::std::iter::empty()
101 },
102
103 IterImpl::Range { range, repr } => {
104 let start = Literal::i128_unsuffixed(range.start);
105 let end = Literal::i128_unsuffixed(range.end);
106
107 quote! {
108 let start: #repr = #start;
109 let end: #repr = #end;
110 (start .. end).map(|discrim| unsafe { ::std::mem::transmute(discrim) })
111 }
112 },
113
114 IterImpl::RangeInclusive { range, repr } => {
115 let start = Literal::i128_unsuffixed(*range.start());
116 let end = Literal::i128_unsuffixed(*range.end());
117 quote! {
118 let start: #repr = #start;
119 let end: #repr = #end;
120 (start ..= end).map(|discrim| unsafe { ::std::mem::transmute(discrim) })
121 }
122 },
123
124 IterImpl::Slice(variants) => quote! {
125 const VARIANTS: &[#ty] = &[#( #variants ),*];
126
127 VARIANTS.iter().cloned()
128 },
129 };
130
131 quote! {
132 impl #ty {
133 fn iter() -> impl Iterator<Item = #ty> + Clone {
134 #body
135 }
136 }
137 }
138 }
139}
140
141fn detect_contiguous_run(mut discriminants: impl Iterator<Item = Discriminant>)
144 -> Option<RangeInclusive<Discriminant>>
145{
146 let first = discriminants.next()?;
147
148 let mut last = first;
149 while let Some(next) = discriminants.next() {
150 if last.checked_add(1)? != next {
151 return None;
152 }
153
154 last = next
155 }
156
157 Some(first..=last)
158}
159
160pub fn derive(input: &syn::DeriveInput) -> Result<TokenStream, ErrorList> {
161 let input = Enum::parse(input)?;
162 let imp = IterImpl::for_enum(&input)?;
163 Ok(imp.tokens(&input.name))
164}