server-less-macros 0.6.0

Proc macros for server-less
Documentation
//! Error derive macro for generating IntoErrorCode implementations.
//!
//! ```ignore
//! #[derive(ServerlessError)]
//! enum MyError {
//!     #[error(code = NotFound, message = "User not found")]
//!     UserNotFound,
//!     #[error(code = InvalidInput)]
//!     ValidationFailed(String),
//!     // Code inferred from variant name
//!     Unauthorized,
//! }
//! ```
//!
//! # Conflict with `thiserror`
//!
//! `ServerlessError` uses `#[error(...)]` as its per-variant attribute, which **conflicts
//! with `thiserror`'s `#[error("...")]`** — both crates register an attribute with the
//! same name.  If you add `#[derive(thiserror::Error)]` to the same enum, the compiler
//! will see duplicate or mis-parsed `#[error]` attributes.
//!
//! **Recommendation:** remove `#[derive(thiserror::Error)]` when using `ServerlessError`.
//! `ServerlessError` already derives `Display` and `Error` for you via the generated code.

use proc_macro2::TokenStream;
use quote::quote;
use syn::{Data, DeriveInput, Fields, Ident, Token, parse::Parse, punctuated::Punctuated};

/// Arguments for the #[error(...)] attribute on variants
#[derive(Default)]
struct ErrorVariantArgs {
    /// Error code (e.g., NotFound, InvalidInput, or numeric 404)
    code: Option<ErrorCodeSpec>,
    /// Custom message
    message: Option<String>,
    /// JSON-RPC numeric error code override (e.g. -32602)
    jsonrpc_code: Option<i32>,
}

enum ErrorCodeSpec {
    /// Named error code: NotFound, InvalidInput, etc.
    Named(Ident),
    /// Numeric HTTP status: 404, 500, etc.
    Numeric(u16),
}

impl Parse for ErrorVariantArgs {
    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
        let mut args = ErrorVariantArgs::default();

        let pairs = Punctuated::<syn::Meta, Token![,]>::parse_terminated(input)?;

        for meta in pairs {
            match meta {
                syn::Meta::NameValue(nv) if nv.path.is_ident("code") => {
                    // code = NotFound or code = 404
                    match &nv.value {
                        syn::Expr::Path(path) => {
                            if let Some(ident) = path.path.get_ident() {
                                args.code = Some(ErrorCodeSpec::Named(ident.clone()));
                            }
                        }
                        syn::Expr::Lit(syn::ExprLit {
                            lit: syn::Lit::Int(lit),
                            ..
                        }) => {
                            let value: u16 = lit.base10_parse()?;
                            args.code = Some(ErrorCodeSpec::Numeric(value));
                        }
                        _ => {
                            return Err(syn::Error::new_spanned(
                                &nv.value,
                                "expected error code name or HTTP status\n\
                                 \n\
                                 Valid names: NotFound, InvalidInput, Unauthenticated, Forbidden, Internal\n\
                                 Or use HTTP status: IANA 4xx/5xx plus Cloudflare/nginx/IIS vendor codes.\n\
                                 \n\
                                 Example: #[error(code = NotFound)]",
                            ));
                        }
                    }
                }
                syn::Meta::NameValue(nv) if nv.path.is_ident("message") => {
                    if let syn::Expr::Lit(syn::ExprLit {
                        lit: syn::Lit::Str(s),
                        ..
                    }) = &nv.value
                    {
                        args.message = Some(s.value());
                    } else {
                        return Err(syn::Error::new_spanned(
                            &nv.value,
                            "message must be a string literal\n\
                             \n\
                             Example: #[error(code = NotFound, message = \"Resource not found\")]",
                        ));
                    }
                }
                syn::Meta::NameValue(nv) if nv.path.is_ident("jsonrpc_code") => {
                    // jsonrpc_code = -32602 (negative integer literal)
                    let parsed_code: i32 = match &nv.value {
                        syn::Expr::Unary(syn::ExprUnary {
                            op: syn::UnOp::Neg(_),
                            expr,
                            ..
                        }) => {
                            if let syn::Expr::Lit(syn::ExprLit {
                                lit: syn::Lit::Int(lit),
                                ..
                            }) = expr.as_ref()
                            {
                                let val: i32 = lit.base10_parse()?;
                                -val
                            } else {
                                return Err(syn::Error::new_spanned(
                                    &nv.value,
                                    "jsonrpc_code must be an integer (e.g. -32602)",
                                ));
                            }
                        }
                        syn::Expr::Lit(syn::ExprLit {
                            lit: syn::Lit::Int(lit),
                            ..
                        }) => lit.base10_parse()?,
                        _ => {
                            return Err(syn::Error::new_spanned(
                                &nv.value,
                                "jsonrpc_code must be an integer\n\
                                 \n\
                                 Example: #[error(jsonrpc_code = -32602)]",
                            ));
                        }
                    };
                    args.jsonrpc_code = Some(parsed_code);
                }
                other => {
                    return Err(syn::Error::new_spanned(
                        other,
                        "unknown attribute. Valid: code, message, jsonrpc_code",
                    ));
                }
            }
        }

