predawn-macro 0.9.0

Macros for predawn
Documentation
use std::collections::HashSet;

use from_attr::{AttrsValue, FromAttr};
use http::StatusCode;
use proc_macro2::TokenStream;
use quote_use::quote_use;
use syn::{spanned::Spanned, Attribute, DeriveInput, Expr, ExprLit, Ident, Lit, Type, Variant};

use crate::util;

#[derive(FromAttr)]
#[attribute(idents = [multi_response])]
struct EnumAttr {
    error: Type,
}

pub(crate) fn generate(input: DeriveInput) -> syn::Result<TokenStream> {
    let DeriveInput {
        attrs,
        ident,
        generics,
        data,
        ..
    } = input;

    let EnumAttr {
        error: into_response_error,
    } = match EnumAttr::from_attributes(&attrs) {
        Ok(Some(AttrsValue {
            value: enum_attr, ..
        })) => enum_attr,
        Ok(None) => {
            return Err(syn::Error::new(
                ident.span(),
                "missing `#[multi_response(error = SomeIntoResponseError)]` attribute",
            ))
        }
        Err(AttrsValue { value: e, .. }) => return Err(e),
    };

    let variants = util::extract_variants(data, "MultiResponse")?;

    let mut status_codes = HashSet::new();
    let mut responses_bodies = Vec::new();
    let mut into_response_arms = Vec::new();
    let mut errors = Vec::new();

    for variant in variants.into_iter() {
        match handle_single_variant(variant, &ident, &into_response_error, &mut status_codes) {
            Ok((responses_body, into_response_arm)) => {
                responses_bodies.push(responses_body);
                into_response_arms.push(into_response_arm);
            }
            Err(e) => errors.push(e),
        }
    }

    if let Some(e) = errors.into_iter().reduce(|mut a, b| {
        a.combine(b);
        a
    }) {
        return Err(e);
    }

    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

    let expand = quote_use! {
        # use std::collections::BTreeMap;
        # use predawn::MultiResponse;
        # use predawn::openapi::{self, Schema};
        # use predawn::response::Response;
        # use predawn::into_response::IntoResponse;
        # use predawn::api_response::ApiResponse;
        # use predawn::http::StatusCode;

        impl #impl_generics MultiResponse for #ident #ty_generics #where_clause {
            fn responses(schemas: &mut BTreeMap<String, Schema>, schemas_in_progress: &mut Vec<String>) -> BTreeMap<StatusCode, openapi::Response> {
                let mut map = BTreeMap::new();

                #(#responses_bodies)*

                map
            }
        }

        impl #impl_generics IntoResponse for #ident #ty_generics #where_clause {
            type Error = #into_response_error;

            fn into_response(self) -> Result<Response, <Self as IntoResponse>::Error> {
                let (mut response, status) = match self {
                    #(#into_response_arms)*
                };

                *response.status_mut() = StatusCode::from_u16(status).unwrap();

                Ok(response)
            }
        }

        impl #impl_generics ApiResponse for #ident #ty_generics #where_clause {
            fn responses(schemas: &mut BTreeMap<String, Schema>, schemas_in_progress: &mut Vec<String>) -> Option<BTreeMap<StatusCode, openapi::Response>> {
                Some(<Self as MultiResponse>::responses(schemas, schemas_in_progress))
            }
        }
    };

    Ok(expand)
}

fn handle_single_variant<'a>(
    variant: Variant,
    enum_ident: &'a Ident,
    into_response_error: &'a Type,
    status_codes: &'a mut HashSet<u16>,
) -> syn::Result<(TokenStream, TokenStream)> {
    let variant_span = variant.span();

    let Variant {
        attrs,
        ident: variant_ident,
        fields,
        ..
    } = variant;

    let Some(status_code) = extract_status_code(&attrs, status_codes)? else {
        let e = syn::Error::new(variant_span, "missing `#[status = xxx]` attribute");
        return Err(e);
    };

    let ty = util::extract_single_unnamed_field_type_from_variant(fields, variant_span)?;

    let responses_body = quote_use! {
        # use predawn::SingleResponse;
        # use predawn::http::StatusCode;

        map.insert(
            StatusCode::from_u16(#status_code).unwrap(),
            <#ty as SingleResponse>::response(schemas, schemas_in_progress),
        );
    };

    let into_response_arm = quote_use! {
        # use core::convert::From;
        # use predawn::into_response::IntoResponse;

        #enum_ident::#variant_ident(a) => match <#ty as IntoResponse>::into_response(a) {
            Ok(response) => (response, #status_code),
            Err(e) => return Err(<#into_response_error as From<_>>::from(e)),
        },
    };

    Ok((responses_body, into_response_arm))
}

fn extract_status_code<'a>(
    attrs: &'a [Attribute],
    status_codes: &'a mut HashSet<u16>,
) -> syn::Result<Option<u16>> {
    let mut errors = Vec::new();
    let mut found = None;

    for attr in attrs {
        if !attr.path().is_ident("status") {
            continue;
        }

        let value = match attr.meta.require_name_value() {
            Ok(name_value) => &name_value.value,
            Err(e) => {
                errors.push(e);
                continue;
            }
        };

        let Expr::Lit(ExprLit {
            lit: Lit::Int(lit_int),
            ..
        }) = value
        else {
            let e = syn::Error::new(value.span(), "only int literal is allowed");
            errors.push(e);
            continue;
        };

        if found.is_some() {
            let e = syn::Error::new(attr.span(), "only one `status` attribute is allowed");
            errors.push(e);
            continue;
        }

        let status_code = match lit_int.base10_parse::<u16>() {
            Ok(a) => a,
            Err(e) => {
                errors.push(e);
                continue;
            }
        };

        let lit_int_span = lit_int.span();

        if StatusCode::from_u16(status_code).is_err() {
            let e = syn::Error::new(lit_int_span, "it is not a valid status code");
            errors.push(e);
            continue;
        }

        if !status_codes.contains(&status_code) {
            status_codes.insert(status_code);
            found = Some(status_code);
        } else {
            let e = syn::Error::new(lit_int_span, "duplicate status code");
            errors.push(e);
        }
    }

    if let Some(e) = errors.into_iter().reduce(|mut a, b| {
        a.combine(b);
        a
    }) {
        return Err(e);
    }

    Ok(found)
}