palette_derive 0.7.6

Automatically implement traits from the palette crate.
Documentation
use proc_macro::TokenStream;
use proc_macro2::Span;

use quote::{quote, ToTokens};
use syn::{
    punctuated::Punctuated, token::Comma, Attribute, Data, DeriveInput, Fields, Meta, Path, Type,
};

use crate::meta::{self, FieldAttributes, IdentOrIndex, TypeItemAttributes};
use crate::util;

pub fn derive(tokens: TokenStream) -> std::result::Result<TokenStream, Vec<syn::Error>> {
    let DeriveInput {
        ident,
        attrs,
        generics,
        data,
        ..
    } = syn::parse(tokens).map_err(|error| vec![error])?;

    let allowed_repr = is_allowed_repr(&attrs)?;
    let (item_meta, item_errors) = meta::parse_namespaced_attributes::<TypeItemAttributes>(attrs);

    let mut number_of_channels = 0usize;
    let mut field_type: Option<Type> = None;

    let (all_fields, fields_meta, field_errors) = match data {
        Data::Struct(struct_item) => {
            let (fields_meta, field_errors) =
                meta::parse_field_attributes::<FieldAttributes>(struct_item.fields.clone());
            let all_fields = match struct_item.fields {
                Fields::Named(fields) => fields.named,
                Fields::Unnamed(fields) => fields.unnamed,
                Fields::Unit => Default::default(),
            };

            (all_fields, fields_meta, field_errors)
        }
        Data::Enum(_) => {
            return Err(vec![syn::Error::new(
                Span::call_site(),
                "`ArrayCast` cannot be derived for enums, because of the discriminant",
            )]);
        }
        Data::Union(_) => {
            return Err(vec![syn::Error::new(
                Span::call_site(),
                "`ArrayCast` cannot be derived for unions",
            )]);
        }
    };

    let fields = all_fields
        .into_iter()
        .enumerate()
        .map(|(index, field)| {
            (
                field
                    .ident
                    .map(IdentOrIndex::Ident)
                    .unwrap_or_else(|| IdentOrIndex::Index(index.into())),
                field.ty,
            )
        })
        .filter(|(field, _)| !fields_meta.zero_size_fields.contains(field));

    let mut errors = Vec::new();

    for (field, ty) in fields {
        let ty = fields_meta
            .type_substitutes
            .get(&field)
            .cloned()
            .unwrap_or(ty);
        number_of_channels += 1;

        if let Some(field_type) = field_type.clone() {
            if field_type != ty {
                errors.push(syn::Error::new_spanned(
                    &field,
                    format!(
                        "expected fields to have type `{}`",
                        field_type.into_token_stream()
                    ),
                ));
            }
        } else {
            field_type = Some(ty);
        }
    }

    if !allowed_repr {
        errors.push(syn::Error::new(
            Span::call_site(),
            format!(
                "a `#[repr(C)]` or `#[repr(transparent)]` attribute is required to give `{}` a fixed memory layout",
                ident
            ),
        ));
    }

    let array_cast_trait_path = util::path(["cast", "ArrayCast"], item_meta.internal);

    let mut implementation = if let Some(field_type) = field_type {
        let (impl_generics, type_generics, where_clause) = generics.split_for_impl();

        quote! {
            #[automatically_derived]
            unsafe impl #impl_generics #array_cast_trait_path for #ident #type_generics #where_clause {
                type Array = [#field_type; #number_of_channels];
            }
        }
    } else {
        errors.push(syn::Error::new(
            Span::call_site(),
            "`ArrayCast` can only be derived for structs with one or more fields".to_string(),
        ));

        return Err(errors);
    };

    implementation.extend(errors.iter().map(syn::Error::to_compile_error));

    let item_errors = item_errors
        .into_iter()
        .map(|error| error.into_compile_error());
    let field_errors = field_errors
        .into_iter()
        .map(|error| error.into_compile_error());

    Ok(quote! {
        #(#item_errors)*
        #(#field_errors)*

        #implementation
    }
    .into())
}

fn is_allowed_repr(attributes: &[Attribute]) -> std::result::Result<bool, Vec<syn::Error>> {
    let mut errors = Vec::new();

    for attribute in attributes {
        let attribute_name = attribute.path().get_ident().map(ToString::to_string);

        if let Some("repr") = attribute_name.as_deref() {
            let meta_list = match attribute.meta.require_list() {
                Ok(list) => list,
                Err(error) => {
                    errors.push(error);
                    continue;
                }
            };

            let items = match meta_list.parse_args_with(Punctuated::<Meta, Comma>::parse_terminated)
            {
                Ok(items) => items,
                Err(error) => {
                    errors.push(error);
                    continue;
                }
            };

            let contains_allowed_repr = items.iter().any(|item| {
                item.require_path_only()
                    .ok()
                    .and_then(Path::get_ident)
                    .map_or(false, |ident| ident == "C" || ident == "transparent")
            });

            if contains_allowed_repr {
                return Ok(true);
            }
        }
    }

    if errors.is_empty() {
        Ok(false)
    } else {
        Err(errors)
    }
}