        Ok(args)
    }
}

/// Expand the ServerlessError derive macro
pub fn expand_serverless_error(input: DeriveInput) -> syn::Result<TokenStream> {
    let name = &input.ident;

    let Data::Enum(data_enum) = &input.data else {
        return Err(syn::Error::new_spanned(
            &input,
            "ServerlessError can only be derived for enums\n\
             \n\
             Hint: Define your errors as an enum:\n\
             \n\
             #[derive(Debug, ServerlessError)]\n\
             enum MyError {{\n\
                 #[error(code = NotFound)]\n\
                 NotFound,\n\
             }}\n\
             \n\
             ServerlessError maps to HTTP status codes (e.g. code = 404) and \
             JSON-RPC error codes automatically across protocols.",
        ));
    };

    let mut error_code_arms = Vec::new();
    let mut jsonrpc_code_arms = Vec::new();
    let mut message_arms = Vec::new();
    let mut display_arms = Vec::new();

    for variant in &data_enum.variants {
        let variant_name = &variant.ident;
        let variant_name_str = variant_name.to_string();

        // Parse #[error(...)] attribute if present
        let args = variant
            .attrs
            .iter()
            .find(|attr| attr.path().is_ident("error"))
            .map(|attr| attr.parse_args::<ErrorVariantArgs>())
            .transpose()?
            .unwrap_or_default();

        // Determine error code
        let error_code = match args.code {
            Some(ErrorCodeSpec::Named(ident)) => {
                quote! { ::server_less::ErrorCode::#ident }
            }
            Some(ErrorCodeSpec::Numeric(status)) => {
                // Map HTTP status codes (IANA standard + vendor extensions) to ErrorCode.
                // Unknown codes produce a compile error — use a named ErrorCode variant instead.
                let span = variant_name.span();
                match status {
                    // ── Standard IANA 4xx ──────────────────────────────────────────
                    400 => quote! { ::server_less::ErrorCode::InvalidInput },
                    401 => quote! { ::server_less::ErrorCode::Unauthenticated },
                    402 => quote! { ::server_less::ErrorCode::Internal },
                    403 => quote! { ::server_less::ErrorCode::Forbidden },
                    404 => quote! { ::server_less::ErrorCode::NotFound },
                    405 => quote! { ::server_less::ErrorCode::InvalidInput },
                    406 => quote! { ::server_less::ErrorCode::InvalidInput },
                    407 => quote! { ::server_less::ErrorCode::Unauthenticated },
                    408 => quote! { ::server_less::ErrorCode::Unavailable },
                    409 => quote! { ::server_less::ErrorCode::Conflict },
                    410 => quote! { ::server_less::ErrorCode::NotFound },
                    411 => quote! { ::server_less::ErrorCode::InvalidInput },
                    412 => quote! { ::server_less::ErrorCode::UnprocessableEntity },
                    413 => quote! { ::server_less::ErrorCode::InvalidInput },
                    414 => quote! { ::server_less::ErrorCode::InvalidInput },
                    415 => quote! { ::server_less::ErrorCode::InvalidInput },
                    416 => quote! { ::server_less::ErrorCode::InvalidInput },
                    417 => quote! { ::server_less::ErrorCode::InvalidInput },
                    418 => quote! { ::server_less::ErrorCode::Internal },
                    421 => quote! { ::server_less::ErrorCode::InvalidInput },
                    422 => quote! { ::server_less::ErrorCode::UnprocessableEntity },
                    423 => quote! { ::server_less::ErrorCode::Conflict },
                    424 => quote! { ::server_less::ErrorCode::UnprocessableEntity },
                    425 => quote! { ::server_less::ErrorCode::Internal },
                    426 => quote! { ::server_less::ErrorCode::InvalidInput },
                    428 => quote! { ::server_less::ErrorCode::UnprocessableEntity },
                    429 => quote! { ::server_less::ErrorCode::RateLimited },
                    431 => quote! { ::server_less::ErrorCode::InvalidInput },
                    451 => quote! { ::server_less::ErrorCode::Forbidden },
                    // ── Standard IANA 5xx ──────────────────────────────────────────
                    500 => quote! { ::server_less::ErrorCode::Internal },
                    501 => quote! { ::server_less::ErrorCode::NotImplemented },
                    502 => quote! { ::server_less::ErrorCode::Unavailable },
                    503 => quote! { ::server_less::ErrorCode::Unavailable },
                    504 => quote! { ::server_less::ErrorCode::Unavailable },
                    505 => quote! { ::server_less::ErrorCode::Internal },
                    506 => quote! { ::server_less::ErrorCode::Internal },
                    507 => quote! { ::server_less::ErrorCode::Internal },
                    508 => quote! { ::server_less::ErrorCode::Internal },
                    510 => quote! { ::server_less::ErrorCode::Internal },
                    511 => quote! { ::server_less::ErrorCode::Unauthenticated },
                    // ── Cloudflare vendor 5xx ──────────────────────────────────────
                    520 => quote! { ::server_less::ErrorCode::Internal },
                    521 => quote! { ::server_less::ErrorCode::Unavailable },
                    522 => quote! { ::server_less::ErrorCode::Unavailable },
                    523 => quote! { ::server_less::ErrorCode::Unavailable },
                    524 => quote! { ::server_less::ErrorCode::Unavailable },
                    525 => quote! { ::server_less::ErrorCode::Internal },
                    526 => quote! { ::server_less::ErrorCode::Internal },
                    527 => quote! { ::server_less::ErrorCode::Unavailable },
                    530 => quote! { ::server_less::ErrorCode::Unavailable },
                    // ── nginx vendor ───────────────────────────────────────────────
                    444 => quote! { ::server_less::ErrorCode::Internal },
                    494 => quote! { ::server_less::ErrorCode::InvalidInput },
                    495 => quote! { ::server_less::ErrorCode::Internal },
                    496 => quote! { ::server_less::ErrorCode::Unauthenticated },
                    497 => quote! { ::server_less::ErrorCode::InvalidInput },
                    499 => quote! { ::server_less::ErrorCode::Internal },
                    // ── IIS vendor ─────────────────────────────────────────────────
                    440 => quote! { ::server_less::ErrorCode::Unauthenticated },
                    449 => quote! { ::server_less::ErrorCode::Internal },
                    // ── Unknown ────────────────────────────────────────────────────
                    other => {
                        return Err(syn::Error::new(
                            span,
                            format!(
                                "unknown HTTP status code {other}; \
                                 use a named ErrorCode variant instead\n\
                                 \n\
                                 Valid named variants: NotFound, InvalidInput, Unauthenticated, \
                                 Forbidden, Conflict, UnprocessableEntity, RateLimited, Internal, \
                                 NotImplemented, Unavailable\n\
                                 \n\
                                 Known numeric codes: IANA 4xx/5xx plus Cloudflare/nginx/IIS vendor codes.\n\
                                 \n\
                                 Example: #[error(code = NotFound)]"
                            ),
                        ));
                    }
                }
            }
            None => {
                // Infer from variant name
                quote! { ::server_less::ErrorCode::infer_from_name(#variant_name_str) }
            }
        };

        // Determine message
        let message_expr = if let Some(msg) = args.message {
            quote! { #msg.to_string() }
        } else {
            // Use variant name, converting CamelCase to "Camel case"
            let readable = camel_to_sentence(&variant_name_str);
            quote! { #readable.to_string() }
        };

        // Generate match arms based on variant fields
        let (pattern, display_format) = match &variant.fields {
            Fields::Unit => (
                quote! { Self::#variant_name },
                quote! { write!(f, "{}", ::server_less::IntoErrorCode::message(self)) },
            ),
            Fields::Unnamed(fields) => {
                let field_names: Vec<_> = (0..fields.unnamed.len())
                    .map(|i| quote::format_ident!("_{}", i))
                    .collect();
                let pattern = quote! { Self::#variant_name(#(#field_names),*) };

                // If single String field, include it in display
                if fields.unnamed.len() == 1 {
                    (
                        pattern.clone(),
                        quote! { write!(f, "{}: {}", #variant_name_str, _0) },
                    )
                } else {
                    (pattern, quote! { write!(f, "{}", ::server_less::IntoErrorCode::message(self)) })
                }
            }
            Fields::Named(fields) => {
                let field_names: Vec<_> = fields
                    .named
                    .iter()
                    .map(|f| f.ident.as_ref().unwrap())
                    .collect();
                let pattern = quote! { Self::#variant_name { #(#field_names),* } };
                (pattern, quote! { write!(f, "{}", ::server_less::IntoErrorCode::message(self)) })
            }
        };

        error_code_arms.push(quote! {
            #pattern => #error_code
        });

        // Generate jsonrpc_code arm: use explicit override if provided, otherwise delegate to error_code
        let jsonrpc_code_expr = if let Some(code) = args.jsonrpc_code {
            quote! { #code }
        } else {
            quote! { #error_code.jsonrpc_code() }
        };
        jsonrpc_code_arms.push(quote! {
            #pattern => #jsonrpc_code_expr
        });

        message_arms.push(quote! {
            #pattern => #message_expr
        });

        display_arms.push(quote! {
            #pattern => #display_format
        });
    }

    Ok(quote! {
        impl ::server_less::IntoErrorCode for #name {
            fn error_code(&self) -> ::server_less::ErrorCode {
                match self {
                    #(#error_code_arms,)*
                }
            }

            fn jsonrpc_code(&self) -> i32 {
                match self {
                    #(#jsonrpc_code_arms,)*
                }
            }

            fn message(&self) -> String {
                match self {
                    #(#message_arms,)*
                }
            }
        }

        impl ::std::fmt::Display for #name {
            fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
                match self {
                    #(#display_arms,)*
                }
            }
        }

        impl ::std::error::Error for #name {}
    })
}

/// Convert CamelCase to "Camel case" sentence
fn camel_to_sentence(s: &str) -> String {
    let mut result = String::new();
    for (i, c) in s.chars().enumerate() {
        if c.is_uppercase() && i > 0 {
            result.push(' ');
            for lower in c.to_lowercase() {
                result.push(lower);
            }
        } else {
            result.push(c);
        }
    }
    result
}