Skip to main content

optimizer_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{Data, DeriveInput, Fields, parse_macro_input};
4
5/// Derive macro for the `Categorical` trait on fieldless enums.
6///
7/// Generates an implementation of `optimizer::parameter::Categorical` that maps
8/// enum variants to/from sequential indices.
9///
10/// # Example
11///
12/// ```ignore
13/// use optimizer::parameter::Categorical;
14///
15/// #[derive(Clone, Categorical)]
16/// enum Color {
17///     Red,
18///     Green,
19///     Blue,
20/// }
21/// ```
22#[proc_macro_derive(Categorical)]
23pub fn derive_categorical(input: TokenStream) -> TokenStream {
24    let input = parse_macro_input!(input as DeriveInput);
25    let name = &input.ident;
26
27    let Data::Enum(data_enum) = &input.data else {
28        return syn::Error::new_spanned(&input, "Categorical can only be derived for enums")
29            .to_compile_error()
30            .into();
31    };
32
33    // Validate all variants are fieldless
34    for variant in &data_enum.variants {
35        if !matches!(variant.fields, Fields::Unit) {
36            return syn::Error::new_spanned(
37                variant,
38                "Categorical can only be derived for enums with unit variants (no fields)",
39            )
40            .to_compile_error()
41            .into();
42        }
43    }
44
45    let n_choices = data_enum.variants.len();
46    let variant_names: Vec<_> = data_enum.variants.iter().map(|v| &v.ident).collect();
47    let indices: Vec<usize> = (0..n_choices).collect();
48
49    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
50
51    let expanded = quote! {
52        impl #impl_generics optimizer::parameter::Categorical for #name #ty_generics #where_clause {
53            const N_CHOICES: usize = #n_choices;
54
55            fn from_index(index: usize) -> Self {
56                match index {
57                    #(#indices => #name::#variant_names,)*
58                    _ => panic!("invalid index {} for {} with {} variants", index, stringify!(#name), #n_choices),
59                }
60            }
61
62            fn to_index(&self) -> usize {
63                match self {
64                    #(#name::#variant_names => #indices,)*
65                }
66            }
67        }
68    };
69
70    expanded.into()
71}