picoserve_derive 0.1.2

Macros for picoserve
Documentation
use proc_macro::TokenStream;
use quote::{quote, ToTokens};
use syn::{parse_macro_input, DeriveInput};

fn single_field(fields: &syn::Fields) -> Option<proc_macro2::TokenStream> {
    match fields {
        syn::Fields::Named(fields) => {
            let mut fields = fields.named.iter();
            let field = fields.next()?;
            fields
                .next()
                .is_none()
                .then(|| quote! { { #field: ref field } })
        }
        syn::Fields::Unnamed(fields) => {
            let mut fields = fields.unnamed.iter();
            let _field = fields.next()?;
            fields.next().is_none().then(|| quote! { (ref field) })
        }
        syn::Fields::Unit => None,
    }
}

enum StatusCodeAttr {
    StatusCode(proc_macro2::TokenStream),
    Transparent,
}

fn status_code_attr(attrs: &[syn::Attribute]) -> Result<Option<StatusCodeAttr>, &'static str> {
    let mut attrs = attrs.iter();

    loop {
        let Some(attr) = attrs.next() else {
            return Ok(None);
        };

        if !attr.path().is_ident("status_code") {
            continue;
        };

        let syn::Meta::List(meta_list) = &attr.meta else {
            return Err("#[status_code(...)] must be a `StatusCode`");
        };

        let status_code = &meta_list.tokens;

        return Ok(Some(if status_code.to_string() == "transparent" {
            StatusCodeAttr::Transparent
        } else {
            StatusCodeAttr::StatusCode(quote! { picoserve::response::StatusCode::#status_code })
        }));
    }
}

fn try_derive_error_with_status_code(
    input: &DeriveInput,
) -> Result<proc_macro2::TokenStream, syn::Error> {
    let ident = &input.ident;

    let default_status_code = status_code_attr(&input.attrs)
        .map_err(|message| syn::Error::new_spanned(input, message))?;

    let status_code = match &input.data {
        syn::Data::Struct(data_struct) => match default_status_code
            .ok_or_else(|| syn::Error::new_spanned(input, "Missing #[status_code(..)]"))?
        {
            StatusCodeAttr::StatusCode(token_stream) => token_stream,
            StatusCodeAttr::Transparent => {
                let fields = single_field(&data_struct.fields).ok_or_else(|| {
                    syn::Error::new_spanned(input, "Transparent errors must have a single field")
                })?;

                quote! {
                    let Self #fields = self;
                    picoserve::response::ErrorWithStatusCode::status_code(field)
                }
            }
        },
        syn::Data::Enum(data_enum) => {
            let cases = data_enum
                .variants
                .iter()
                .map(|variant| {
                    let variant_status_code = status_code_attr(&variant.attrs)
                        .map_err(|message| syn::Error::new_spanned(ident, message))?;

                    let selected_status_code = variant_status_code
                        .as_ref()
                        .or(default_status_code.as_ref())
                        .ok_or_else(|| {
                            syn::Error::new_spanned(
                                variant,
                                "Either the enum or this variant must have an attribute of `status_code`",
                            )
                        })?;

                    let ident = &variant.ident;
                    let fields;
                    let status_code;

                    match selected_status_code {
                        StatusCodeAttr::StatusCode(selected_status_code) => {
                            fields = match variant.fields {
                                syn::Fields::Named(..) => quote! { {..} },
                                syn::Fields::Unnamed(..) => quote! { (..) },
                                syn::Fields::Unit => quote! {},
                            };

                            status_code = selected_status_code.clone();
                        }
                        StatusCodeAttr::Transparent => {
                            fields = single_field(&variant.fields).ok_or_else(|| {
                                syn::Error::new_spanned(
                                    variant,
                                    "Transparent errors must have a single field",
                                )
                            })?;

                            status_code = quote! {
                                picoserve::response::ErrorWithStatusCode::status_code(field)
                            };
                        }
                    }

                    Ok(quote! { Self::#ident #fields => #status_code, })
                })
                .collect::<Result<proc_macro2::TokenStream, syn::Error>>()?;

            quote! {
                match *self {
                    #cases
                }
            }
        }
        syn::Data::Union(..) => {
            return Err(syn::Error::new_spanned(
                input,
                "Must be a struct or an enum",
            ))
        }
    };

    let syn::Generics {
        lt_token,
        params: generics_params,
        gt_token,
        where_clause,
    } = &input.generics;

    let where_clause_predicates = where_clause
        .as_ref()
        .map(|where_clause| where_clause.predicates.iter())
        .into_iter()
        .flatten()
        .map(ToTokens::to_token_stream)
        .chain(std::iter::once(quote! { Self: core::fmt::Display }))
        .collect::<syn::punctuated::Punctuated<_, syn::token::Comma>>();

    let param_names = generics_params
        .iter()
        .map(|param| match param {
            syn::GenericParam::Lifetime(syn::LifetimeParam { lifetime, .. }) => {
                lifetime.to_token_stream()
            }
            syn::GenericParam::Type(type_param) => type_param.ident.to_token_stream(),
            syn::GenericParam::Const(const_param) => const_param.ident.to_token_stream(),
        })
        .collect::<syn::punctuated::Punctuated<proc_macro2::TokenStream, syn::token::Comma>>();

    Ok(quote! {
        #[allow(unused_qualifications)]
        #[automatically_derived]
        impl #lt_token #generics_params #gt_token picoserve::response::ErrorWithStatusCode for #ident #lt_token #param_names #gt_token where #where_clause_predicates {
            fn status_code(&self) -> picoserve::response::StatusCode {
                #status_code
            }
        }

        #[allow(unused_qualifications)]
        #[automatically_derived]
        impl #lt_token #generics_params #gt_token picoserve::response::IntoResponse for #ident #lt_token #param_names #gt_token where #where_clause_predicates {
            async fn write_to<R: picoserve::io::Read, W: picoserve::response::ResponseWriter<Error = R::Error>>(
                self,
                connection: picoserve::response::Connection<'_, R>,
                response_writer: W,
            ) -> Result<picoserve::ResponseSent, W::Error> {
                (picoserve::response::ErrorWithStatusCode::status_code(&self), format_args!("{self}\n"))
                    .write_to(connection, response_writer)
                    .await
            }
        }
    })
}

/// Derive `ErrorWithStatusCode` for a struct or an enum.
///
/// This will also derive `IntoResponse`, returning a `Response` with the given status code and a `text/plain` body of the `Display` implementation.
///
/// # Structs
///
/// There must be an attribute `status_code` containing the StatusCode of the error, e.g. `#[status_code(INTERNAL_SERVER_ERROR)]`.
///
/// If the `status_code` is `transparent`, the struct must contain a single field which implements `ErrorWithStatusCode`.
///
/// # Enums
///
/// There may be an attribute `status_code` on the enum itself containing the default StatusCode of the error.
///
/// There may also be an attribute `status_code` on a variant, which overrides the default StatusCode.
/// If all variants have their own attribute `status_code`, the default may be omitted.
///
/// Variants with a `status_code` of `transparent` must contain a single field which implements `ErrorWithStatusCode`.
#[proc_macro_derive(ErrorWithStatusCode, attributes(status_code))]
pub fn derive_error_with_status_code(input: TokenStream) -> TokenStream {
    let input = parse_macro_input!(input as DeriveInput);

    match try_derive_error_with_status_code(&input) {
        Ok(tokens) => tokens.into(),
        Err(error) => error.into_compile_error().into(),
    }
}