closure_tree_macros/
lib.rs1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::ext::IdentExt;
4use syn::{
5 parse_macro_input, spanned::Spanned, Attribute, Data, DeriveInput, Fields, Ident, Path, Type,
6};
7
8#[proc_macro_derive(ClosureTreeModel, attributes(closure_tree))]
9pub fn derive_closure_tree_model(input: TokenStream) -> TokenStream {
10 let input = parse_macro_input!(input as DeriveInput);
11
12 match impl_closure_tree_model(&input) {
13 Ok(tokens) => tokens,
14 Err(err) => err.to_compile_error().into(),
15 }
16}
17
18#[derive(Default)]
19struct Options {
20 id_field: Option<String>,
21 id_type: Option<Type>,
22 parent_field: Option<String>,
23 hierarchy_module: Option<Path>,
24 hierarchy_table: Option<String>,
25 name_field: Option<String>,
26 entity_name: Option<String>,
27 hierarchy_name: Option<String>,
28 ancestor_field: Option<String>,
29 descendant_field: Option<String>,
30 generations_field: Option<String>,
31}
32
33fn impl_closure_tree_model(input: &DeriveInput) -> syn::Result<TokenStream> {
34 let struct_ident = &input.ident;
35
36 let data_struct = match &input.data {
37 Data::Struct(data) => data,
38 _ => {
39 return Err(syn::Error::new(
40 input.span(),
41 "ClosureTreeModel can only be derived for structs",
42 ))
43 }
44 };
45
46 let mut options = Options::default();
47 let mut table_name: Option<String> = None;
48
49 for attr in &input.attrs {
50 if attr.path().is_ident("closure_tree") {
51 parse_closure_tree_attr(attr, &mut options)?;
52 }
53
54 if attr.path().is_ident("sea_orm") {
55 if let Some(name) = parse_sea_orm_table_name(attr)? {
56 table_name = Some(name);
57 }
58 }
59 }
60
61 let id_field_name = options.id_field.unwrap_or_else(|| "id".to_string());
62 let parent_field_name = options
63 .parent_field
64 .unwrap_or_else(|| "parent_id".to_string());
65 let name_field_name = options.name_field.unwrap_or_else(|| "name".to_string());
66 let ancestor_field_name = options
67 .ancestor_field
68 .unwrap_or_else(|| "ancestor_id".to_string());
69 let descendant_field_name = options
70 .descendant_field
71 .unwrap_or_else(|| "descendant_id".to_string());
72 let generations_field_name = options
73 .generations_field
74 .unwrap_or_else(|| "generations".to_string());
75
76 let id_field_ident = Ident::new(&id_field_name, struct_ident.span());
77 let parent_field_ident = Ident::new(&parent_field_name, struct_ident.span());
78 let name_field_ident = Ident::new(&name_field_name, struct_ident.span());
79 let ancestor_field_ident = Ident::new(&ancestor_field_name, struct_ident.span());
80 let descendant_field_ident = Ident::new(&descendant_field_name, struct_ident.span());
81 let generations_field_ident = Ident::new(&generations_field_name, struct_ident.span());
82
83 let mut id_field_type: Option<Type> = options.id_type.clone();
84
85 if let Fields::Named(ref fields) = data_struct.fields {
86 for field in &fields.named {
87 if let Some(ident) = &field.ident {
88 if ident == &id_field_ident && id_field_type.is_none() {
89 id_field_type = Some(field.ty.clone());
90 }
91 }
92 }
93 } else {
94 return Err(syn::Error::new(
95 data_struct.fields.span(),
96 "ClosureTreeModel requires named fields",
97 ));
98 }
99
100 let id_type = id_field_type.ok_or_else(|| {
101 syn::Error::new(
102 struct_ident.span(),
103 "Unable to determine id field type; specify `id_type = ...` in #[closure_tree]",
104 )
105 })?;
106
107 let hierarchy_module_path = options
108 .hierarchy_module
109 .ok_or_else(|| syn::Error::new(struct_ident.span(), "`hierarchy_module` must be set"))?;
110
111 let entity_name = options
112 .entity_name
113 .unwrap_or_else(|| struct_ident.unraw().to_string());
114 let hierarchy_name = options.hierarchy_name.unwrap_or_else(|| {
115 if entity_name.ends_with("Hierarchy") {
116 entity_name.clone()
117 } else {
118 format!("{}Hierarchy", entity_name)
119 }
120 });
121
122 let base_table = table_name.unwrap_or_else(|| struct_ident.unraw().to_string());
123 let hierarchy_table = options
124 .hierarchy_table
125 .unwrap_or_else(|| format!("{}_hierarchies", base_table));
126
127 let id_column_variant = format_ident!("{}", to_pascal_case(&id_field_name));
128 let parent_column_variant = format_ident!("{}", to_pascal_case(&parent_field_name));
129 let name_column_variant = format_ident!("{}", to_pascal_case(&name_field_name));
130 let ancestor_column_variant = format_ident!("{}", to_pascal_case(&ancestor_field_name));
131 let descendant_column_variant = format_ident!("{}", to_pascal_case(&descendant_field_name));
132 let generations_column_variant = format_ident!("{}", to_pascal_case(&generations_field_name));
133
134 let parent_column_literal = syn::LitStr::new(&parent_field_name, struct_ident.span());
135 let name_column_literal = syn::LitStr::new(&name_field_name, struct_ident.span());
136 let hierarchy_table_literal = syn::LitStr::new(&hierarchy_table, struct_ident.span());
137 let entity_name_literal = syn::LitStr::new(&entity_name, struct_ident.span());
138 let hierarchy_name_literal = syn::LitStr::new(&hierarchy_name, struct_ident.span());
139
140 let generated = quote! {
141 impl ::closure_tree::ClosureTreeModel for #struct_ident {
142 type Entity = Entity;
143 type ActiveModel = ActiveModel;
144 type Id = #id_type;
145
146 type HierarchyEntity = #hierarchy_module_path::Entity;
147 type HierarchyModel = #hierarchy_module_path::Model;
148 type HierarchyActiveModel = #hierarchy_module_path::ActiveModel;
149
150 fn closure_tree_config() -> &'static ::closure_tree::ClosureTreeConfig {
151 static CONFIG: ::once_cell::sync::Lazy<::closure_tree::ClosureTreeConfig> =
152 ::once_cell::sync::Lazy::new(|| {
153 let base = ::closure_tree::ClosureTreeConfig::new(
154 #entity_name_literal,
155 #hierarchy_name_literal,
156 );
157 ::closure_tree::ClosureTreeOptions::default()
158 .parent_column(#parent_column_literal)
159 .name_column(#name_column_literal)
160 .hierarchy_table(#hierarchy_table_literal)
161 .apply(base)
162 });
163 &CONFIG
164 }
165
166 fn id(&self) -> Self::Id {
167 self.#id_field_ident.clone()
168 }
169
170 fn parent_id(&self) -> Option<Self::Id> {
171 self.#parent_field_ident.clone()
172 }
173
174 fn set_parent(active: &mut Self::ActiveModel, parent: Option<Self::Id>) {
175 active.#parent_field_ident = ::sea_orm::ActiveValue::Set(parent);
176 }
177
178 fn id_to_value(id: &Self::Id) -> ::sea_orm::Value {
179 ::sea_orm::Value::from(id.clone())
180 }
181
182 fn name(&self) -> &str {
183 self.#name_field_ident.as_str()
184 }
185
186 fn set_name(active: &mut Self::ActiveModel, name: &str) {
187 active.#name_field_ident = ::sea_orm::ActiveValue::Set(name.to_owned());
188 }
189
190 fn parent_column() -> <Self::Entity as ::sea_orm::EntityTrait>::Column {
191 Column::#parent_column_variant
192 }
193
194 fn id_column() -> <Self::Entity as ::sea_orm::EntityTrait>::Column {
195 Column::#id_column_variant
196 }
197
198 fn name_column() -> <Self::Entity as ::sea_orm::EntityTrait>::Column {
199 Column::#name_column_variant
200 }
201
202 fn hierarchy_ancestor_column() -> <Self::HierarchyEntity as ::sea_orm::EntityTrait>::Column {
203 #hierarchy_module_path::Column::#ancestor_column_variant
204 }
205
206 fn hierarchy_descendant_column() -> <Self::HierarchyEntity as ::sea_orm::EntityTrait>::Column {
207 #hierarchy_module_path::Column::#descendant_column_variant
208 }
209
210 fn hierarchy_generations_column() -> <Self::HierarchyEntity as ::sea_orm::EntityTrait>::Column {
211 #hierarchy_module_path::Column::#generations_column_variant
212 }
213
214 fn hierarchy_id_to_value(id: &Self::Id) -> ::sea_orm::Value {
215 ::sea_orm::Value::from(id.clone())
216 }
217
218 fn hierarchy_model_ancestor(model: &Self::HierarchyModel) -> Self::Id {
219 model.#ancestor_field_ident.clone()
220 }
221
222 fn hierarchy_model_descendant(model: &Self::HierarchyModel) -> Self::Id {
223 model.#descendant_field_ident.clone()
224 }
225
226 fn hierarchy_model_generations(model: &Self::HierarchyModel) -> i32 {
227 model.#generations_field_ident
228 }
229
230 fn hierarchy_build_row(
231 ancestor: Self::Id,
232 descendant: Self::Id,
233 generations: i32,
234 ) -> Self::HierarchyActiveModel {
235 #[allow(clippy::needless_update)]
236 {
237 #hierarchy_module_path::ActiveModel {
238 #ancestor_field_ident: ::sea_orm::ActiveValue::Set(ancestor),
239 #descendant_field_ident: ::sea_orm::ActiveValue::Set(descendant),
240 #generations_field_ident: ::sea_orm::ActiveValue::Set(generations),
241 ..::core::default::Default::default()
242 }
243 }
244 }
245 }
246 };
247
248 Ok(generated.into())
249}
250
251fn parse_closure_tree_attr(attr: &Attribute, options: &mut Options) -> syn::Result<()> {
252 attr.parse_nested_meta(|meta| {
253 let ident = meta
254 .path
255 .get_ident()
256 .ok_or_else(|| syn::Error::new(meta.path.span(), "Invalid option key"))?
257 .to_string();
258
259 match ident.as_str() {
260 "id_field" => {
261 let value: syn::LitStr = meta.value()?.parse()?;
262 options.id_field = Some(value.value());
263 }
264 "parent_field" => {
265 let value: syn::LitStr = meta.value()?.parse()?;
266 options.parent_field = Some(value.value());
267 }
268 "name_field" => {
269 let value: syn::LitStr = meta.value()?.parse()?;
270 options.name_field = Some(value.value());
271 }
272 "hierarchy_module" => {
273 let value: syn::LitStr = meta.value()?.parse()?;
274 options.hierarchy_module = Some(parse_path(&value.value(), value.span())?);
275 }
276 "hierarchy_table" => {
277 let value: syn::LitStr = meta.value()?.parse()?;
278 options.hierarchy_table = Some(value.value());
279 }
280 "entity_name" => {
281 let value: syn::LitStr = meta.value()?.parse()?;
282 options.entity_name = Some(value.value());
283 }
284 "hierarchy_name" => {
285 let value: syn::LitStr = meta.value()?.parse()?;
286 options.hierarchy_name = Some(value.value());
287 }
288 "ancestor_field" => {
289 let value: syn::LitStr = meta.value()?.parse()?;
290 options.ancestor_field = Some(value.value());
291 }
292 "descendant_field" => {
293 let value: syn::LitStr = meta.value()?.parse()?;
294 options.descendant_field = Some(value.value());
295 }
296 "generations_field" => {
297 let value: syn::LitStr = meta.value()?.parse()?;
298 options.generations_field = Some(value.value());
299 }
300 "id_type" => {
301 let ty: Type = meta.value()?.parse()?;
302 options.id_type = Some(ty);
303 }
304 other => {
305 return Err(syn::Error::new(
306 meta.path.span(),
307 format!("Unsupported closure_tree option `{other}`"),
308 ));
309 }
310 }
311
312 Ok(())
313 })
314}
315
316fn parse_sea_orm_table_name(attr: &Attribute) -> syn::Result<Option<String>> {
317 let mut table_name: Option<String> = None;
318 attr.parse_nested_meta(|meta| {
319 if meta.path.is_ident("table_name") {
320 let value: syn::LitStr = meta.value()?.parse()?;
321 table_name = Some(value.value());
322 }
323 Ok(())
324 })?;
325 Ok(table_name)
326}
327
328fn parse_path(value: &str, span: proc_macro2::Span) -> syn::Result<Path> {
329 syn::parse_str::<Path>(value).map_err(|_| syn::Error::new(span, "Invalid path"))
330}
331
332fn to_pascal_case(value: &str) -> String {
333 value
334 .split('_')
335 .filter(|segment| !segment.is_empty())
336 .map(|segment| {
337 let mut chars = segment.chars();
338 match chars.next() {
339 Some(first) => first.to_ascii_uppercase().to_string() + chars.as_str(),
340 None => String::new(),
341 }
342 })
343 .collect()
344}