predawn-macro 0.9.0

Macros for predawn
Documentation
use from_attr::{AttrsValue, FromAttr};
use proc_macro2::TokenStream;
use quote::quote;
use quote_use::quote_use;
use syn::{DeriveInput, Field, Ident};

use crate::{schema_attr::SchemaAttr, serde_attr::SerdeAttr, util};

pub(crate) fn generate(input: DeriveInput) -> syn::Result<TokenStream> {
    let DeriveInput {
        attrs,
        ident,
        generics,
        data,
        ..
    } = input;

    let named = util::extract_named_struct_fields(data, "Multipart")?;

    let mut struct_field_idents = Vec::new();
    let mut define_vars = Vec::new();
    let mut parse_fields = Vec::new();
    let mut extract_vars = Vec::new();
    let mut errors = Vec::new();

    named
        .into_iter()
        .for_each(|field| match generate_single_field(field) {
            Ok((struct_field, define_var, parse_field, extract_var)) => {
                struct_field_idents.push(struct_field);
                define_vars.push(define_var);
                parse_fields.push(parse_field);
                extract_vars.push(extract_var);
            }
            Err(e) => errors.push(e),
        });

    if let Some(e) = errors.into_iter().reduce(|mut a, b| {
        a.combine(b);
        a
    }) {
        return Err(e);
    }

    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

    let impl_generics_with_lifetime = {
        let mut s = quote!(#impl_generics).to_string();
        match s.find('<') {
            Some(pos) => {
                s.insert_str(pos + 1, "'a,");
                syn::parse_str(&s).unwrap()
            }
            _ => quote!(<'a>),
        }
    };

    let description = util::extract_description(&attrs);
    let description = if description.is_empty() {
        quote! { None }
    } else {
        let description = util::generate_string_expr(&description);
        quote! { Some(#description) }
    };

    let expand = quote_use! {
        # use core::default::Default;
        # use std::vec::Vec;
        # use std::collections::BTreeMap;
        # use predawn::{MultiRequestMediaType, ToSchema};
        # use predawn::media_type::{MediaType, RequestMediaType, has_media_type, SingleMediaType};
        # use predawn::from_request::FromRequest;
        # use predawn::response_error::MultipartError;
        # use predawn::request::Head;
        # use predawn::body::RequestBody;
        # use predawn::extract::multipart::Multipart;
        # use predawn::api_request::ApiRequest;
        # use predawn::openapi::{self, Schema, Parameter};

        impl #impl_generics_with_lifetime FromRequest<'a> for #ident #ty_generics #where_clause {
            type Error = MultipartError;

            async fn from_request(head: &'a mut Head, body: RequestBody) -> Result<Self, Self::Error> {
                let mut multipart = <Multipart as FromRequest<_>>::from_request(head, body).await?;

                #(#define_vars)*

                while let Some(field) = multipart.next_field().await? {
                    #(#parse_fields)*
                }

                #(#extract_vars)*

                Ok(Self { #(#struct_field_idents),* })
            }
        }

        impl #impl_generics ApiRequest for #ident #ty_generics #where_clause {
            fn parameters(_: &mut BTreeMap<String, Schema>, _: &mut Vec<String>) -> Option<Vec<openapi::Parameter>> {
                None
            }

            fn request_body(schemas: &mut BTreeMap<String, Schema>, schemas_in_progress: &mut Vec<String>) -> Option<openapi::RequestBody> {
                Some(openapi::RequestBody {
                    description: #description,
                    content: <Self as MultiRequestMediaType>::content(schemas, schemas_in_progress),
                    required: true,
                    extensions: Default::default(),
                })
            }
        }

        impl #impl_generics MediaType for #ident #ty_generics #where_clause {
            const MEDIA_TYPE: &'static str = "multipart/form-data";
        }

        impl #impl_generics RequestMediaType for #ident #ty_generics #where_clause {
            fn check_content_type(content_type: &str) -> bool {
                has_media_type(content_type, "multipart", "form-data", "form-data", None)
            }
        }

        impl #impl_generics SingleMediaType for #ident #ty_generics #where_clause {
            fn media_type(schemas: &mut BTreeMap<String, Schema>, schemas_in_progress: &mut Vec<String>) -> openapi::MediaType {
                openapi::MediaType {
                    schema: Some(<Self as ToSchema>::schema_ref(schemas, schemas_in_progress)),
                    example: Default::default(),
                    examples: Default::default(),
                    encoding: Default::default(),
                    extensions: Default::default(),
                }
            }
        }
    };

    Ok(expand)
}

fn generate_single_field(
    field: Field,
) -> syn::Result<(Ident, TokenStream, TokenStream, TokenStream)> {
    let Field {
        attrs, ident, ty, ..
    } = field;

    let SerdeAttr {
        rename: serde_rename,
        flatten: _,
        default: serde_default,
    } = SerdeAttr::new(&attrs);

    let SchemaAttr {
        rename: schema_rename,
        flatten: _,
        default: schema_default,
    } = match SchemaAttr::from_attributes(&attrs) {
        Ok(Some(AttrsValue {
            value: field_attr, ..
        })) => field_attr,
        Ok(None) => Default::default(),
        Err(AttrsValue { value: e, .. }) => return Err(e),
    };

    let struct_field_ident = ident.expect("unreachable: named field must have an identifier");

    let multipart_field = schema_rename
        .unwrap_or_else(|| serde_rename.unwrap_or_else(|| struct_field_ident.to_string()));

    let default_expr = util::generate_default_expr(&ty, serde_default, schema_default)?;
    let missing_field_arm = default_expr.map(|expr| {
        quote_use! {
            # use predawn::response_error::MultipartError;

            Err(MultipartError::MissingField { .. }) => #expr,
        }
    });

    let define_var = quote_use! {
        # use predawn::extract::multipart::ParseField;

        let mut #struct_field_ident = <#ty as ParseField>::default_holder(#multipart_field);
    };

    let parse_field = quote_use! {
        # use predawn::extract::multipart::ParseField;

        if field.name() == Some(#multipart_field) {
            #struct_field_ident = <#ty as ParseField>::parse_field(#struct_field_ident, field, #multipart_field).await?;
            continue;
        }
    };

    let extract_var = quote_use! {
        # use predawn::extract::multipart::ParseField;

        let #struct_field_ident = match <#ty as ParseField>::extract(#struct_field_ident, #multipart_field) {
            Ok(v) => v,
            #missing_field_arm
            Err(e) => return Err(e),
        };
    };

    Ok((struct_field_ident, define_var, parse_field, extract_var))
}