apollo-errors-derive 0.5.0

Proc macro for deriving apollo-errors::Error trait
Documentation
//! Shared helper functions for code generation

use heck::{ToKebabCase, ToLowerCamelCase, ToShoutySnakeCase, ToSnakeCase, ToUpperCamelCase};
use proc_macro2::TokenStream;
use quote::quote;

use crate::ir::FieldDefinition;

/// Default JSON-RPC error code when not specified (-32000 = "Server error")
pub(crate) const DEFAULT_JSONRPC_CODE: i32 = -32000;

/// Generate a token stream for an Option<String> field
///
/// Returns `Some("value")` or `None` as appropriate.
pub(crate) fn quote_option_str(opt: &Option<String>) -> TokenStream {
    match opt {
        Some(v) => quote! { Some(#v) },
        None => quote! { None },
    }
}

/// Generate downcast attempts for a method on the error registry.
///
/// When `with_config` is true, appends `config` as an argument to each call site
/// (used for format methods like `to_json`, `to_graphql`, etc.). When false, the
/// method is called without extra arguments (used for `http_status`, `http_headers`).
pub(crate) fn generate_downcast_attempts(
    name: &syn::Ident,
    method_name: &str,
    with_config: bool,
) -> TokenStream {
    let method_ident = syn::Ident::new(method_name, name.span());
    let config_arg = if with_config {
        quote! { , config }
    } else {
        quote! {}
    };
    quote! {
        if let Some(concrete) = error.downcast_ref::<#name>() {
            return Some(::apollo_errors::Error::#method_ident(concrete #config_arg));
        }
        if let Some(boxed) = error.downcast_ref::<Box<#name>>() {
            return Some(::apollo_errors::Error::#method_ident(boxed.as_ref() #config_arg));
        }
        if let Some(arc) = error.downcast_ref::<std::sync::Arc<#name>>() {
            return Some(::apollo_errors::Error::#method_ident(arc.as_ref() #config_arg));
        }
        if let Some(box_arc) = error.downcast_ref::<Box<std::sync::Arc<#name>>>() {
            return Some(::apollo_errors::Error::#method_ident(box_arc.as_ref().as_ref() #config_arg));
        }
        None
    }
}

/// Generate field metadata tokens for a single field
pub(crate) fn generate_field_metadata(field: &FieldDefinition) -> TokenStream {
    let rust_name = field.rust_name.to_string();
    let output_name = &field.output_name;
    let ty = &field.ty;
    let ty_str = quote!(#ty).to_string();
    let is_extension = field.is_extension;
    let http_header = quote_option_str(&field.http_header);

    let snake_case = output_name.to_snake_case();
    let camel_case = output_name.to_lower_camel_case();
    let pascal_case = output_name.to_upper_camel_case();
    let screaming_snake_case = output_name.to_shouty_snake_case();
    let kebab_case = output_name.to_kebab_case();

    quote! {
        ::apollo_errors::private::FieldMetadata {
            rust_name: #rust_name,
            output_name: #output_name,
            snake_case: #snake_case,
            camel_case: #camel_case,
            pascal_case: #pascal_case,
            screaming_snake_case: #screaming_snake_case,
            kebab_case: #kebab_case,
            ty: #ty_str,
            is_extension: #is_extension,
            http_header: #http_header,
        }
    }
}

/// Generate code to extract HTTP headers from fields
///
/// When `use_self_prefix` is true, generates `self.field` access (for structs).
/// When false, generates just `field` access (for enum match patterns).
pub(crate) fn generate_http_headers_body(
    fields: &[FieldDefinition],
    use_self_prefix: bool,
) -> TokenStream {
    let header_fields: Vec<_> = fields
        .iter()
        .filter_map(|f| f.http_header.as_ref().map(|h| (f, h)))
        .collect();

    if header_fields.is_empty() {
        quote! { Vec::new() }
    } else {
        let field_extractions: Vec<_> = header_fields
            .iter()
            .map(|(field, header_name)| {
                let rust_name = &field.rust_name;
                let header_name_lower = header_name.to_ascii_lowercase();

                let (field_access, field_ref) = if use_self_prefix {
                    (quote! { self.#rust_name }, quote! { &self.#rust_name })
                } else {
                    (quote! { #rust_name }, quote! { #rust_name })
                };

                if field.is_option {
                    quote! {
                        if let Some(__apollo_inner) = #field_access.as_ref() {
                            if let Some(__apollo_val) = ::apollo_errors::private::ToHeaderValue::to_header_value(__apollo_inner) {
                                __apollo_headers.push((
                                    ::apollo_errors::http::HeaderName::from_static(#header_name_lower),
                                    __apollo_val,
                                ));
                            }
                        }
                    }
                } else {
                    quote! {
                        if let Some(__apollo_val) = ::apollo_errors::private::ToHeaderValue::to_header_value(#field_ref) {
                            __apollo_headers.push((
                                ::apollo_errors::http::HeaderName::from_static(#header_name_lower),
                                __apollo_val,
                            ));
                        }
                    }
                }
            })
            .collect();

        quote! {
            {
                let mut __apollo_headers = Vec::new();
                #(#field_extractions)*
                __apollo_headers
            }
        }
    }
}

/// Generate a runtime `FieldCase` match that resolves the field key from pre-computed literals.
///
/// All case variants are derived from `output_name` (the rename value, or the Rust field name
/// when no rename is given), matching the behaviour of `FieldMetadata::name_for`.
pub(crate) fn generate_field_key(field: &FieldDefinition) -> TokenStream {
    let output_name = &field.output_name;

    let snake = output_name.to_snake_case();
    let camel = output_name.to_lower_camel_case();
    let pascal = output_name.to_upper_camel_case();
    let screaming = output_name.to_shouty_snake_case();
    let kebab = output_name.to_kebab_case();

    quote! {
        match config.field_case {
            ::apollo_errors::private::FieldCase::SnakeCase => #snake,
            ::apollo_errors::private::FieldCase::CamelCase => #camel,
            ::apollo_errors::private::FieldCase::PascalCase => #pascal,
            ::apollo_errors::private::FieldCase::ScreamingSnakeCase => #screaming,
            ::apollo_errors::private::FieldCase::KebabCase => #kebab,
        }
    }
}

/// Generate a runtime `CodeCase` match that resolves the error code from pre-computed literals.
pub(crate) fn generate_code_expr(code: &str) -> TokenStream {
    let code_base = code.replace("::", "_");
    let code_screaming = code_base.to_shouty_snake_case();
    let code_camel = code_base.to_lower_camel_case();
    let code_pascal = code_base.to_upper_camel_case();
    let code_kebab = code_base.to_kebab_case();

    quote! {
        match config.code_case {
            ::apollo_errors::private::CodeCase::Default => #code,
            ::apollo_errors::private::CodeCase::ScreamingSnakeCase => #code_screaming,
            ::apollo_errors::private::CodeCase::CamelCase => #code_camel,
            ::apollo_errors::private::CodeCase::PascalCase => #code_pascal,
            ::apollo_errors::private::CodeCase::KebabCase => #code_kebab,
        }
    }
}

pub(crate) fn precompute_code_variants(code: &str) -> (String, String, String, String) {
    let code_base = code.replace("::", "_");
    let code_screaming_snake = code_base.to_shouty_snake_case();
    let code_camel = code_base.to_lower_camel_case();
    let code_pascal = code_base.to_upper_camel_case();
    let code_kebab = code_base.to_kebab_case();
    (code_screaming_snake, code_camel, code_pascal, code_kebab)
}