closure_tree_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::ext::IdentExt;
4use syn::{
5    parse_macro_input, spanned::Spanned, Attribute, Data, DeriveInput, Fields, Ident, Path, Type,
6};
7
8#[proc_macro_derive(ClosureTreeModel, attributes(closure_tree))]
9pub fn derive_closure_tree_model(input: TokenStream) -> TokenStream {
10    let input = parse_macro_input!(input as DeriveInput);
11
12    match impl_closure_tree_model(&input) {
13        Ok(tokens) => tokens,
14        Err(err) => err.to_compile_error().into(),
15    }
16}
17
18#[derive(Default)]
19struct Options {
20    id_field: Option<String>,
21    id_type: Option<Type>,
22    parent_field: Option<String>,
23    hierarchy_module: Option<Path>,
24    hierarchy_table: Option<String>,
25    name_field: Option<String>,
26    entity_name: Option<String>,
27    hierarchy_name: Option<String>,
28    ancestor_field: Option<String>,
29    descendant_field: Option<String>,
30    generations_field: Option<String>,
31}
32
33fn impl_closure_tree_model(input: &DeriveInput) -> syn::Result<TokenStream> {
34    let struct_ident = &input.ident;
35
36    let data_struct = match &input.data {
37        Data::Struct(data) => data,
38        _ => {
39            return Err(syn::Error::new(
40                input.span(),
41                "ClosureTreeModel can only be derived for structs",
42            ))
43        }
44    };
45
46    let mut options = Options::default();
47    let mut table_name: Option<String> = None;
48
49    for attr in &input.attrs {
50        if attr.path().is_ident("closure_tree") {
51            parse_closure_tree_attr(attr, &mut options)?;
52        }
53
54        if attr.path().is_ident("sea_orm") {
55            if let Some(name) = parse_sea_orm_table_name(attr)? {
56                table_name = Some(name);
57            }
58        }
59    }
60
61    let id_field_name = options.id_field.unwrap_or_else(|| "id".to_string());
62    let parent_field_name = options
63        .parent_field
64        .unwrap_or_else(|| "parent_id".to_string());
65    let name_field_name = options.name_field.unwrap_or_else(|| "name".to_string());
66    let ancestor_field_name = options
67        .ancestor_field
68        .unwrap_or_else(|| "ancestor_id".to_string());
69    let descendant_field_name = options
70        .descendant_field
71        .unwrap_or_else(|| "descendant_id".to_string());
72    let generations_field_name = options
73        .generations_field
74        .unwrap_or_else(|| "generations".to_string());
75
76    let id_field_ident = Ident::new(&id_field_name, struct_ident.span());
77    let parent_field_ident = Ident::new(&parent_field_name, struct_ident.span());
78    let name_field_ident = Ident::new(&name_field_name, struct_ident.span());
79    let ancestor_field_ident = Ident::new(&ancestor_field_name, struct_ident.span());
80    let descendant_field_ident = Ident::new(&descendant_field_name, struct_ident.span());
81    let generations_field_ident = Ident::new(&generations_field_name, struct_ident.span());
82
83    let mut id_field_type: Option<Type> = options.id_type.clone();
84
85    if let Fields::Named(ref fields) = data_struct.fields {
86        for field in &fields.named {
87            if let Some(ident) = &field.ident {
88                if ident == &id_field_ident && id_field_type.is_none() {
89                    id_field_type = Some(field.ty.clone());
90                }
91            }
92        }
93    } else {
94        return Err(syn::Error::new(
95            data_struct.fields.span(),
96            "ClosureTreeModel requires named fields",
97        ));
98    }
99
100    let id_type = id_field_type.ok_or_else(|| {
101        syn::Error::new(
102            struct_ident.span(),
103            "Unable to determine id field type; specify `id_type = ...` in #[closure_tree]",
104        )
105    })?;
106
107    let hierarchy_module_path = options
108        .hierarchy_module
109        .ok_or_else(|| syn::Error::new(struct_ident.span(), "`hierarchy_module` must be set"))?;
110
111    let entity_name = options
112        .entity_name
113        .unwrap_or_else(|| struct_ident.unraw().to_string());
114    let hierarchy_name = options.hierarchy_name.unwrap_or_else(|| {
115        if entity_name.ends_with("Hierarchy") {
116            entity_name.clone()
117        } else {
118            format!("{}Hierarchy", entity_name)
119        }
120    });
121
122    let base_table = table_name.unwrap_or_else(|| struct_ident.unraw().to_string());
123    let hierarchy_table = options
124        .hierarchy_table
125        .unwrap_or_else(|| format!("{}_hierarchies", base_table));
126
127    let id_column_variant = format_ident!("{}", to_pascal_case(&id_field_name));
128    let parent_column_variant = format_ident!("{}", to_pascal_case(&parent_field_name));
129    let name_column_variant = format_ident!("{}", to_pascal_case(&name_field_name));
130    let ancestor_column_variant = format_ident!("{}", to_pascal_case(&ancestor_field_name));
131    let descendant_column_variant = format_ident!("{}", to_pascal_case(&descendant_field_name));
132    let generations_column_variant = format_ident!("{}", to_pascal_case(&generations_field_name));
133
134    let parent_column_literal = syn::LitStr::new(&parent_field_name, struct_ident.span());
135    let name_column_literal = syn::LitStr::new(&name_field_name, struct_ident.span());
136    let hierarchy_table_literal = syn::LitStr::new(&hierarchy_table, struct_ident.span());
137    let entity_name_literal = syn::LitStr::new(&entity_name, struct_ident.span());
138    let hierarchy_name_literal = syn::LitStr::new(&hierarchy_name, struct_ident.span());
139
140    let generated = quote! {
141        impl ::closure_tree::ClosureTreeModel for #struct_ident {
142            type Entity = Entity;
143            type ActiveModel = ActiveModel;
144            type Id = #id_type;
145
146            type HierarchyEntity = #hierarchy_module_path::Entity;
147            type HierarchyModel = #hierarchy_module_path::Model;
148            type HierarchyActiveModel = #hierarchy_module_path::ActiveModel;
149
150            fn closure_tree_config() -> &'static ::closure_tree::ClosureTreeConfig {
151                static CONFIG: ::once_cell::sync::Lazy<::closure_tree::ClosureTreeConfig> =
152                    ::once_cell::sync::Lazy::new(|| {
153                        let base = ::closure_tree::ClosureTreeConfig::new(
154                            #entity_name_literal,
155                            #hierarchy_name_literal,
156                        );
157                        ::closure_tree::ClosureTreeOptions::default()
158                            .parent_column(#parent_column_literal)
159                            .name_column(#name_column_literal)
160                            .hierarchy_table(#hierarchy_table_literal)
161                            .apply(base)
162                    });
163                &CONFIG
164            }
165
166            fn id(&self) -> Self::Id {
167                self.#id_field_ident.clone()
168            }
169
170            fn parent_id(&self) -> Option<Self::Id> {
171                self.#parent_field_ident.clone()
172            }
173
174            fn set_parent(active: &mut Self::ActiveModel, parent: Option<Self::Id>) {
175                active.#parent_field_ident = ::sea_orm::ActiveValue::Set(parent);
176            }
177
178            fn id_to_value(id: &Self::Id) -> ::sea_orm::Value {
179                ::sea_orm::Value::from(id.clone())
180            }
181
182            fn name(&self) -> &str {
183                self.#name_field_ident.as_str()
184            }
185
186            fn set_name(active: &mut Self::ActiveModel, name: &str) {
187                active.#name_field_ident = ::sea_orm::ActiveValue::Set(name.to_owned());
188            }
189
190            fn parent_column() -> <Self::Entity as ::sea_orm::EntityTrait>::Column {
191                Column::#parent_column_variant
192            }
193
194            fn id_column() -> <Self::Entity as ::sea_orm::EntityTrait>::Column {
195                Column::#id_column_variant
196            }
197
198            fn name_column() -> <Self::Entity as ::sea_orm::EntityTrait>::Column {
199                Column::#name_column_variant
200            }
201
202            fn hierarchy_ancestor_column() -> <Self::HierarchyEntity as ::sea_orm::EntityTrait>::Column {
203                #hierarchy_module_path::Column::#ancestor_column_variant
204            }
205
206            fn hierarchy_descendant_column() -> <Self::HierarchyEntity as ::sea_orm::EntityTrait>::Column {
207                #hierarchy_module_path::Column::#descendant_column_variant
208            }
209
210            fn hierarchy_generations_column() -> <Self::HierarchyEntity as ::sea_orm::EntityTrait>::Column {
211                #hierarchy_module_path::Column::#generations_column_variant
212            }
213
214            fn hierarchy_id_to_value(id: &Self::Id) -> ::sea_orm::Value {
215                ::sea_orm::Value::from(id.clone())
216            }
217
218            fn hierarchy_model_ancestor(model: &Self::HierarchyModel) -> Self::Id {
219                model.#ancestor_field_ident.clone()
220            }
221
222            fn hierarchy_model_descendant(model: &Self::HierarchyModel) -> Self::Id {
223                model.#descendant_field_ident.clone()
224            }
225
226            fn hierarchy_model_generations(model: &Self::HierarchyModel) -> i32 {
227                model.#generations_field_ident
228            }
229
230            fn hierarchy_build_row(
231                ancestor: Self::Id,
232                descendant: Self::Id,
233                generations: i32,
234            ) -> Self::HierarchyActiveModel {
235                #[allow(clippy::needless_update)]
236                {
237                    #hierarchy_module_path::ActiveModel {
238                        #ancestor_field_ident: ::sea_orm::ActiveValue::Set(ancestor),
239                        #descendant_field_ident: ::sea_orm::ActiveValue::Set(descendant),
240                        #generations_field_ident: ::sea_orm::ActiveValue::Set(generations),
241                        ..::core::default::Default::default()
242                    }
243                }
244            }
245        }
246    };
247
248    Ok(generated.into())
249}
250
251fn parse_closure_tree_attr(attr: &Attribute, options: &mut Options) -> syn::Result<()> {
252    attr.parse_nested_meta(|meta| {
253        let ident = meta
254            .path
255            .get_ident()
256            .ok_or_else(|| syn::Error::new(meta.path.span(), "Invalid option key"))?
257            .to_string();
258
259        match ident.as_str() {
260            "id_field" => {
261                let value: syn::LitStr = meta.value()?.parse()?;
262                options.id_field = Some(value.value());
263            }
264            "parent_field" => {
265                let value: syn::LitStr = meta.value()?.parse()?;
266                options.parent_field = Some(value.value());
267            }
268            "name_field" => {
269                let value: syn::LitStr = meta.value()?.parse()?;
270                options.name_field = Some(value.value());
271            }
272            "hierarchy_module" => {
273                let value: syn::LitStr = meta.value()?.parse()?;
274                options.hierarchy_module = Some(parse_path(&value.value(), value.span())?);
275            }
276            "hierarchy_table" => {
277                let value: syn::LitStr = meta.value()?.parse()?;
278                options.hierarchy_table = Some(value.value());
279            }
280            "entity_name" => {
281                let value: syn::LitStr = meta.value()?.parse()?;
282                options.entity_name = Some(value.value());
283            }
284            "hierarchy_name" => {
285                let value: syn::LitStr = meta.value()?.parse()?;
286                options.hierarchy_name = Some(value.value());
287            }
288            "ancestor_field" => {
289                let value: syn::LitStr = meta.value()?.parse()?;
290                options.ancestor_field = Some(value.value());
291            }
292            "descendant_field" => {
293                let value: syn::LitStr = meta.value()?.parse()?;
294                options.descendant_field = Some(value.value());
295            }
296            "generations_field" => {
297                let value: syn::LitStr = meta.value()?.parse()?;
298                options.generations_field = Some(value.value());
299            }
300            "id_type" => {
301                let ty: Type = meta.value()?.parse()?;
302                options.id_type = Some(ty);
303            }
304            other => {
305                return Err(syn::Error::new(
306                    meta.path.span(),
307                    format!("Unsupported closure_tree option `{other}`"),
308                ));
309            }
310        }
311
312        Ok(())
313    })
314}
315
316fn parse_sea_orm_table_name(attr: &Attribute) -> syn::Result<Option<String>> {
317    let mut table_name: Option<String> = None;
318    attr.parse_nested_meta(|meta| {
319        if meta.path.is_ident("table_name") {
320            let value: syn::LitStr = meta.value()?.parse()?;
321            table_name = Some(value.value());
322        }
323        Ok(())
324    })?;
325    Ok(table_name)
326}
327
328fn parse_path(value: &str, span: proc_macro2::Span) -> syn::Result<Path> {
329    syn::parse_str::<Path>(value).map_err(|_| syn::Error::new(span, "Invalid path"))
330}
331
332fn to_pascal_case(value: &str) -> String {
333    value
334        .split('_')
335        .filter(|segment| !segment.is_empty())
336        .map(|segment| {
337            let mut chars = segment.chars();
338            match chars.next() {
339                Some(first) => first.to_ascii_uppercase().to_string() + chars.as_str(),
340                None => String::new(),
341            }
342        })
343        .collect()
344}