derive_enum_rotate/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use proc_macro_error::{abort, proc_macro_error};
4use quote::{quote, ToTokens};
5use syn::{parse_macro_input, Data, DeriveInput, Fields, Ident, Token};
6
7struct IterationOrder {
8    idents: Vec<Ident>,
9    idents_span: Span,
10}
11
12impl syn::parse::Parse for IterationOrder {
13    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
14        let attr_content;
15        input.parse::<Token![#]>()?;
16        syn::bracketed!(attr_content in input);
17        assert_eq!(attr_content.parse::<Ident>()?, "iteration_order");
18        let paren_content;
19
20        // TODO: output a nice error message if this fails ("#[iteration_order]")
21        syn::parenthesized!(paren_content in attr_content);
22
23        let mut idents = vec![];
24        while !paren_content.is_empty() {
25            let ident = match paren_content.parse::<Ident>() {
26                Ok(ident) => ident,
27                Err(_) => abort!(paren_content.span(), "Expected identifier",),
28            };
29            idents.push(ident);
30            if !paren_content.is_empty() {
31                match paren_content.parse::<Token![,]>() {
32                    Ok(_) => {}
33                    Err(_) => abort!(paren_content.span(), "Expected comma (,)",),
34                };
35            }
36        }
37        Ok(Self {
38            idents,
39            idents_span: paren_content.span().into(),
40        })
41    }
42}
43
44#[proc_macro_error]
45#[proc_macro_derive(EnumRotate, attributes(iteration_order))]
46pub fn derive_enum_rotate(input: TokenStream) -> TokenStream {
47    let input = parse_macro_input!(input as DeriveInput);
48
49    let mut attr_iter = input.attrs.iter().filter(|a| {
50        a.path().segments.len() == 1 && a.path().segments[0].ident == "iteration_order"
51    });
52
53    let iteration_order: Option<IterationOrder> = attr_iter
54        .next()
55        .map(|a| syn::parse2(a.to_token_stream()).expect("Failed to parse"));
56
57    // Crash if multiple #[iteration_order(...)] attributes are present
58    if let Some(repeated_attr) = attr_iter.next() {
59        abort!(
60            repeated_attr,
61            "Duplicate \"iteration_order\" attribute, please specify at most one iteration order",
62        );
63    }
64
65    let enum_data = match &input.data {
66        Data::Enum(data) => Ok(data),
67        Data::Struct(_) => Err("Struct"),
68        Data::Union(_) => Err("Union"),
69    };
70    let enum_data = match enum_data {
71        Ok(data) => data,
72        Err(item) => abort!(
73            input,
74            "{item} {} is not an enum, EnumRotate can only be derived for enums",
75            input.ident,
76        ),
77    };
78
79    let variants: Vec<_> = enum_data.variants.iter().collect();
80
81    // Validate custom iteration order
82    if let Some(iteration_order) = &iteration_order {
83        let expected_len = variants.len();
84        let got_len = iteration_order.idents.len();
85        if got_len != expected_len {
86            abort!(
87                iteration_order.idents_span,
88                "Expected {} items in the iteration order but got {}",
89                expected_len, got_len;
90                note = "Enum `{}` has {} variants", input.ident, expected_len;
91                note = "Each variant should appear exactly once in the iteration order";
92            );
93        }
94
95        if let Some(invalid) = iteration_order
96            .idents
97            .iter()
98            .filter(|ident| !variants.iter().any(|var| var.ident == **ident))
99            .next()
100        {
101            abort!(
102                iteration_order.idents_span,
103                "Invalid variant for enum `{}`: {}",
104                input.ident, invalid;
105                note = "The iteration order can only contain variants of `{}`",
106                input.ident;
107            );
108        }
109
110        if let Some(missing) = variants
111            .iter()
112            .filter(|var| !iteration_order.idents.contains(&var.ident))
113            .next()
114        {
115            abort!(
116                iteration_order.idents_span,
117                "Variant {} not covered",
118                missing.ident;
119                note = "Each variant of `{}` should appear exactly once in the iteration order",
120                input.ident;
121            );
122        }
123    }
124
125    // TODO: support empty variants: A(), A {}
126    for variant in &variants {
127        if !matches!(variant.fields, Fields::Unit) {
128            abort!(
129                variant,
130                "Variant {} is not a unit variant, all variants must be unit variants to derive EnumRotate",
131                variant.ident,
132            );
133        }
134    }
135
136    let name = input.ident;
137    let tokens = if variants.is_empty() {
138        // Special case for empty enums
139        quote! {
140            impl ::enum_rotate::EnumRotate for #name {
141                fn next(&self) -> Self {
142                    unsafe {
143                        ::std::hint::unreachable_unchecked()
144                    }
145                }
146
147                fn prev(&self) -> Self {
148                    unsafe {
149                        ::std::hint::unreachable_unchecked()
150                    }
151                }
152
153                fn iter() -> impl Iterator<Item=Self> {
154                    ::std::iter::empty()
155                }
156
157                fn iter_from(&self) -> impl Iterator<Item=Self> {
158                    unsafe {
159                        ::std::hint::unreachable_unchecked();
160                    }
161                    // This is necessary because "() is not an iterator"
162                    #[allow(unreachable_code)]
163                    ::std::iter::empty()
164                }
165            }
166        }
167    } else {
168        // Base case for non-empty enums
169        let map_from = iteration_order
170            .map(|io| io.idents)
171            .unwrap_or_else(|| variants.iter().map(|var| var.ident.clone()).collect());
172        let map_to = {
173            let mut vec = map_from.clone();
174            vec.rotate_left(1);
175            vec
176        };
177
178        quote! {
179            impl ::enum_rotate::EnumRotate for #name {
180                fn next(&self) -> Self {
181                    match self {
182                        #( Self::#map_from => Self::#map_to, )*
183                    }
184                }
185
186                fn prev(&self) -> Self {
187                    match self {
188                        #( Self::#map_to => Self::#map_from, )*
189                    }
190                }
191
192                fn iter() -> impl Iterator<Item=Self> {
193                    vec![ #( Self::#map_from ),* ].into_iter()
194                }
195
196                fn iter_from(&self) -> impl Iterator<Item=Self> {
197                    let mut vars = vec![ #( Self::#map_from ),* ];
198                    let index = vars.iter().position(|var| {
199                        ::std::mem::discriminant(var) == ::std::mem::discriminant(self)
200                    }).unwrap();
201
202                    vars.rotate_left(index);
203                    vars.into_iter()
204                }
205            }
206        }
207    };
208
209    tokens.into()
210}