closure-tree-macros 0.0.1

Procedural macros for closure-tree.
Documentation
use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::ext::IdentExt;
use syn::{
    parse_macro_input, spanned::Spanned, Attribute, Data, DeriveInput, Fields, Ident, Path, Type,
};

#[proc_macro_derive(ClosureTreeModel, attributes(closure_tree))]
pub fn derive_closure_tree_model(input: TokenStream) -> TokenStream {
    let input = parse_macro_input!(input as DeriveInput);

    match impl_closure_tree_model(&input) {
        Ok(tokens) => tokens,
        Err(err) => err.to_compile_error().into(),
    }
}

#[derive(Default)]
struct Options {
    id_field: Option<String>,
    id_type: Option<Type>,
    parent_field: Option<String>,
    hierarchy_module: Option<Path>,
    hierarchy_table: Option<String>,
    name_field: Option<String>,
    entity_name: Option<String>,
    hierarchy_name: Option<String>,
    ancestor_field: Option<String>,
    descendant_field: Option<String>,
    generations_field: Option<String>,
}

fn impl_closure_tree_model(input: &DeriveInput) -> syn::Result<TokenStream> {
    let struct_ident = &input.ident;

    let data_struct = match &input.data {
        Data::Struct(data) => data,
        _ => {
            return Err(syn::Error::new(
                input.span(),
                "ClosureTreeModel can only be derived for structs",
            ))
        }
    };

    let mut options = Options::default();
    let mut table_name: Option<String> = None;

    for attr in &input.attrs {
        if attr.path().is_ident("closure_tree") {
            parse_closure_tree_attr(attr, &mut options)?;
        }

        if attr.path().is_ident("sea_orm") {
            if let Some(name) = parse_sea_orm_table_name(attr)? {
                table_name = Some(name);
            }
        }
    }

    let id_field_name = options.id_field.unwrap_or_else(|| "id".to_string());
    let parent_field_name = options
        .parent_field
        .unwrap_or_else(|| "parent_id".to_string());
    let name_field_name = options.name_field.unwrap_or_else(|| "name".to_string());
    let ancestor_field_name = options
        .ancestor_field
        .unwrap_or_else(|| "ancestor_id".to_string());
    let descendant_field_name = options
        .descendant_field
        .unwrap_or_else(|| "descendant_id".to_string());
    let generations_field_name = options
        .generations_field
        .unwrap_or_else(|| "generations".to_string());

    let id_field_ident = Ident::new(&id_field_name, struct_ident.span());
    let parent_field_ident = Ident::new(&parent_field_name, struct_ident.span());
    let name_field_ident = Ident::new(&name_field_name, struct_ident.span());
    let ancestor_field_ident = Ident::new(&ancestor_field_name, struct_ident.span());
    let descendant_field_ident = Ident::new(&descendant_field_name, struct_ident.span());
    let generations_field_ident = Ident::new(&generations_field_name, struct_ident.span());

    let mut id_field_type: Option<Type> = options.id_type.clone();

    if let Fields::Named(ref fields) = data_struct.fields {
        for field in &fields.named {
            if let Some(ident) = &field.ident {
                if ident == &id_field_ident && id_field_type.is_none() {
                    id_field_type = Some(field.ty.clone());
                }
            }
        }
    } else {
        return Err(syn::Error::new(
            data_struct.fields.span(),
            "ClosureTreeModel requires named fields",
        ));
    }

    let id_type = id_field_type.ok_or_else(|| {
        syn::Error::new(
            struct_ident.span(),
            "Unable to determine id field type; specify `id_type = ...` in #[closure_tree]",
        )
    })?;

    let hierarchy_module_path = options
        .hierarchy_module
        .ok_or_else(|| syn::Error::new(struct_ident.span(), "`hierarchy_module` must be set"))?;

    let entity_name = options
        .entity_name
        .unwrap_or_else(|| struct_ident.unraw().to_string());
    let hierarchy_name = options.hierarchy_name.unwrap_or_else(|| {
        if entity_name.ends_with("Hierarchy") {
            entity_name.clone()
        } else {
            format!("{}Hierarchy", entity_name)
        }
    });

    let base_table = table_name.unwrap_or_else(|| struct_ident.unraw().to_string());
    let hierarchy_table = options
        .hierarchy_table
        .unwrap_or_else(|| format!("{}_hierarchies", base_table));

    let id_column_variant = format_ident!("{}", to_pascal_case(&id_field_name));
    let parent_column_variant = format_ident!("{}", to_pascal_case(&parent_field_name));
    let name_column_variant = format_ident!("{}", to_pascal_case(&name_field_name));
    let ancestor_column_variant = format_ident!("{}", to_pascal_case(&ancestor_field_name));
    let descendant_column_variant = format_ident!("{}", to_pascal_case(&descendant_field_name));
    let generations_column_variant = format_ident!("{}", to_pascal_case(&generations_field_name));

    let parent_column_literal = syn::LitStr::new(&parent_field_name, struct_ident.span());
    let name_column_literal = syn::LitStr::new(&name_field_name, struct_ident.span());
    let hierarchy_table_literal = syn::LitStr::new(&hierarchy_table, struct_ident.span());
    let entity_name_literal = syn::LitStr::new(&entity_name, struct_ident.span());
    let hierarchy_name_literal = syn::LitStr::new(&hierarchy_name, struct_ident.span());

    let generated = quote! {
        impl ::closure_tree::ClosureTreeModel for #struct_ident {
            type Entity = Entity;
            type ActiveModel = ActiveModel;
            type Id = #id_type;

            type HierarchyEntity = #hierarchy_module_path::Entity;
            type HierarchyModel = #hierarchy_module_path::Model;
            type HierarchyActiveModel = #hierarchy_module_path::ActiveModel;

            fn closure_tree_config() -> &'static ::closure_tree::ClosureTreeConfig {
                static CONFIG: ::once_cell::sync::Lazy<::closure_tree::ClosureTreeConfig> =
                    ::once_cell::sync::Lazy::new(|| {
                        let base = ::closure_tree::ClosureTreeConfig::new(
                            #entity_name_literal,
                            #hierarchy_name_literal,
                        );
                        ::closure_tree::ClosureTreeOptions::default()
                            .parent_column(#parent_column_literal)
                            .name_column(#name_column_literal)
                            .hierarchy_table(#hierarchy_table_literal)
                            .apply(base)
                    });
                &CONFIG
            }

            fn id(&self) -> Self::Id {
                self.#id_field_ident.clone()
            }

            fn parent_id(&self) -> Option<Self::Id> {
                self.#parent_field_ident.clone()
            }

            fn set_parent(active: &mut Self::ActiveModel, parent: Option<Self::Id>) {
                active.#parent_field_ident = ::sea_orm::ActiveValue::Set(parent);
            }

            fn id_to_value(id: &Self::Id) -> ::sea_orm::Value {
                ::sea_orm::Value::from(id.clone())
            }

            fn name(&self) -> &str {
                self.#name_field_ident.as_str()
            }

            fn set_name(active: &mut Self::ActiveModel, name: &str) {
                active.#name_field_ident = ::sea_orm::ActiveValue::Set(name.to_owned());
            }

            fn parent_column() -> <Self::Entity as ::sea_orm::EntityTrait>::Column {
                Column::#parent_column_variant
            }

            fn id_column() -> <Self::Entity as ::sea_orm::EntityTrait>::Column {
                Column::#id_column_variant
            }

            fn name_column() -> <Self::Entity as ::sea_orm::EntityTrait>::Column {
                Column::#name_column_variant
            }

            fn hierarchy_ancestor_column() -> <Self::HierarchyEntity as ::sea_orm::EntityTrait>::Column {
                #hierarchy_module_path::Column::#ancestor_column_variant
            }

            fn hierarchy_descendant_column() -> <Self::HierarchyEntity as ::sea_orm::EntityTrait>::Column {
                #hierarchy_module_path::Column::#descendant_column_variant
            }

            fn hierarchy_generations_column() -> <Self::HierarchyEntity as ::sea_orm::EntityTrait>::Column {
                #hierarchy_module_path::Column::#generations_column_variant
            }

            fn hierarchy_id_to_value(id: &Self::Id) -> ::sea_orm::Value {
                ::sea_orm::Value::from(id.clone())
            }

            fn hierarchy_model_ancestor(model: &Self::HierarchyModel) -> Self::Id {
                model.#ancestor_field_ident.clone()
            }

            fn hierarchy_model_descendant(model: &Self::HierarchyModel) -> Self::Id {
                model.#descendant_field_ident.clone()
            }

            fn hierarchy_model_generations(model: &Self::HierarchyModel) -> i32 {
                model.#generations_field_ident
            }

            fn hierarchy_build_row(
                ancestor: Self::Id,
                descendant: Self::Id,
                generations: i32,
            ) -> Self::HierarchyActiveModel {
                #[allow(clippy::needless_update)]
                {
                    #hierarchy_module_path::ActiveModel {
                        #ancestor_field_ident: ::sea_orm::ActiveValue::Set(ancestor),
                        #descendant_field_ident: ::sea_orm::ActiveValue::Set(descendant),
                        #generations_field_ident: ::sea_orm::ActiveValue::Set(generations),
                        ..::core::default::Default::default()
                    }
                }
            }
        }
    };

    Ok(generated.into())
}

