#![expect(clippy::needless_continue, reason = "originates in darling macro")]
use darling::FromVariant;
use proc_macro2::TokenStream;
use quote::{ToTokens as _, quote};
use syn::punctuated::Punctuated;
use syn::{DeriveInput, Path, Type};
#[derive(FromVariant, Default)]
#[darling(attributes(enum_tree), default)]
struct VariantAttrs
{
skip: bool,
}
fn parse_crate_path(input: &DeriveInput) -> syn::Result<Path>
{
for attr in &input.attrs {
if !attr.path().is_ident("enum_tree") {
continue;
}
let mut found: Option<Path> = None;
attr.parse_nested_meta(|meta| {
if meta.path.is_ident("crate") {
let value = meta.value()?;
let path: Path = value.parse()?;
found = Some(path);
Ok(())
} else {
Err(meta.error("unknown enum_tree option"))
}
})?;
if let Some(path) = found {
return Ok(path);
}
}
Ok(syn::parse_quote!(::enum_tree))
}
pub fn derive_deep_variants(input: &DeriveInput) -> TokenStream
{
let type_name = &input.ident;
let crate_path = match parse_crate_path(input) {
Ok(p) => p,
Err(e) => return e.to_compile_error(),
};
let syn::Data::Enum(data) = &input.data else {
return syn::Error::new_spanned(&input.ident, "DeepVariants can only be derived for enums")
.to_compile_error();
};
let collected = match collect_variants(type_name, &data.variants) {
Ok(c) => c,
Err(e) => return e,
};
expand_impl(type_name, &crate_path, &collected)
}
struct CollectedVariants
{
num_unit: usize,
inner_deep_counts: Vec<TokenStream>,
constructions: Vec<TokenStream>,
}
fn collect_variants(
type_name: &syn::Ident,
variants: &Punctuated<syn::Variant, syn::Token![,]>,
) -> Result<CollectedVariants, TokenStream>
{
let mut out = CollectedVariants {
num_unit: 0,
inner_deep_counts: Vec::new(),
constructions: Vec::new(),
};
for variant in variants {
let attrs = VariantAttrs::from_variant(variant).map_err(darling::Error::write_errors)?;
if attrs.skip {
continue;
}
let variant_name = &variant.ident;
let variant_qualname = quote! { #type_name::#variant_name };
match &variant.fields {
syn::Fields::Unit => {
out.num_unit += 1;
out.constructions
.push(extend_array_for_unit_variant(&variant_qualname));
}
syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
let field_type = &fields.unnamed[0].ty;
out.inner_deep_counts.push(quote! {
<#field_type as DeepVariants>::DEEP_VARIANTS.len()
});
out.constructions.push(extend_array_for_singleton_variant(
&variant_qualname,
field_type,
));
}
_ => {
return Err(syn::Error::new_spanned(
variant,
"DeepVariants only supports unit variants and singleton tuple variants",
)
.to_compile_error());
}
}
}
Ok(out)
}
fn expand_impl(
type_name: &syn::Ident,
crate_path: &syn::Path,
collected: &CollectedVariants,
) -> TokenStream
{
let inner_deep_counts = &collected.inner_deep_counts;
let constructions = &collected.constructions;
let num_unit = collected.num_unit;
let gen_deep_count = if inner_deep_counts.is_empty() {
num_unit.to_token_stream()
} else {
quote! { #(#inner_deep_counts)+* + #num_unit }
};
quote! {
impl #crate_path::DeepVariants for #type_name {
const DEEP_VARIANTS: &'static [Self] = const {
use #crate_path::DeepVariants;
const DEEP_COUNT: ::core::primitive::usize = #gen_deep_count;
const VARIANT_ARRAY: [#type_name; DEEP_COUNT] = const {
let mut arr: [::core::mem::MaybeUninit<#type_name>; DEEP_COUNT] =
[const { ::core::mem::MaybeUninit::uninit() }; DEEP_COUNT];
let mut idx = 0_usize;
#(#constructions)*
::core::assert!(
idx == arr.len(),
"Logic error: not all enum variants have been initialized!"
);
unsafe { ::core::mem::transmute(arr) }
};
&VARIANT_ARRAY
};
}
}
}
fn extend_array_for_unit_variant(variant: &TokenStream) -> TokenStream
{
quote! {{
arr[idx].write(#variant);
idx += 1;
}}
}
fn extend_array_for_singleton_variant(
variant_constructor: &TokenStream,
inner_type: &Type,
) -> TokenStream
{
let expanded = quote! {
#[allow(unused, reason = "the inner type may have no variants")]
{
let inner_variants = <#inner_type as DeepVariants>::DEEP_VARIANTS;
let mut inner_idx = 0_usize;
while inner_idx < inner_variants.len() {
let inner_variant = unsafe {
::core::ptr::read(&inner_variants[inner_idx])
};
arr[idx].write(#variant_constructor(inner_variant));
inner_idx += 1;
idx += 1;
}
}
};
expanded
}