1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{Data, DeriveInput, Fields, parse_macro_input};
4
5#[proc_macro_derive(Categorical)]
23pub fn derive_categorical(input: TokenStream) -> TokenStream {
24 let input = parse_macro_input!(input as DeriveInput);
25 let name = &input.ident;
26
27 let Data::Enum(data_enum) = &input.data else {
28 return syn::Error::new_spanned(&input, "Categorical can only be derived for enums")
29 .to_compile_error()
30 .into();
31 };
32
33 for variant in &data_enum.variants {
35 if !matches!(variant.fields, Fields::Unit) {
36 return syn::Error::new_spanned(
37 variant,
38 "Categorical can only be derived for enums with unit variants (no fields)",
39 )
40 .to_compile_error()
41 .into();
42 }
43 }
44
45 let n_choices = data_enum.variants.len();
46 let variant_names: Vec<_> = data_enum.variants.iter().map(|v| &v.ident).collect();
47 let indices: Vec<usize> = (0..n_choices).collect();
48
49 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
50
51 let expanded = quote! {
52 impl #impl_generics optimizer::parameter::Categorical for #name #ty_generics #where_clause {
53 const N_CHOICES: usize = #n_choices;
54
55 fn from_index(index: usize) -> Self {
56 match index {
57 #(#indices => #name::#variant_names,)*
58 _ => panic!("invalid index {} for {} with {} variants", index, stringify!(#name), #n_choices),
59 }
60 }
61
62 fn to_index(&self) -> usize {
63 match self {
64 #(#name::#variant_names => #indices,)*
65 }
66 }
67 }
68 };
69
70 expanded.into()
71}