bloom_web_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{quote, format_ident};
3use syn::{parse_macro_input, Data, DeriveInput, Fields, Ident, LitStr, ItemStruct, ItemImpl, ImplItem, TypePath};
4use syn::spanned::Spanned;
5
6/// Helper struct to parse relation attributes
7#[derive(Default)]
8#[allow(dead_code)]
9struct Relation {
10    kind: Option<String>,
11    target: Option<LitStr>,
12    foreign_key: Option<LitStr>,
13    column: Option<LitStr>,
14}
15
16/// Derives Entity functionality for structs, generating database table creation and CRUD operations.
17///
18/// This macro generates implementations for database entity management including:
19/// - Table creation SQL generation
20/// - Insert, update, upsert operations
21/// - Migration registration
22///
23/// # Attributes
24/// - `#[table("table_name")]` - Specifies the database table name
25/// - `#[id]` - Marks a field as the primary key
26/// - `#[column("column_name")]` - Specifies the database column name
27/// - Relation attributes: `#[one_to_many]`, `#[many_to_one]`, `#[many_to_many]`, `#[one_to_one]`
28/// - `#[join_column("column_name")]` - Specifies the foreign key column name
29#[proc_macro_derive(Entity, attributes(table, id, column, one_to_many, many_to_one, many_to_many, one_to_one, join_column))]
30pub fn derive_entity(input: TokenStream) -> TokenStream {
31    let input = parse_macro_input!(input as DeriveInput);
32    let struct_name = &input.ident;
33
34    let fields = match input.data {
35        Data::Struct(data_struct) => match data_struct.fields {
36            Fields::Named(fields) => fields.named,
37            _ => panic!("Entity must be a struct with named fields"),
38        },
39        _ => panic!("Entity must be a struct"),
40    };
41
42    let table_attr = input.attrs.iter().find(|attr| attr.path().is_ident("table"));
43    let table_name: LitStr = match table_attr {
44        Some(attr) => attr.parse_args().expect("Invalid table attribute"),
45        None => panic!("Missing #[table(\"table_name\")]"),
46    };
47
48    let mut id_field: Option<Ident> = None;
49    let mut id_column_name_opt: Option<String> = None;
50    let mut column_defs: Vec<LitStr> = Vec::new();
51    let mut index_defs: Vec<LitStr> = Vec::new();
52    let mut insert_cols: Vec<LitStr> = Vec::new();
53    let mut insert_idents: Vec<Ident> = Vec::new();
54    let mut update_cols: Vec<LitStr> = Vec::new();
55    let mut update_idents: Vec<Ident> = Vec::new();
56
57    for field in fields.iter() {
58        let field_name = field.ident.as_ref().unwrap().clone();
59        let field_ty = &field.ty;
60
61        let is_id = field.attrs.iter().any(|attr| attr.path().is_ident("id"));
62        if is_id {
63            if id_field.is_some() {
64                panic!("Only one #[id] field allowed");
65            }
66            id_field = Some(field_name.clone());
67        }
68
69        let has_one_to_many = field.attrs.iter().any(|a| a.path().is_ident("one_to_many"));
70        let has_many_to_many = field.attrs.iter().any(|a| a.path().is_ident("many_to_many"));
71        let has_one_to_one  = field.attrs.iter().any(|a| a.path().is_ident("one_to_one"));
72        let is_virtual_relation = has_one_to_many || has_many_to_many || has_one_to_one;
73        if is_virtual_relation {
74            continue;
75        }
76
77        let has_many_to_one = field.attrs.iter().any(|a| a.path().is_ident("many_to_one"));
78        if has_many_to_one {
79            if let syn::Type::Path(_type_path) = field_ty {
80                if let Some(attr) = field.attrs.iter().find(|a| a.path().is_ident("join_column")) {
81                    let join_name = attr.parse_args::<LitStr>().expect("Invalid join_column attribute").value();
82                    let col_def = format!("{} {}", join_name, "INTEGER");
83                    column_defs.push(LitStr::new(&col_def, field.span()));
84                    let idx_def = format!("INDEX ({})", join_name);
85                    index_defs.push(LitStr::new(&idx_def, field.span()));
86                    continue;
87                }
88            }
89        }
90
91        let column_attr = field.attrs.iter().find(|a| a.path().is_ident("column"));
92        let column_name = match column_attr {
93            Some(attr) => attr.parse_args::<LitStr>().expect("Invalid column attribute").value(),
94            None => field_name.to_string(),
95        };
96        if is_id {
97            id_column_name_opt = Some(column_name.clone());
98        }
99
100        let sql_type = if let syn::Type::Path(type_path) = field_ty {
101            let type_name = type_path
102                .path
103                .segments
104                .last()
105                .map(|s| s.ident.to_string())
106                .unwrap_or_default();
107            match type_name.as_str() {
108                "i32" => "INTEGER",
109                "i64" => "BIGINT",
110                "f32" => "FLOAT",
111                "f64" => "DOUBLE PRECISION",
112                "String" => "VARCHAR(255)",
113                "bool" => "BOOLEAN",
114                _ => panic!("Unsupported type for field {}: {:?}", field_name, field_ty),
115            }
116        } else {
117            panic!("Unsupported type for field {}: {:?}", field_name, field_ty);
118        };
119
120        let col_def = if is_id && (sql_type == "INTEGER" || sql_type == "BIGINT") {
121            format!("{} {} PRIMARY KEY AUTO_INCREMENT", column_name, sql_type)
122        } else if is_id {
123            format!("{} {} PRIMARY KEY", column_name, sql_type)
124        } else {
125            format!("{} {}", column_name, sql_type)
126        };
127        column_defs.push(LitStr::new(&col_def, field.span()));
128
129
130        if !is_id {
131            insert_cols.push(LitStr::new(&column_name, field.span()));
132            insert_idents.push(field_name.clone());
133            update_cols.push(LitStr::new(&column_name, field.span()));
134            update_idents.push(field_name.clone());
135        }
136    }
137
138    let id_field = id_field.expect("Missing #[id] field for primary key");
139    let _ = id_field;
140
141    let id_column_name = id_column_name_opt.expect("Missing id column name (ensure #[id] is present)");
142
143    let expanded = quote! {
144        #[allow(dead_code)]
145        impl #struct_name {
146            pub const TABLE_NAME: &'static str = #table_name;
147
148            pub async fn create_table(pool: &sqlx::MySqlPool) -> anyhow::Result<()> {
149                let columns: &[&str] = &[#(#column_defs),*];
150                let indexes: &[&str] = &[#(#index_defs),*];
151                let mut parts: Vec<&str> = Vec::new();
152                parts.extend_from_slice(columns);
153                parts.extend_from_slice(indexes);
154                let sql = format!(
155                    "CREATE TABLE IF NOT EXISTS {} ({})",
156                    Self::TABLE_NAME,
157                    parts.join(", ")
158                );
159                sqlx::query(&sql)
160                    .execute(pool)
161                    .await?;
162                Ok(())
163            }
164
165            pub async fn insert(&self, pool: &sqlx::MySqlPool) -> anyhow::Result<u64> {
166                let cols: &[&str] = &[#(#insert_cols),*];
167                let placeholders = vec!["?"; cols.len()].join(", ");
168                let sql = format!(
169                    "INSERT INTO {} ({}) VALUES ({})",
170                    Self::TABLE_NAME,
171                    cols.join(", "),
172                    placeholders
173                );
174                let mut q = sqlx::query(&sql);
175                #( q = q.bind(&self.#insert_idents); )*
176                let res = q.execute(pool).await?;
177                Ok(res.rows_affected())
178            }
179
180            pub async fn update_by_id(&self, pool: &sqlx::MySqlPool) -> anyhow::Result<u64> {
181                let cols: &[&str] = &[#(#update_cols),*];
182                let set_clause = cols.iter().map(|c| format!("{} = ?", c)).collect::<Vec<_>>().join(", ");
183                let sql = format!(
184                    "UPDATE {} SET {} WHERE {} = ?",
185                    Self::TABLE_NAME,
186                    set_clause,
187                    #id_column_name
188                );
189                let mut q = sqlx::query(&sql);
190                #( q = q.bind(&self.#update_idents); )*
191                q = q.bind(&self.#id_field);
192                let res = q.execute(pool).await?;
193                Ok(res.rows_affected())
194            }
195
196            pub async fn upsert(&self, pool: &sqlx::MySqlPool) -> anyhow::Result<u64> {
197                let cols: &[&str] = &[#(#insert_cols),*];
198                let placeholders = vec!["?"; cols.len()].join(", ");
199                let update_part = cols.iter().map(|c| format!("{}=VALUES({})", c, c)).collect::<Vec<_>>().join(", ");
200                let sql = format!(
201                    "INSERT INTO {} ({}) VALUES ({}) ON DUPLICATE KEY UPDATE {}",
202                    Self::TABLE_NAME,
203                    cols.join(", "),
204                    placeholders,
205                    update_part
206                );
207                let mut q = sqlx::query(&sql);
208                #( q = q.bind(&self.#insert_idents); )*
209                let res = q.execute(pool).await?;
210                Ok(res.rows_affected())
211            }
212
213            fn __run_migration<'a>(pool: &'a sqlx::MySqlPool) -> std::pin::Pin<Box<dyn std::future::Future<Output = anyhow::Result<()>> + Send + 'a>> {
214                Box::pin(async move { Self::create_table(pool).await })
215            }
216        }
217
218
219        inventory::submit! {
220            bloom_web_core::entity_registry::Migration {
221                name: #table_name,
222                run: <#struct_name>::__run_migration,
223            }
224        }
225    };
226
227    TokenStream::from(expanded)
228}
229
230#[proc_macro_attribute]
231pub fn repository(attr: TokenStream, item: TokenStream) -> TokenStream {
232    let entity_ty: TypePath = syn::parse(attr).expect("Expected entity type in #[repository(Entity)]");
233    let item_struct: ItemStruct = syn::parse(item.clone()).expect("#[repository] must be used on a unit struct");
234    let repo_name = &item_struct.ident;
235
236    let expanded = quote! {
237        #item_struct
238
239        impl #repo_name {
240            pub async fn find_all_raw(pool: &sqlx::MySqlPool) -> anyhow::Result<Vec<sqlx::mysql::MySqlRow>> {
241                let sql = format!("SELECT * FROM {}", <#entity_ty>::TABLE_NAME);
242                let rows = sqlx::query(&sql).fetch_all(pool).await?;
243                Ok(rows)
244            }
245
246            pub async fn find_by_id_raw(pool: &sqlx::MySqlPool, id: i64) -> anyhow::Result<Option<sqlx::mysql::MySqlRow>> {
247                let sql = format!("SELECT * FROM {} WHERE id = ? LIMIT 1", <#entity_ty>::TABLE_NAME);
248                let row = sqlx::query(&sql).bind(id).fetch_optional(pool).await?;
249                Ok(row)
250            }
251
252            pub async fn exists_by_id(pool: &sqlx::MySqlPool, id: i64) -> anyhow::Result<bool> {
253                let sql = format!("SELECT 1 FROM {} WHERE id = ? LIMIT 1", <#entity_ty>::TABLE_NAME);
254                let row = sqlx::query(&sql).bind(id).fetch_optional(pool).await?;
255                Ok(row.is_some())
256            }
257
258            pub async fn count(pool: &sqlx::MySqlPool) -> anyhow::Result<i64> {
259                let sql = format!("SELECT COUNT(*) as cnt FROM {}", <#entity_ty>::TABLE_NAME);
260                let row: sqlx::mysql::MySqlRow = sqlx::query(&sql).fetch_one(pool).await?;
261                let cnt_by_alias: Result<i64, _> = <sqlx::mysql::MySqlRow as sqlx::Row>::try_get(&row, "cnt");
262                if let Ok(v) = cnt_by_alias { return Ok(v); }
263                let cnt_by_idx: i64 = <sqlx::mysql::MySqlRow as sqlx::Row>::try_get(&row, 0)?;
264                Ok(cnt_by_idx)
265            }
266
267            pub async fn delete_by_id(pool: &sqlx::MySqlPool, id: i64) -> anyhow::Result<u64> {
268                let sql = format!("DELETE FROM {} WHERE id = ?", <#entity_ty>::TABLE_NAME);
269                let res = sqlx::query(&sql).bind(id).execute(pool).await?;
270                Ok(res.rows_affected())
271            }
272
273            pub async fn create(pool: &sqlx::MySqlPool, entity: &#entity_ty) -> anyhow::Result<u64> {
274                entity.insert(pool).await
275            }
276
277            pub async fn update(pool: &sqlx::MySqlPool, entity: &#entity_ty) -> anyhow::Result<u64> {
278                entity.update_by_id(pool).await
279            }
280
281            pub async fn insert_or_update(pool: &sqlx::MySqlPool, entity: &#entity_ty) -> anyhow::Result<u64> {
282                entity.upsert(pool).await
283            }
284
285            pub async fn delete(pool: &sqlx::MySqlPool, id: i64) -> anyhow::Result<u64> {
286                Self::delete_by_id(pool, id).await
287            }
288
289            pub async fn find_all<T>(pool: &sqlx::MySqlPool) -> anyhow::Result<Vec<T>>
290            where
291                for<'r> T: sqlx::FromRow<'r, sqlx::mysql::MySqlRow> + Send + Unpin,
292            {
293                let sql = format!("SELECT * FROM {}", <#entity_ty>::TABLE_NAME);
294                let rows = sqlx::query_as::<_, T>(&sql).fetch_all(pool).await?;
295                Ok(rows)
296            }
297
298            pub async fn find_by_id<T>(pool: &sqlx::MySqlPool, id: i64) -> anyhow::Result<Option<T>>
299            where
300                for<'r> T: sqlx::FromRow<'r, sqlx::mysql::MySqlRow> + Send + Unpin,
301            {
302                let sql = format!("SELECT * FROM {} WHERE id = ? LIMIT 1", <#entity_ty>::TABLE_NAME);
303                let row = sqlx::query_as::<_, T>(&sql).bind(id).fetch_optional(pool).await?;
304                Ok(row)
305            }
306        }
307    };
308
309    TokenStream::from(expanded)
310}
311
312#[proc_macro_attribute]
313pub fn controller(attr: TokenStream, item: TokenStream) -> TokenStream {
314    let base_path: LitStr = syn::parse(attr).expect("Expected base path string in #[controller(\"/path\")] ");
315    let item_impl: ItemImpl = syn::parse(item.clone()).expect("#[controller] must be used on an impl block");
316
317    let self_ty_ident = match *item_impl.self_ty.clone() {
318        syn::Type::Path(ref tp) => tp.path.segments.last().unwrap().ident.clone(),
319        _ => panic!("#[controller] impl must target a concrete named type"),
320    };
321
322    let mut has_get_all = false;
323    let mut has_get_by_id = false;
324    let mut has_create = false;
325    let mut has_update = false;
326    let mut has_delete = false;
327
328    for it in &item_impl.items {
329        if let ImplItem::Fn(f) = it {
330            let name = f.sig.ident.to_string();
331            match name.as_str() {
332                "get_all" => has_get_all = true,
333                "get_by_id" => has_get_by_id = true,
334                "create" => has_create = true,
335                "update" => has_update = true,
336                "delete" => has_delete = true,
337                _ => {}
338            }
339        }
340    }
341
342    let mut routes: Vec<proc_macro2::TokenStream> = Vec::new();
343    if has_get_all {
344        routes.push(quote! { scope = scope.route("", actix_web::web::get().to(<#self_ty_ident>::get_all)); });
345    }
346    if has_get_by_id {
347        routes.push(quote! { scope = scope.route("/{id}", actix_web::web::get().to(<#self_ty_ident>::get_by_id)); });
348    }
349    if has_create {
350        routes.push(quote! { scope = scope.route("", actix_web::web::post().to(<#self_ty_ident>::create)); });
351    }
352    if has_update {
353        routes.push(quote! { scope = scope.route("/{id}", actix_web::web::put().to(<#self_ty_ident>::update)); });
354    }
355    if has_delete {
356        routes.push(quote! { scope = scope.route("/{id}", actix_web::web::delete().to(<#self_ty_ident>::delete)); });
357    }
358
359    let expanded = quote! {
360        #item_impl
361
362        impl #self_ty_ident {
363            fn __configure(cfg: &mut actix_web::web::ServiceConfig) {
364                let mut scope = actix_web::web::scope(#base_path);
365                #(#routes)*
366                cfg.service(scope);
367            }
368        }
369
370        inventory::submit! {
371            bloom_web_core::controller_registry::Controller {
372                name: #base_path,
373                configure: <#self_ty_ident>::__configure,
374            }
375        }
376    };
377
378    TokenStream::from(expanded)
379}
380
381
382#[proc_macro_attribute]
383pub fn auto_register(_attr: TokenStream, item: TokenStream) -> TokenStream {
384    let func: syn::ItemFn = syn::parse(item).expect("#[auto_register] must be used on a free function");
385    let func_name = func.sig.ident.clone();
386    let reg_ty_ident = format_ident!("__AutoReg_{}", func_name);
387
388    let expanded = quote! {
389        #func
390
391        struct #reg_ty_ident;
392        impl #reg_ty_ident {
393            fn __configure(cfg: &mut actix_web::web::ServiceConfig) {
394                cfg.service(#func_name);
395            }
396        }
397
398        inventory::submit! {
399            bloom_web_core::controller_registry::Controller {
400                name: stringify!(#func_name),
401                configure: <#reg_ty_ident>::__configure,
402            }
403        }
404    };
405
406    TokenStream::from(expanded)
407}
408
409#[proc_macro_attribute]
410pub fn get(attr: TokenStream, item: TokenStream) -> TokenStream {
411    http_request("get", attr, item)
412}
413
414#[proc_macro_attribute]
415pub fn post(attr: TokenStream, item: TokenStream) -> TokenStream {
416    http_request("post", attr, item)
417}
418
419#[proc_macro_attribute]
420pub fn put(attr: TokenStream, item: TokenStream) -> TokenStream {
421    http_request("put", attr, item)
422}
423
424#[proc_macro_attribute]
425pub fn delete(attr: TokenStream, item: TokenStream) -> TokenStream {
426    http_request("delete", attr, item)
427}
428
429#[proc_macro_attribute]
430pub fn patch(attr: TokenStream, item: TokenStream) -> TokenStream {
431    http_request("patch", attr, item)
432}
433
434#[proc_macro_attribute]
435pub fn get_mapping(attr: TokenStream, item: TokenStream) -> TokenStream {
436    http_request("get", attr, item)
437}
438
439#[proc_macro_attribute]
440pub fn post_mapping(attr: TokenStream, item: TokenStream) -> TokenStream {
441    http_request("post", attr, item)
442}
443
444#[proc_macro_attribute]
445pub fn put_mapping(attr: TokenStream, item: TokenStream) -> TokenStream {
446    http_request("put", attr, item)
447}
448
449#[proc_macro_attribute]
450pub fn delete_mapping(attr: TokenStream, item: TokenStream) -> TokenStream {
451    http_request("delete", attr, item)
452}
453
454#[proc_macro_attribute]
455pub fn patch_mapping(attr: TokenStream, item: TokenStream) -> TokenStream {
456    http_request("patch", attr, item)
457}
458
459#[proc_macro_attribute]
460pub fn scheduled(attr: TokenStream, item: TokenStream) -> TokenStream {
461    let interval_lit: syn::LitInt = match syn::parse(attr.clone()) {
462        Ok(v) => v,
463        Err(_) => panic!("#[scheduled] expects a millisecond interval literal, e.g., #[scheduled(60000)]"),
464    };
465    let interval_ms: u64 = interval_lit.base10_parse().expect("Invalid millisecond value for #[scheduled]");
466
467    let func: syn::ItemFn = syn::parse(item).expect("#[scheduled] must be used on a free async function");
468    let func_name = func.sig.ident.clone();
469    let reg_ty_ident = format_ident!("__SchedReg_{}", func_name);
470
471    let expanded = quote! {
472        #func
473
474        struct #reg_ty_ident;
475        impl #reg_ty_ident {
476            fn __spawn(pool: &sqlx::MySqlPool) {
477                let pool = pool.clone();
478                tokio::spawn(async move {
479                    loop {
480                        #func_name(&pool).await;
481                        tokio::time::sleep(std::time::Duration::from_millis(#interval_ms)).await;
482                    }
483                });
484            }
485        }
486
487        inventory::submit! {
488            bloom_web_core::scheduler_registry::Scheduled {
489                name: stringify!(#func_name),
490                spawn: <#reg_ty_ident>::__spawn,
491            }
492        }
493    };
494
495    TokenStream::from(expanded)
496}
497
498#[proc_macro_derive(ApiSchema, attributes(schema))]
499pub fn derive_api_schema(input: TokenStream) -> TokenStream {
500    let input = parse_macro_input!(input as DeriveInput);
501    let struct_name = &input.ident;
502    let struct_name_str = struct_name.to_string();
503
504    let expanded = quote! {
505        inventory::submit! {
506            bloom_web_core::swagger_registry::SchemaInfo {
507                name: #struct_name_str,
508            }
509        }
510    };
511
512    TokenStream::from(expanded)
513}
514
515fn http_request(verb: &str, attr: TokenStream, item: TokenStream) -> TokenStream {
516    let path: LitStr = syn::parse(attr).expect("Expected path string in mapping attribute, e.g. \"/path\"");
517    let func: syn::ItemFn = syn::parse(item).expect("Mapping attribute must be used on a free function");
518    let func_name = func.sig.ident.clone();
519    let reg_ty_ident = format_ident!("__MapReg_{}__{}", verb, func_name);
520
521    let verb_attr = match verb {
522        "get" => quote! { #[actix_web::get(#path)] },
523        "post" => quote! { #[actix_web::post(#path)] },
524        "put" => quote! { #[actix_web::put(#path)] },
525        "delete" => quote! { #[actix_web::delete(#path)] },
526        "patch" => quote! { #[actix_web::patch(#path)] },
527        _ => quote! {},
528    };
529
530    let method_str = verb.to_uppercase();
531    let summary = format!("{} {}", method_str, path.value());
532
533    let request_schema_option = if matches!(verb, "post" | "put" | "patch") {
534        let mut found_schema = None;
535        for input in &func.sig.inputs {
536            if let syn::FnArg::Typed(pat_type) = input {
537                if let syn::Type::Path(type_path) = &*pat_type.ty {
538                    for segment in &type_path.path.segments {
539                        if segment.ident == "Json" {
540                            if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
541                                if let Some(syn::GenericArgument::Type(syn::Type::Path(inner_type))) = args.args.first() {
542                                    if let Some(last_segment) = inner_type.path.segments.last() {
543                                        found_schema = Some(last_segment.ident.to_string());
544                                        break;
545                                    }
546                                }
547                            }
548                        }
549                    }
550                }
551            }
552        }
553        found_schema
554    } else {
555        None
556    };
557
558    let request_schema_literal = match request_schema_option {
559        Some(schema_name) => {
560            let schema_str = LitStr::new(&schema_name, path.span());
561            quote! { Some(#schema_str) }
562        },
563        None => quote! { None },
564    };
565
566    let expanded = quote! {
567        #verb_attr
568        #func
569
570        #[allow(non_camel_case_types)]
571        struct #reg_ty_ident;
572        impl #reg_ty_ident {
573            fn __configure(cfg: &mut actix_web::web::ServiceConfig) {
574                cfg.service(#func_name);
575            }
576        }
577
578        inventory::submit! {
579            bloom_web_core::swagger_registry::PathOperation {
580                path: #path,
581                method: #verb,
582                operation_id: stringify!(#func_name),
583                summary: #summary,
584                request_schema: #request_schema_literal,
585            }
586        }
587
588        inventory::submit! {
589            bloom_web_core::controller_registry::Controller {
590                name: #path,
591                configure: <#reg_ty_ident>::__configure,
592            }
593        }
594    };
595
596    TokenStream::from(expanded)
597}