enum-tree-derive 0.1.0

Derive macros for the enum-tree crate
Documentation
#![expect(clippy::needless_continue, reason = "originates in darling macro")]

use darling::FromVariant;
use proc_macro2::TokenStream;
use quote::{ToTokens as _, quote};
use syn::punctuated::Punctuated;
use syn::{DeriveInput, Path, Type};

#[derive(FromVariant, Default)]
#[darling(attributes(enum_tree), default)]
struct VariantAttrs
{
    skip: bool,
}

/// Parse a struct-level `#[enum_tree(crate = path)]` attribute, allowing
/// callers to re-export the runtime crate under a different name.
fn parse_crate_path(input: &DeriveInput) -> syn::Result<Path>
{
    for attr in &input.attrs {
        if !attr.path().is_ident("enum_tree") {
            continue;
        }
        let mut found: Option<Path> = None;
        attr.parse_nested_meta(|meta| {
            if meta.path.is_ident("crate") {
                let value = meta.value()?;
                let path: Path = value.parse()?;
                found = Some(path);
                Ok(())
            } else {
                Err(meta.error("unknown enum_tree option"))
            }
        })?;
        if let Some(path) = found {
            return Ok(path);
        }
    }
    Ok(syn::parse_quote!(::enum_tree))
}

pub fn derive_deep_variants(input: &DeriveInput) -> TokenStream
{
    let type_name = &input.ident;

    let crate_path = match parse_crate_path(input) {
        Ok(p) => p,
        Err(e) => return e.to_compile_error(),
    };

    let syn::Data::Enum(data) = &input.data else {
        return syn::Error::new_spanned(&input.ident, "DeepVariants can only be derived for enums")
            .to_compile_error();
    };

    let collected = match collect_variants(type_name, &data.variants) {
        Ok(c) => c,
        Err(e) => return e,
    };
    expand_impl(type_name, &crate_path, &collected)
}

/// Tokens collected from every non-skipped variant of the enum.
struct CollectedVariants
{
    /// Number of unit variants encountered.
    num_unit: usize,
    /// One `<#field_ty as DeepVariants>::DEEP_VARIANTS.len()` expression
    /// per singleton tuple variant.
    inner_deep_counts: Vec<TokenStream>,
    /// Statements that write each variant's contribution into `arr` at
    /// `idx` inside the generated `const` block.
    constructions: Vec<TokenStream>,
}

fn collect_variants(
    type_name: &syn::Ident,
    variants: &Punctuated<syn::Variant, syn::Token![,]>,
) -> Result<CollectedVariants, TokenStream>
{
    let mut out = CollectedVariants {
        num_unit: 0,
        inner_deep_counts: Vec::new(),
        constructions: Vec::new(),
    };

    for variant in variants {
        let attrs = VariantAttrs::from_variant(variant).map_err(darling::Error::write_errors)?;
        if attrs.skip {
            continue;
        }

        let variant_name = &variant.ident;
        let variant_qualname = quote! { #type_name::#variant_name };

        match &variant.fields {
            syn::Fields::Unit => {
                out.num_unit += 1;
                out.constructions
                    .push(extend_array_for_unit_variant(&variant_qualname));
            }
            syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
                let field_type = &fields.unnamed[0].ty;
                out.inner_deep_counts.push(quote! {
                    <#field_type as DeepVariants>::DEEP_VARIANTS.len()
                });
                out.constructions.push(extend_array_for_singleton_variant(
                    &variant_qualname,
                    field_type,
                ));
            }
            _ => {
                return Err(syn::Error::new_spanned(
                    variant,
                    "DeepVariants only supports unit variants and singleton tuple variants",
                )
                .to_compile_error());
            }
        }
    }
    Ok(out)
}

fn expand_impl(
    type_name: &syn::Ident,
    crate_path: &syn::Path,
    collected: &CollectedVariants,
) -> TokenStream
{
    let inner_deep_counts = &collected.inner_deep_counts;
    let constructions = &collected.constructions;
    let num_unit = collected.num_unit;

    let gen_deep_count = if inner_deep_counts.is_empty() {
        num_unit.to_token_stream()
    } else {
        quote! { #(#inner_deep_counts)+* + #num_unit }
    };

    quote! {
        impl #crate_path::DeepVariants for #type_name {
            const DEEP_VARIANTS: &'static [Self] = const {
                use #crate_path::DeepVariants;

                const DEEP_COUNT: ::core::primitive::usize = #gen_deep_count;

                const VARIANT_ARRAY: [#type_name; DEEP_COUNT] = const {
                    let mut arr: [::core::mem::MaybeUninit<#type_name>; DEEP_COUNT] =
                        [const { ::core::mem::MaybeUninit::uninit() }; DEEP_COUNT];
                    let mut idx = 0_usize;

                    #(#constructions)*

                    ::core::assert!(
                        idx == arr.len(),
                        "Logic error: not all enum variants have been initialized!"
                    );

                    // SAFETY: all array elements have been initialized, as checked
                    // by the above assertion.
                    unsafe { ::core::mem::transmute(arr) }
                };

                &VARIANT_ARRAY
            };
        }
    }
}

fn extend_array_for_unit_variant(variant: &TokenStream) -> TokenStream
{
    quote! {{
        arr[idx].write(#variant);
        idx += 1;
    }}
}

fn extend_array_for_singleton_variant(
    variant_constructor: &TokenStream,
    inner_type: &Type,
) -> TokenStream
{
    let expanded = quote! {
        #[allow(unused, reason = "the inner type may have no variants")]
        {
            let inner_variants = <#inner_type as DeepVariants>::DEEP_VARIANTS;

            let mut inner_idx = 0_usize;
            while inner_idx < inner_variants.len() {
                // We need to copy the value at compile time context,
                // but we don't want to require that `#inner_type` be `Copy`.
                //
                // Note that the `#inner_type` is known to be trivially copiable,
                // since it must be const constructible to implement
                // `DeepVariants` in the first place.
                //
                // Safety: Safe since `#inner_type` is trivially copiable by
                // the above. Note that neither the source nor the destination
                // values will ever be dropped as they are `const`, so we need
                // not worry if `#inner_type` implements a custom `Drop`.
                let inner_variant = unsafe {
                    ::core::ptr::read(&inner_variants[inner_idx])
                };
                arr[idx].write(#variant_constructor(inner_variant));
                inner_idx += 1;
                idx += 1;
            }
        }
    };
    expanded
}