mkutils-macros 0.1.145

Utility methods, traits, and types.
Documentation
use crate::{error::Error, utils::CommaPunctuated};
use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use syn::{
    Data, DeriveInput, Error as SynError, Expr, Field, Fields, FieldsNamed, FieldsUnnamed, Generics, Ident, ImplItem,
    Index, Lit, Meta, Path, WherePredicate, spanned::Spanned,
};

pub struct Basic;

impl Basic {
    const ASSOCIATED_ITEM_ATTRIBUTE_NAME: &str = "assoc";
    const BOUND_ATTRIBUTE_NAME: &str = "bound";

    fn field_assignment(field: &Field, trait_path: &Path, method: &Ident) -> Result<TokenStream2, SynError> {
        let Some(field_name) = &field.ident else {
            return Err(Error::c_struct_field_missing_name(field));
        };
        let field_assignment =
            quote::quote! { #field_name: #trait_path::#method(&self.#field_name, &other.#field_name) };

        Ok(field_assignment)
    }

    fn value_for_c_struct(
        fields_named: &FieldsNamed,
        trait_path: &Path,
        method: &Ident,
    ) -> Result<TokenStream2, SynError> {
        let field_assignments = fields_named
            .named
            .iter()
            .map(|field| Self::field_assignment(field, trait_path, method))
            .collect::<Result<Vec<TokenStream2>, SynError>>()?;
        let value = quote::quote! { Self { #(#field_assignments),* } };

        Ok(value)
    }

    fn value_for_tuple_struct(fields_unnamed: &FieldsUnnamed, trait_path: &Path, method: &Ident) -> TokenStream2 {
        let num_fields = fields_unnamed.unnamed.len();
        let field_values = (0..num_fields)
            .map(Index::from)
            .map(|index| quote::quote! { #trait_path::#method(&self.#index, &other.#index) });

        quote::quote! { Self(#(#field_values),*) }
    }

    fn value_for_unit_struct() -> TokenStream2 {
        quote::quote! { Self }
    }

    fn value(input: &DeriveInput, trait_path: &Path, method: &Ident) -> Result<TokenStream2, SynError> {
        let Data::Struct(data_struct) = &input.data else {
            return Err(Error::unsupported_item_type(input));
        };
        let value = match &data_struct.fields {
            Fields::Named(fields_named) => Self::value_for_c_struct(fields_named, trait_path, method)?,
            Fields::Unnamed(fields_unnamed) => Self::value_for_tuple_struct(fields_unnamed, trait_path, method),
            Fields::Unit => Self::value_for_unit_struct(),
        };

        Ok(value)
    }

    fn parse_associated_item(meta: &Meta) -> Result<ImplItem, SynError> {
        let meta_list = meta.require_list()?;
        let associated_item_token_stream = &meta_list.tokens;
        let associated_item_res = syn::parse2::<ImplItem>(associated_item_token_stream.clone());

        if associated_item_res.is_ok() {
            return associated_item_res;
        }

        let associated_item_token_stream = quote::quote! { #associated_item_token_stream; };

        syn::parse2::<ImplItem>(associated_item_token_stream)
    }

    fn parse_attribute_args(
        attribute_args: CommaPunctuated<Meta>,
        generics: &mut Generics,
        associated_items: &mut Vec<ImplItem>,
        use_default_bounds: &mut bool,
    ) -> Result<(), SynError> {
        for meta in attribute_args {
            if meta.path().is_ident(Self::ASSOCIATED_ITEM_ATTRIBUTE_NAME) {
                let associated_item = Self::parse_associated_item(&meta)?;

                associated_items.push(associated_item);
            } else if meta.path().is_ident(Self::BOUND_ATTRIBUTE_NAME) {
                let meta_name_value = meta.require_name_value()?;
                let bounds_str = crate::utils::parse!(Expr::Lit, &meta_name_value.value)?;
                let bounds_str = crate::utils::parse!(Lit::Str, &bounds_str.lit)?;
                let bounds = bounds_str.parse_with(CommaPunctuated::<WherePredicate>::parse_terminated)?;

                generics.make_where_clause().predicates.extend(bounds);

                *use_default_bounds = false;
            } else {
                let expected_values = &[Self::ASSOCIATED_ITEM_ATTRIBUTE_NAME, Self::BOUND_ATTRIBUTE_NAME];
                let error = Error::unexpected_value_multi(&meta, expected_values);

                return Err(error);
            }
        }

        Ok(())
    }

    fn get_generics_and_associated_items(
        input: &DeriveInput,
        trait_path: &Path,
        attribute_name: &str,
    ) -> Result<(Generics, Vec<ImplItem>), SynError> {
        let mut generics = input.generics.clone();
        let mut associated_items = Vec::new();
        let mut use_default_bounds = true;

        for attribute in &input.attrs {
            if !attribute.path().is_ident(attribute_name) {
                continue;
            }

            let attribute_args = attribute.parse_args_with(CommaPunctuated::<Meta>::parse_terminated)?;

            Self::parse_attribute_args(
                attribute_args,
                &mut generics,
                &mut associated_items,
                &mut use_default_bounds,
            )?;
        }

        if use_default_bounds {
            let mut default_bounds = Vec::new();

            for type_param in generics.type_params() {
                let type_param_ident = &type_param.ident;
                let default_bound = quote::quote! { #type_param_ident: #trait_path };
                let default_bound = syn::parse2::<WherePredicate>(default_bound)?;

                default_bounds.push(default_bound);
            }

            generics.make_where_clause().predicates.extend(default_bounds);
        }

        Ok((generics, associated_items))
    }

    fn derive_impl(
        input: &DeriveInput,
        trait_path: &str,
        method: &str,
        other_type: &str,
        attribute_name: &str,
    ) -> Result<TokenStream2, SynError> {
        let input_ident = &input.ident;
        let trait_path = syn::parse_str::<Path>(trait_path)?;
        let (generics, associated_items) = Self::get_generics_and_associated_items(input, &trait_path, attribute_name)?;
        let (impl_generics, input_generics, input_where_clause) = generics.split_for_impl();
        let method = syn::parse_str::<Ident>(method)?;
        let other_type = syn::parse_str::<Path>(other_type)?;
        let value = Self::value(input, &trait_path, &method)?;
        let impl_block_token_stream = quote::quote! {
            impl #impl_generics #trait_path for #input_ident #input_generics #input_where_clause {
                #(#associated_items)*

                fn #method(&self, other: &#other_type) -> Self {
                    #value
                }
            }
        };

        Ok(impl_block_token_stream)
    }

    pub fn derive(
        input_token_stream: TokenStream,
        trait_path: &str,
        method: &str,
        other_type: &str,
        attribute_name: &str,
    ) -> TokenStream {
        let input = syn::parse_macro_input!(input_token_stream);

        Self::derive_impl(&input, trait_path, method, other_type, attribute_name)
            .unwrap_or_else(SynError::into_compile_error)
            .into()
    }
}