ext-crud-derive 0.1.1

Derive macros for ext-crud-rs
Documentation
use proc_macro::TokenStream;
use quote::format_ident;
use quote::quote;
use syn::{
    parse_macro_input, punctuated::Punctuated, token::Comma, Data, DeriveInput, Field, Fields,
    Ident, Type,
};

#[proc_macro_derive(ExtendedCrud, attributes(table_name, primary_key))]
pub fn derive_extended_crud(input: TokenStream) -> TokenStream {
    let input = parse_macro_input!(input as DeriveInput);
    let name = &input.ident;

    let table_name = input
        .attrs
        .iter()
        .find(|attr| attr.path().is_ident("table_name"))
        .map(|attr| attr.parse_args::<syn::LitStr>().unwrap().value())
        .expect("table_name attribute is required");

    let fields = fields_named_from_input(&input);

    let (primary_key_field, primary_key_type, primary_key_name) = extract_primary_keys(fields);

    let expanded = quote! {
        impl #name {
            fn primary_key_name() -> &'static str {
                #primary_key_name
            }
            fn primary_key(&self) -> &#primary_key_type {
                &self.#primary_key_field
            }
        }
        impl<C: Client> ExtendedCrud<C> for #name {
            type PrimaryKey = #primary_key_type;

            const TABLE_NAME: &'static str = #table_name;

            const PRIMARY_KEY_NAME: &'static str = #primary_key_name;

            fn primary_key(&self) -> &Self::PrimaryKey {
                 &self.#primary_key_field
             }
        }
        impl TryFromError<serde_json::Value, serde_json::Error> for #name {

            fn try_from_err(value: serde_json::Value) -> Result<Self, serde_json::Error> {
                serde_json::from_value(value).map_err(Into::into)
            }
        }
    };

    TokenStream::from(expanded)
}

#[proc_macro_derive(PartialEntity, attributes(partial_entity_name, primary_key))]
pub fn partial_entity(input: TokenStream) -> TokenStream {
    let input = parse_macro_input!(input as DeriveInput);
    let original_name = &input.ident;

    let partial_name = input
        .attrs
        .iter()
        .find(|attr| attr.path().is_ident("partial_entity_name"))
        .map(|attr| format_ident!("{}", attr.parse_args::<syn::LitStr>().unwrap().value()))
        .unwrap_or_else(|| format_ident!("Partial{}", original_name.to_string()));

    let fields = fields_named_from_input(&input);

    let (primary_key_field, primary_key_type, primary_key_name) = extract_primary_keys(fields);

    let partial_fields = fields.iter().map(|f| {
        let name = &f.ident;
        let ty = &f.ty;
        quote! {
            #name: Option<#ty>
        }
    });

    let to_partial_fields = fields.iter().map(|f| {
        let name = &f.ident;
        quote! {
            #name: Some(self.#name.clone())
        }
    });

    let apply_fields = fields.iter().map(|f| {
        let name = &f.ident;
        quote! {
            if let Some(ref value) = self.#name {
                original.#name = value.clone();
            }
        }
    });

    let new_fields = fields.iter().map(|f| {
        let name = &f.ident;
        quote! {
            #name: None
        }
    });

    let expanded = quote! {
        #[derive(Debug, Clone, Serialize, Deserialize)]
        pub struct #partial_name {
            #(#partial_fields,)*
        }

        impl #original_name {
            fn to_partial(&self) -> #partial_name {
                #partial_name {
                    #(#to_partial_fields,)*
                }
            }
        }

        impl PartialEntity<#original_name> for #partial_name {
            type PrimaryKey = #primary_key_type;

            const PRIMARY_KEY_NAME: &'static str = #primary_key_name;

            fn new() -> Self {
                 Self {
                     #(#new_fields,)*
                 }
             }

            fn apply_to(&self, original: &#original_name) -> #original_name {
                let mut original = original.clone();
                #(#apply_fields)*
                original
            }

            fn primary_key(&self) -> Option<Self::PrimaryKey> {
                self.#primary_key_field.clone()
            }
        }
    };

    TokenStream::from(expanded)
}

fn fields_named_from_input(input: &DeriveInput) -> &Punctuated<Field, Comma> {
    match &input.data {
        Data::Struct(data) => fields_named(&data.fields),
        _ => panic!("Only structs are supported"),
    }
}

fn fields_named(fields: &Fields) -> &Punctuated<Field, Comma> {
    match fields {
        Fields::Named(fields) => &fields.named,
        _ => panic!("Only named fields are supported"),
    }
}

fn extract_primary_keys(fields: &Punctuated<Field, Comma>) -> (&Ident, &Type, String) {
    fields
        .iter()
        .find(|f| {
            f.attrs
                .iter()
                .any(|attr| attr.path().is_ident("primary_key"))
        })
        .map(|f| {
            let primary_key_field = f.ident.as_ref().unwrap();
            let primary_key_name = f
                .attrs
                .iter()
                .find(|attr| attr.path().is_ident("primary_key"))
                .map(|attr| {
                    if let Ok(attr) = attr.parse_args::<syn::LitStr>() {
                        attr.value()
                    } else {
                        primary_key_field.to_string()
                    }
                })
                .unwrap_or(primary_key_field.to_string());
            (primary_key_field, &f.ty, primary_key_name)
        })
        .or_else(|| {
            fields
                .iter()
                .find(|f| f.ident.as_ref().unwrap() == "id")
                .map(|f| (f.ident.as_ref().unwrap(), &f.ty, "id".to_string()))
        })
        .expect("A field named 'id' or with #[primary_key] attribute is required")
}