pogo-masterfile-macros 0.1.0

Procedural derive macros backing the pogo-masterfile-types crate (AllVariants, AsStr, FromStrEnum).
Documentation
//! Procedural derive macros backing the [`pogo-masterfile-types`] crate.
//!
//! This crate is normally consumed transparently via re-exports from
//! `pogo-masterfile-types`. Direct dependency is fine but not required.

use proc_macro::TokenStream;
use quote::quote;
use syn::{Data, DeriveInput, Fields, parse_macro_input};

/// Derives `pub const ALL: [Self; N]` and `pub const SIZE: usize` for a
/// unit-only enum. Visibility of the constants follows the enum's own
/// visibility.
///
/// ```
/// use pogo_masterfile_macros::AllVariants;
///
/// #[derive(AllVariants)]
/// enum E { A, B, C }
///
/// assert_eq!(E::SIZE, 3);
/// ```
#[proc_macro_derive(AllVariants)]
pub fn derive_all_variants(input: TokenStream) -> TokenStream {
    let input = parse_macro_input!(input as DeriveInput);
    let name = &input.ident;
    let vis = &input.vis;

    let Data::Enum(data_enum) = &input.data else {
        return syn::Error::new_spanned(name, "AllVariants only applies to enums")
            .to_compile_error()
            .into();
    };

    let mut errors: Vec<syn::Error> = Vec::new();
    let mut variant_idents: Vec<&syn::Ident> = Vec::new();
    for v in &data_enum.variants {
        match &v.fields {
            Fields::Unit => variant_idents.push(&v.ident),
            _ => errors.push(syn::Error::new_spanned(
                v,
                "AllVariants requires all variants to be unit (no fields)",
            )),
        }
    }

    if !errors.is_empty() {
        let combined = errors
            .into_iter()
            .reduce(|mut a, b| {
                a.combine(b);
                a
            })
            .unwrap();
        return combined.to_compile_error().into();
    }

    let count = variant_idents.len();
    let qualified = variant_idents.iter().map(|v| quote! { #name::#v });

    quote! {
        impl #name {
            #vis const SIZE: usize = #count;
            #vis const ALL: [Self; #count] = [ #(#qualified),* ];
        }
    }
    .into()
}

/// Derives `pub const fn as_str(&self) -> &'static str` and `impl Display`
/// for a unit-only enum. Each variant's string is taken from a
/// `#[serde(rename = "...")]` attribute; if absent, falls back to
/// `stringify!(VariantIdent)`.
///
/// Declares `serde` as a helper attribute so `#[serde(rename = "...")]`
/// is syntactically recognized even when serde itself is not derived on
/// the same enum. Production use will always co-derive Serialize /
/// Deserialize too, but the macro stays usable in isolation (e.g. tests).
#[proc_macro_derive(AsStr, attributes(serde))]
pub fn derive_as_str(input: TokenStream) -> TokenStream {
    let input = parse_macro_input!(input as DeriveInput);
    let name = &input.ident;

    let Data::Enum(data_enum) = &input.data else {
        return syn::Error::new_spanned(name, "AsStr only applies to enums")
            .to_compile_error()
            .into();
    };

    let mut errors: Vec<syn::Error> = Vec::new();
    let mut arms: Vec<proc_macro2::TokenStream> = Vec::new();
    for v in &data_enum.variants {
        if !matches!(v.fields, Fields::Unit) {
            errors.push(syn::Error::new_spanned(
                v,
                "AsStr requires all variants to be unit (no fields)",
            ));
            continue;
        }
        let ident = &v.ident;
        let lit = match extract_serde_rename(&v.attrs) {
            Ok(Some(s)) => s,
            Ok(None) => ident.to_string(),
            Err(e) => {
                errors.push(e);
                continue;
            }
        };
        arms.push(quote! { Self::#ident => #lit });
    }

    if !errors.is_empty() {
        let combined = errors
            .into_iter()
            .reduce(|mut a, b| {
                a.combine(b);
                a
            })
            .unwrap();
        return combined.to_compile_error().into();
    }

    quote! {
        impl #name {
            pub const fn as_str(&self) -> &'static str {
                match self {
                    #(#arms),*
                }
            }
        }
        impl ::core::fmt::Display for #name {
            fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
                f.write_str(self.as_str())
            }
        }
    }
    .into()
}