fn parse_closure_tree_attr(attr: &Attribute, options: &mut Options) -> syn::Result<()> {
    attr.parse_nested_meta(|meta| {
        let ident = meta
            .path
            .get_ident()
            .ok_or_else(|| syn::Error::new(meta.path.span(), "Invalid option key"))?
            .to_string();

        match ident.as_str() {
            "id_field" => {
                let value: syn::LitStr = meta.value()?.parse()?;
                options.id_field = Some(value.value());
            }
            "parent_field" => {
                let value: syn::LitStr = meta.value()?.parse()?;
                options.parent_field = Some(value.value());
            }
            "name_field" => {
                let value: syn::LitStr = meta.value()?.parse()?;
                options.name_field = Some(value.value());
            }
            "hierarchy_module" => {
                let value: syn::LitStr = meta.value()?.parse()?;
                options.hierarchy_module = Some(parse_path(&value.value(), value.span())?);
            }
            "hierarchy_table" => {
                let value: syn::LitStr = meta.value()?.parse()?;
                options.hierarchy_table = Some(value.value());
            }
            "entity_name" => {
                let value: syn::LitStr = meta.value()?.parse()?;
                options.entity_name = Some(value.value());
            }
            "hierarchy_name" => {
                let value: syn::LitStr = meta.value()?.parse()?;
                options.hierarchy_name = Some(value.value());
            }
            "ancestor_field" => {
                let value: syn::LitStr = meta.value()?.parse()?;
                options.ancestor_field = Some(value.value());
            }
            "descendant_field" => {
                let value: syn::LitStr = meta.value()?.parse()?;
                options.descendant_field = Some(value.value());
            }
            "generations_field" => {
                let value: syn::LitStr = meta.value()?.parse()?;
                options.generations_field = Some(value.value());
            }
            "id_type" => {
                let ty: Type = meta.value()?.parse()?;
                options.id_type = Some(ty);
            }
            other => {
                return Err(syn::Error::new(
                    meta.path.span(),
                    format!("Unsupported closure_tree option `{other}`"),
                ));
            }
        }

        Ok(())
    })
}

fn parse_sea_orm_table_name(attr: &Attribute) -> syn::Result<Option<String>> {
    let mut table_name: Option<String> = None;
    attr.parse_nested_meta(|meta| {
        if meta.path.is_ident("table_name") {
            let value: syn::LitStr = meta.value()?.parse()?;
            table_name = Some(value.value());
        }
        Ok(())
    })?;
    Ok(table_name)
}

fn parse_path(value: &str, span: proc_macro2::Span) -> syn::Result<Path> {
    syn::parse_str::<Path>(value).map_err(|_| syn::Error::new(span, "Invalid path"))
}

fn to_pascal_case(value: &str) -> String {
    value
        .split('_')
        .filter(|segment| !segment.is_empty())
        .map(|segment| {
            let mut chars = segment.chars();
            match chars.next() {
                Some(first) => first.to_ascii_uppercase().to_string() + chars.as_str(),
                None => String::new(),
            }
        })
        .collect()
}