apollo-errors-derive 0.5.0

Proc macro for deriving apollo-errors::Error trait
Documentation
//! Variant parsing

use syn::{Attribute, Fields, Lit, Result};

use crate::ir::{RegularVariantDefinition, TransparentVariantDefinition, VariantDefinition};

use super::diagnostic::parse_diagnostic_attrs;
use super::field::parse_field;
use super::field::parse_transparent_field;
use super::http_status::parse_http_status;
use super::jsonrpc_code::parse_jsonrpc_code;

/// Parse a single variant
pub(crate) fn parse_variant(variant: syn::Variant) -> Result<VariantDefinition> {
    // Check for #[diagnostic(transparent)]
    if is_transparent(&variant.attrs)? {
        return parse_transparent_variant(variant);
    }

    parse_regular_variant(variant)
}

/// Parse a regular (non-transparent) variant
fn parse_regular_variant(variant: syn::Variant) -> Result<VariantDefinition> {
    // Parse #[error("...")] attribute
    let error_message = parse_error_message(&variant.attrs)?.ok_or_else(|| {
        syn::Error::new_spanned(
            &variant,
            "variant must have #[error(\"message\")] attribute",
        )
    })?;

    // Parse #[diagnostic(...)] attributes
    let diagnostic_attrs = parse_diagnostic_attrs(&variant.attrs)?;

    let diagnostic_code = diagnostic_attrs.code.ok_or_else(|| {
        syn::Error::new_spanned(
            &variant,
            "variant must have #[diagnostic(code(...))] attribute",
        )
    })?;

    let help_text = diagnostic_attrs.help;
    let url = diagnostic_attrs.url;
    let severity = diagnostic_attrs.severity;

    // Parse #[http_status(...)] attribute
    let http_status = parse_http_status(&variant.attrs)?;

    // Parse #[jsonrpc_code(...)] attribute
    let jsonrpc_code = parse_jsonrpc_code(&variant.attrs)?;

    // Parse fields
    let fields = match variant.fields {
        Fields::Named(fields) => fields
            .named
            .into_iter()
            .map(parse_field)
            .collect::<Result<Vec<_>>>()?,
        Fields::Unnamed(_) => {
            return Err(syn::Error::new_spanned(
                variant,
                "tuple variants are not supported, use named fields instead",
            ));
        }
        Fields::Unit => Vec::new(),
    };

    // Validate at most one #[from] field
    let from_count = fields.iter().filter(|f| f.is_from).count();
    if from_count > 1 {
        return Err(syn::Error::new_spanned(
            &variant.ident,
            "variant can have at most one #[from] field",
        ));
    }

    Ok(VariantDefinition::Regular(RegularVariantDefinition {
        name: variant.ident,
        error_message,
        diagnostic_code,
        help_text,
        url,
        severity,
        http_status,
        jsonrpc_code,
        fields,
    }))
}

/// Parse a transparent variant
fn parse_transparent_variant(variant: syn::Variant) -> Result<VariantDefinition> {
    // Transparent variants must have exactly one unnamed field
    let field = match variant.fields {
        Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
            parse_transparent_field(fields.unnamed.into_iter().next().unwrap())?
        }
        _ => {
            return Err(syn::Error::new_spanned(
                &variant,
                "transparent variant must have exactly one unnamed field, e.g., `Variant(InnerError)`",
            ));
        }
    };

    Ok(VariantDefinition::Transparent(
        TransparentVariantDefinition {
            name: variant.ident,
            field,
        },
    ))
}

/// Check if the variant has #[diagnostic(transparent)]
fn is_transparent(attrs: &[Attribute]) -> Result<bool> {
    for attr in attrs {
        if attr.path().is_ident("diagnostic") {
            // Try to parse as an ident (for "transparent")
            if let Ok(ident) = attr.parse_args::<syn::Ident>()
                && ident == "transparent"
            {
                return Ok(true);
            }
        }
    }
    Ok(false)
}

/// Parse #[error("message")] attribute
fn parse_error_message(attrs: &[Attribute]) -> Result<Option<String>> {
    for attr in attrs {
        if attr.path().is_ident("error") {
            // Parse the string literal directly from the attribute
            let lit: Lit = attr.parse_args()?;
            if let Lit::Str(s) = lit {
                return Ok(Some(s.value()));
            }
        }
    }
    Ok(None)
}