/// Derives an inherent `template_id(&self) -> &str` method for an enum
/// whose every variant is a single-field tuple wrapping a struct with a
/// `template_id: String` field.
///
/// Eliminates boilerplate `match` arms when dispatching on a wide enum
/// (e.g. `MasterfileEntry`) to read the inner `template_id` — every arm
/// is mechanically `Self::Variant(e) => e.template_id.as_str()`.
///
/// # Requirements
///
/// - The type is an enum.
/// - Every variant is a single-field tuple variant: `Variant(Inner)`.
/// - The inner type has a `template_id` field of a type that exposes
///   `.as_str() -> &str` (i.e. `String` or `&str`).
///
/// # Example
///
/// ```ignore
/// use pogo_masterfile_macros::TemplateId;
///
/// struct Inner { template_id: String }
///
/// #[derive(TemplateId)]
/// enum E { A(Inner), B(Inner) }
///
/// let e = E::A(Inner { template_id: "X".into() });
/// assert_eq!(e.template_id(), "X");
/// ```
#[proc_macro_derive(TemplateId)]
pub fn derive_template_id(input: TokenStream) -> TokenStream {
    let input = parse_macro_input!(input as DeriveInput);
    let name = &input.ident;

    let Data::Enum(data_enum) = &input.data else {
        return syn::Error::new_spanned(name, "TemplateId only applies to enums")
            .to_compile_error()
            .into();
    };

    let mut errors: Vec<syn::Error> = Vec::new();
    let mut arms: Vec<proc_macro2::TokenStream> = Vec::new();

    for v in &data_enum.variants {
        let ident = &v.ident;
        match &v.fields {
            Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
                arms.push(quote! { Self::#ident(inner) => inner.template_id.as_str() });
            }
            _ => errors.push(syn::Error::new_spanned(
                v,
                "TemplateId requires every variant to be a single-field tuple wrapping a struct with `template_id: String`",
            )),
        }
    }

    if !errors.is_empty() {
        let combined = errors
            .into_iter()
            .reduce(|mut a, b| {
                a.combine(b);
                a
            })
            .unwrap();
        return combined.to_compile_error().into();
    }

    quote! {
        impl #name {
            /// Read the `template_id` of whichever variant `self` is.
            pub fn template_id(&self) -> &str {
                match self {
                    #(#arms),*
                }
            }
        }
    }
    .into()
}

/// Derives `impl FromStr` AND `impl TryFrom<&str>` for a unit-only enum.
/// Both share the same string-matching logic: `#[serde(rename = "...")]`
/// first, variant ident otherwise. The error type is
/// `pogo_masterfile_types::UnknownTemplateId` — the macro emits a path
/// reference; consumers must have that type in scope (which they do
/// transparently via the parent crate).
///
/// `TryFrom<&str>` is needed for callers using
/// `impl TryInto<TemplateId>`-style polymorphic input (e.g.
/// `pogo-masterfile`'s per-group accessor `get` method): the std blanket
/// `impl<T, U: TryFrom<T>> TryInto<U> for T` lets `&str` and a typed enum
/// both satisfy a single `I: TryInto<TemplateId>` bound at the call site.
#[proc_macro_derive(FromStrEnum, attributes(serde))]
pub fn derive_from_str_enum(input: TokenStream) -> TokenStream {
    let input = parse_macro_input!(input as DeriveInput);
    let name = &input.ident;

    let Data::Enum(data_enum) = &input.data else {
        return syn::Error::new_spanned(name, "FromStrEnum only applies to enums")
            .to_compile_error()
            .into();
    };

    let mut errors: Vec<syn::Error> = Vec::new();
    let mut arms: Vec<proc_macro2::TokenStream> = Vec::new();
    for v in &data_enum.variants {
        if !matches!(v.fields, Fields::Unit) {
            errors.push(syn::Error::new_spanned(
                v,
                "FromStrEnum requires all variants to be unit (no fields)",
            ));
            continue;
        }
        let ident = &v.ident;
        let lit = match extract_serde_rename(&v.attrs) {
            Ok(Some(s)) => s,
            Ok(None) => ident.to_string(),
            Err(e) => {
                errors.push(e);
                continue;
            }
        };
        arms.push(quote! { #lit => Ok(Self::#ident) });
    }

    if !errors.is_empty() {
        let combined = errors
            .into_iter()
            .reduce(|mut a, b| {
                a.combine(b);
                a
            })
            .unwrap();
        return combined.to_compile_error().into();
    }

    quote! {
        impl ::core::str::FromStr for #name {
            type Err = pogo_masterfile_types::UnknownTemplateId;
            fn from_str(s: &str) -> ::core::result::Result<Self, Self::Err> {
                match s {
                    #(#arms),*,
                    other => Err(pogo_masterfile_types::UnknownTemplateId(other.to_string())),
                }
            }
        }

        impl ::core::convert::TryFrom<&str> for #name {
            type Error = pogo_masterfile_types::UnknownTemplateId;
            fn try_from(s: &str) -> ::core::result::Result<Self, Self::Error> {
                <Self as ::core::str::FromStr>::from_str(s)
            }
        }
    }
    .into()
}

/// Look for `#[serde(rename = "...")]` on a variant. Returns the string
/// payload if found. Errors only on malformed serde attributes.
fn extract_serde_rename(attrs: &[syn::Attribute]) -> syn::Result<Option<String>> {
    for attr in attrs {
        if !attr.path().is_ident("serde") {
            continue;
        }
        let mut found: Option<String> = None;
        attr.parse_nested_meta(|meta| {
            if meta.path.is_ident("rename") {
                let value = meta.value()?;
                let s: syn::LitStr = value.parse()?;
                found = Some(s.value());
            } else {
                // Skip other serde meta items (e.g., `untagged`, `tag = "..."`).
                let _ = meta.input;
            }
            Ok(())
        })?;
        if let Some(s) = found {
            return Ok(Some(s));
        }
    }
    Ok(None)
}