derive_enum_rotate/
lib.rs1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use proc_macro_error::{abort, proc_macro_error};
4use quote::{quote, ToTokens};
5use syn::{parse_macro_input, Data, DeriveInput, Fields, Ident, Token};
6
7struct IterationOrder {
8 idents: Vec<Ident>,
9 idents_span: Span,
10}
11
12impl syn::parse::Parse for IterationOrder {
13 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
14 let attr_content;
15 input.parse::<Token![#]>()?;
16 syn::bracketed!(attr_content in input);
17 assert_eq!(attr_content.parse::<Ident>()?, "iteration_order");
18 let paren_content;
19
20 syn::parenthesized!(paren_content in attr_content);
22
23 let mut idents = vec![];
24 while !paren_content.is_empty() {
25 let ident = match paren_content.parse::<Ident>() {
26 Ok(ident) => ident,
27 Err(_) => abort!(paren_content.span(), "Expected identifier",),
28 };
29 idents.push(ident);
30 if !paren_content.is_empty() {
31 match paren_content.parse::<Token![,]>() {
32 Ok(_) => {}
33 Err(_) => abort!(paren_content.span(), "Expected comma (,)",),
34 };
35 }
36 }
37 Ok(Self {
38 idents,
39 idents_span: paren_content.span().into(),
40 })
41 }
42}
43
44#[proc_macro_error]
45#[proc_macro_derive(EnumRotate, attributes(iteration_order))]
46pub fn derive_enum_rotate(input: TokenStream) -> TokenStream {
47 let input = parse_macro_input!(input as DeriveInput);
48
49 let mut attr_iter = input.attrs.iter().filter(|a| {
50 a.path().segments.len() == 1 && a.path().segments[0].ident == "iteration_order"
51 });
52
53 let iteration_order: Option<IterationOrder> = attr_iter
54 .next()
55 .map(|a| syn::parse2(a.to_token_stream()).expect("Failed to parse"));
56
57 if let Some(repeated_attr) = attr_iter.next() {
59 abort!(
60 repeated_attr,
61 "Duplicate \"iteration_order\" attribute, please specify at most one iteration order",
62 );
63 }
64
65 let enum_data = match &input.data {
66 Data::Enum(data) => Ok(data),
67 Data::Struct(_) => Err("Struct"),
68 Data::Union(_) => Err("Union"),
69 };
70 let enum_data = match enum_data {
71 Ok(data) => data,
72 Err(item) => abort!(
73 input,
74 "{item} {} is not an enum, EnumRotate can only be derived for enums",
75 input.ident,
76 ),
77 };
78
79 let variants: Vec<_> = enum_data.variants.iter().collect();
80
81 if let Some(iteration_order) = &iteration_order {
83 let expected_len = variants.len();
84 let got_len = iteration_order.idents.len();
85 if got_len != expected_len {
86 abort!(
87 iteration_order.idents_span,
88 "Expected {} items in the iteration order but got {}",
89 expected_len, got_len;
90 note = "Enum `{}` has {} variants", input.ident, expected_len;
91 note = "Each variant should appear exactly once in the iteration order";
92 );
93 }
94
95 if let Some(invalid) = iteration_order
96 .idents
97 .iter()
98 .filter(|ident| !variants.iter().any(|var| var.ident == **ident))
99 .next()
100 {
101 abort!(
102 iteration_order.idents_span,
103 "Invalid variant for enum `{}`: {}",
104 input.ident, invalid;
105 note = "The iteration order can only contain variants of `{}`",
106 input.ident;
107 );
108 }
109
110 if let Some(missing) = variants
111 .iter()
112 .filter(|var| !iteration_order.idents.contains(&var.ident))
113 .next()
114 {
115 abort!(
116 iteration_order.idents_span,
117 "Variant {} not covered",
118 missing.ident;
119 note = "Each variant of `{}` should appear exactly once in the iteration order",
120 input.ident;
121 );
122 }
123 }
124
125 for variant in &variants {
127 if !matches!(variant.fields, Fields::Unit) {
128 abort!(
129 variant,
130 "Variant {} is not a unit variant, all variants must be unit variants to derive EnumRotate",
131 variant.ident,
132 );
133 }
134 }
135
136 let name = input.ident;
137 let tokens = if variants.is_empty() {
138 quote! {
140 impl ::enum_rotate::EnumRotate for #name {
141 fn next(&self) -> Self {
142 unsafe {
143 ::std::hint::unreachable_unchecked()
144 }
145 }
146
147 fn prev(&self) -> Self {
148 unsafe {
149 ::std::hint::unreachable_unchecked()
150 }
151 }
152
153 fn iter() -> impl Iterator<Item=Self> {
154 ::std::iter::empty()
155 }
156
157 fn iter_from(&self) -> impl Iterator<Item=Self> {
158 unsafe {
159 ::std::hint::unreachable_unchecked();
160 }
161 #[allow(unreachable_code)]
163 ::std::iter::empty()
164 }
165 }
166 }
167 } else {
168 let map_from = iteration_order
170 .map(|io| io.idents)
171 .unwrap_or_else(|| variants.iter().map(|var| var.ident.clone()).collect());
172 let map_to = {
173 let mut vec = map_from.clone();
174 vec.rotate_left(1);
175 vec
176 };
177
178 quote! {
179 impl ::enum_rotate::EnumRotate for #name {
180 fn next(&self) -> Self {
181 match self {
182 #( Self::#map_from => Self::#map_to, )*
183 }
184 }
185
186 fn prev(&self) -> Self {
187 match self {
188 #( Self::#map_to => Self::#map_from, )*
189 }
190 }
191
192 fn iter() -> impl Iterator<Item=Self> {
193 vec![ #( Self::#map_from ),* ].into_iter()
194 }
195
196 fn iter_from(&self) -> impl Iterator<Item=Self> {
197 let mut vars = vec![ #( Self::#map_from ),* ];
198 let index = vars.iter().position(|var| {
199 ::std::mem::discriminant(var) == ::std::mem::discriminant(self)
200 }).unwrap();
201
202 vars.rotate_left(index);
203 vars.into_iter()
204 }
205 }
206 }
207 };
208
209 tokens.into()
210}