1use proc_macro::TokenStream;
2use quote::{quote, format_ident};
3use syn::{parse_macro_input, Data, DeriveInput, Fields, Ident, LitStr, ItemStruct, ItemImpl, ImplItem, TypePath};
4use syn::spanned::Spanned;
5
6#[derive(Default)]
8#[allow(dead_code)]
9struct Relation {
10 kind: Option<String>,
11 target: Option<LitStr>,
12 foreign_key: Option<LitStr>,
13 column: Option<LitStr>,
14}
15
16#[proc_macro_derive(Entity, attributes(table, id, column, one_to_many, many_to_one, many_to_many, one_to_one, join_column))]
30pub fn derive_entity(input: TokenStream) -> TokenStream {
31 let input = parse_macro_input!(input as DeriveInput);
32 let struct_name = &input.ident;
33
34 let fields = match input.data {
35 Data::Struct(data_struct) => match data_struct.fields {
36 Fields::Named(fields) => fields.named,
37 _ => panic!("Entity must be a struct with named fields"),
38 },
39 _ => panic!("Entity must be a struct"),
40 };
41
42 let table_attr = input.attrs.iter().find(|attr| attr.path().is_ident("table"));
43 let table_name: LitStr = match table_attr {
44 Some(attr) => attr.parse_args().expect("Invalid table attribute"),
45 None => panic!("Missing #[table(\"table_name\")]"),
46 };
47
48 let mut id_field: Option<Ident> = None;
49 let mut id_column_name_opt: Option<String> = None;
50 let mut column_defs: Vec<LitStr> = Vec::new();
51 let mut index_defs: Vec<LitStr> = Vec::new();
52 let mut insert_cols: Vec<LitStr> = Vec::new();
53 let mut insert_idents: Vec<Ident> = Vec::new();
54 let mut update_cols: Vec<LitStr> = Vec::new();
55 let mut update_idents: Vec<Ident> = Vec::new();
56
57 for field in fields.iter() {
58 let field_name = field.ident.as_ref().unwrap().clone();
59 let field_ty = &field.ty;
60
61 let is_id = field.attrs.iter().any(|attr| attr.path().is_ident("id"));
62 if is_id {
63 if id_field.is_some() {
64 panic!("Only one #[id] field allowed");
65 }
66 id_field = Some(field_name.clone());
67 }
68
69 let has_one_to_many = field.attrs.iter().any(|a| a.path().is_ident("one_to_many"));
70 let has_many_to_many = field.attrs.iter().any(|a| a.path().is_ident("many_to_many"));
71 let has_one_to_one = field.attrs.iter().any(|a| a.path().is_ident("one_to_one"));
72 let is_virtual_relation = has_one_to_many || has_many_to_many || has_one_to_one;
73 if is_virtual_relation {
74 continue;
75 }
76
77 let has_many_to_one = field.attrs.iter().any(|a| a.path().is_ident("many_to_one"));
78 if has_many_to_one {
79 if let syn::Type::Path(_type_path) = field_ty {
80 if let Some(attr) = field.attrs.iter().find(|a| a.path().is_ident("join_column")) {
81 let join_name = attr.parse_args::<LitStr>().expect("Invalid join_column attribute").value();
82 let col_def = format!("{} {}", join_name, "INTEGER");
83 column_defs.push(LitStr::new(&col_def, field.span()));
84 let idx_def = format!("INDEX ({})", join_name);
85 index_defs.push(LitStr::new(&idx_def, field.span()));
86 continue;
87 }
88 }
89 }
90
91 let column_attr = field.attrs.iter().find(|a| a.path().is_ident("column"));
92 let column_name = match column_attr {
93 Some(attr) => attr.parse_args::<LitStr>().expect("Invalid column attribute").value(),
94 None => field_name.to_string(),
95 };
96 if is_id {
97 id_column_name_opt = Some(column_name.clone());
98 }
99
100 let sql_type = if let syn::Type::Path(type_path) = field_ty {
101 let type_name = type_path
102 .path
103 .segments
104 .last()
105 .map(|s| s.ident.to_string())
106 .unwrap_or_default();
107 match type_name.as_str() {
108 "i32" => "INTEGER",
109 "i64" => "BIGINT",
110 "f32" => "FLOAT",
111 "f64" => "DOUBLE PRECISION",
112 "String" => "VARCHAR(255)",
113 "bool" => "BOOLEAN",
114 _ => panic!("Unsupported type for field {}: {:?}", field_name, field_ty),
115 }
116 } else {
117 panic!("Unsupported type for field {}: {:?}", field_name, field_ty);
118 };
119
120 let col_def = if is_id && (sql_type == "INTEGER" || sql_type == "BIGINT") {
121 format!("{} {} PRIMARY KEY AUTO_INCREMENT", column_name, sql_type)
122 } else if is_id {
123 format!("{} {} PRIMARY KEY", column_name, sql_type)
124 } else {
125 format!("{} {}", column_name, sql_type)
126 };
127 column_defs.push(LitStr::new(&col_def, field.span()));
128
129
130 if !is_id {
131 insert_cols.push(LitStr::new(&column_name, field.span()));
132 insert_idents.push(field_name.clone());
133 update_cols.push(LitStr::new(&column_name, field.span()));
134 update_idents.push(field_name.clone());
135 }
136 }
137
138 let id_field = id_field.expect("Missing #[id] field for primary key");
139 let _ = id_field;
140
141 let id_column_name = id_column_name_opt.expect("Missing id column name (ensure #[id] is present)");
142
143 let expanded = quote! {
144 #[allow(dead_code)]
145 impl #struct_name {
146 pub const TABLE_NAME: &'static str = #table_name;
147
148 pub async fn create_table(pool: &sqlx::MySqlPool) -> anyhow::Result<()> {
149 let columns: &[&str] = &[#(#column_defs),*];
150 let indexes: &[&str] = &[#(#index_defs),*];
151 let mut parts: Vec<&str> = Vec::new();
152 parts.extend_from_slice(columns);
153 parts.extend_from_slice(indexes);
154 let sql = format!(
155 "CREATE TABLE IF NOT EXISTS {} ({})",
156 Self::TABLE_NAME,
157 parts.join(", ")
158 );
159 sqlx::query(&sql)
160 .execute(pool)
161 .await?;
162 Ok(())
163 }
164
165 pub async fn insert(&self, pool: &sqlx::MySqlPool) -> anyhow::Result<u64> {
166 let cols: &[&str] = &[#(#insert_cols),*];
167 let placeholders = vec!["?"; cols.len()].join(", ");
168 let sql = format!(
169 "INSERT INTO {} ({}) VALUES ({})",
170 Self::TABLE_NAME,
171 cols.join(", "),
172 placeholders
173 );
174 let mut q = sqlx::query(&sql);
175 #( q = q.bind(&self.#insert_idents); )*
176 let res = q.execute(pool).await?;
177 Ok(res.rows_affected())
178 }
179
180 pub async fn update_by_id(&self, pool: &sqlx::MySqlPool) -> anyhow::Result<u64> {
181 let cols: &[&str] = &[#(#update_cols),*];
182 let set_clause = cols.iter().map(|c| format!("{} = ?", c)).collect::<Vec<_>>().join(", ");
183 let sql = format!(
184 "UPDATE {} SET {} WHERE {} = ?",
185 Self::TABLE_NAME,
186 set_clause,
187 #id_column_name
188 );
189 let mut q = sqlx::query(&sql);
190 #( q = q.bind(&self.#update_idents); )*
191 q = q.bind(&self.#id_field);
192 let res = q.execute(pool).await?;
193 Ok(res.rows_affected())
194 }
195
196 pub async fn upsert(&self, pool: &sqlx::MySqlPool) -> anyhow::Result<u64> {
197 let cols: &[&str] = &[#(#insert_cols),*];
198 let placeholders = vec!["?"; cols.len()].join(", ");
199 let update_part = cols.iter().map(|c| format!("{}=VALUES({})", c, c)).collect::<Vec<_>>().join(", ");
200 let sql = format!(
201 "INSERT INTO {} ({}) VALUES ({}) ON DUPLICATE KEY UPDATE {}",
202 Self::TABLE_NAME,
203 cols.join(", "),
204 placeholders,
205 update_part
206 );
207 let mut q = sqlx::query(&sql);
208 #( q = q.bind(&self.#insert_idents); )*
209 let res = q.execute(pool).await?;
210 Ok(res.rows_affected())
211 }
212
213 fn __run_migration<'a>(pool: &'a sqlx::MySqlPool) -> std::pin::Pin<Box<dyn std::future::Future<Output = anyhow::Result<()>> + Send + 'a>> {
214 Box::pin(async move { Self::create_table(pool).await })
215 }
216 }
217
218
219 inventory::submit! {
220 bloom_web_core::entity_registry::Migration {
221 name: #table_name,
222 run: <#struct_name>::__run_migration,
223 }
224 }
225 };
226
227 TokenStream::from(expanded)
228}
229
230#[proc_macro_attribute]
231pub fn repository(attr: TokenStream, item: TokenStream) -> TokenStream {
232 let entity_ty: TypePath = syn::parse(attr).expect("Expected entity type in #[repository(Entity)]");
233 let item_struct: ItemStruct = syn::parse(item.clone()).expect("#[repository] must be used on a unit struct");
234 let repo_name = &item_struct.ident;
235
236 let expanded = quote! {
237 #item_struct
238
239 impl #repo_name {
240 pub async fn find_all_raw(pool: &sqlx::MySqlPool) -> anyhow::Result<Vec<sqlx::mysql::MySqlRow>> {
241 let sql = format!("SELECT * FROM {}", <#entity_ty>::TABLE_NAME);
242 let rows = sqlx::query(&sql).fetch_all(pool).await?;
243 Ok(rows)
244 }
245
246 pub async fn find_by_id_raw(pool: &sqlx::MySqlPool, id: i64) -> anyhow::Result<Option<sqlx::mysql::MySqlRow>> {
247 let sql = format!("SELECT * FROM {} WHERE id = ? LIMIT 1", <#entity_ty>::TABLE_NAME);
248 let row = sqlx::query(&sql).bind(id).fetch_optional(pool).await?;
249 Ok(row)
250 }
251
252 pub async fn exists_by_id(pool: &sqlx::MySqlPool, id: i64) -> anyhow::Result<bool> {
253 let sql = format!("SELECT 1 FROM {} WHERE id = ? LIMIT 1", <#entity_ty>::TABLE_NAME);
254 let row = sqlx::query(&sql).bind(id).fetch_optional(pool).await?;
255 Ok(row.is_some())
256 }
257
258 pub async fn count(pool: &sqlx::MySqlPool) -> anyhow::Result<i64> {
259 let sql = format!("SELECT COUNT(*) as cnt FROM {}", <#entity_ty>::TABLE_NAME);
260 let row: sqlx::mysql::MySqlRow = sqlx::query(&sql).fetch_one(pool).await?;
261 let cnt_by_alias: Result<i64, _> = <sqlx::mysql::MySqlRow as sqlx::Row>::try_get(&row, "cnt");
262 if let Ok(v) = cnt_by_alias { return Ok(v); }
263 let cnt_by_idx: i64 = <sqlx::mysql::MySqlRow as sqlx::Row>::try_get(&row, 0)?;
264 Ok(cnt_by_idx)
265 }
266
267 pub async fn delete_by_id(pool: &sqlx::MySqlPool, id: i64) -> anyhow::Result<u64> {
268 let sql = format!("DELETE FROM {} WHERE id = ?", <#entity_ty>::TABLE_NAME);
269 let res = sqlx::query(&sql).bind(id).execute(pool).await?;
270 Ok(res.rows_affected())
271 }
272
273 pub async fn create(pool: &sqlx::MySqlPool, entity: &#entity_ty) -> anyhow::Result<u64> {
274 entity.insert(pool).await
275 }
276
277 pub async fn update(pool: &sqlx::MySqlPool, entity: &#entity_ty) -> anyhow::Result<u64> {
278 entity.update_by_id(pool).await
279 }
280
281 pub async fn insert_or_update(pool: &sqlx::MySqlPool, entity: &#entity_ty) -> anyhow::Result<u64> {
282 entity.upsert(pool).await
283 }
284
285 pub async fn delete(pool: &sqlx::MySqlPool, id: i64) -> anyhow::Result<u64> {
286 Self::delete_by_id(pool, id).await
287 }
288
289 pub async fn find_all<T>(pool: &sqlx::MySqlPool) -> anyhow::Result<Vec<T>>
290 where
291 for<'r> T: sqlx::FromRow<'r, sqlx::mysql::MySqlRow> + Send + Unpin,
292 {
293 let sql = format!("SELECT * FROM {}", <#entity_ty>::TABLE_NAME);
294 let rows = sqlx::query_as::<_, T>(&sql).fetch_all(pool).await?;
295 Ok(rows)
296 }
297
298 pub async fn find_by_id<T>(pool: &sqlx::MySqlPool, id: i64) -> anyhow::Result<Option<T>>
299 where
300 for<'r> T: sqlx::FromRow<'r, sqlx::mysql::MySqlRow> + Send + Unpin,
301 {
302 let sql = format!("SELECT * FROM {} WHERE id = ? LIMIT 1", <#entity_ty>::TABLE_NAME);
303 let row = sqlx::query_as::<_, T>(&sql).bind(id).fetch_optional(pool).await?;
304 Ok(row)
305 }
306 }
307 };
308
309 TokenStream::from(expanded)
310}
311
312#[proc_macro_attribute]
313pub fn controller(attr: TokenStream, item: TokenStream) -> TokenStream {
314 let base_path: LitStr = syn::parse(attr).expect("Expected base path string in #[controller(\"/path\")] ");
315 let item_impl: ItemImpl = syn::parse(item.clone()).expect("#[controller] must be used on an impl block");
316
317 let self_ty_ident = match *item_impl.self_ty.clone() {
318 syn::Type::Path(ref tp) => tp.path.segments.last().unwrap().ident.clone(),
319 _ => panic!("#[controller] impl must target a concrete named type"),
320 };
321
322 let mut has_get_all = false;
323 let mut has_get_by_id = false;
324 let mut has_create = false;
325 let mut has_update = false;
326 let mut has_delete = false;
327
328 for it in &item_impl.items {
329 if let ImplItem::Fn(f) = it {
330 let name = f.sig.ident.to_string();
331 match name.as_str() {
332 "get_all" => has_get_all = true,
333 "get_by_id" => has_get_by_id = true,
334 "create" => has_create = true,
335 "update" => has_update = true,
336 "delete" => has_delete = true,
337 _ => {}
338 }
339 }
340 }
341
342 let mut routes: Vec<proc_macro2::TokenStream> = Vec::new();
343 if has_get_all {
344 routes.push(quote! { scope = scope.route("", actix_web::web::get().to(<#self_ty_ident>::get_all)); });
345 }
346 if has_get_by_id {
347 routes.push(quote! { scope = scope.route("/{id}", actix_web::web::get().to(<#self_ty_ident>::get_by_id)); });
348 }
349 if has_create {
350 routes.push(quote! { scope = scope.route("", actix_web::web::post().to(<#self_ty_ident>::create)); });
351 }
352 if has_update {
353 routes.push(quote! { scope = scope.route("/{id}", actix_web::web::put().to(<#self_ty_ident>::update)); });
354 }
355 if has_delete {
356 routes.push(quote! { scope = scope.route("/{id}", actix_web::web::delete().to(<#self_ty_ident>::delete)); });
357 }
358
359 let expanded = quote! {
360 #item_impl
361
362 impl #self_ty_ident {
363 fn __configure(cfg: &mut actix_web::web::ServiceConfig) {
364 let mut scope = actix_web::web::scope(#base_path);
365 #(#routes)*
366 cfg.service(scope);
367 }
368 }
369
370 inventory::submit! {
371 bloom_web_core::controller_registry::Controller {
372 name: #base_path,
373 configure: <#self_ty_ident>::__configure,
374 }
375 }
376 };
377
378 TokenStream::from(expanded)
379}
380
381
382#[proc_macro_attribute]
383pub fn auto_register(_attr: TokenStream, item: TokenStream) -> TokenStream {
384 let func: syn::ItemFn = syn::parse(item).expect("#[auto_register] must be used on a free function");
385 let func_name = func.sig.ident.clone();
386 let reg_ty_ident = format_ident!("__AutoReg_{}", func_name);
387
388 let expanded = quote! {
389 #func
390
391 struct #reg_ty_ident;
392 impl #reg_ty_ident {
393 fn __configure(cfg: &mut actix_web::web::ServiceConfig) {
394 cfg.service(#func_name);
395 }
396 }
397
398 inventory::submit! {
399 bloom_web_core::controller_registry::Controller {
400 name: stringify!(#func_name),
401 configure: <#reg_ty_ident>::__configure,
402 }
403 }
404 };
405
406 TokenStream::from(expanded)
407}
408
409#[proc_macro_attribute]
410pub fn get(attr: TokenStream, item: TokenStream) -> TokenStream {
411 http_request("get", attr, item)
412}
413
414#[proc_macro_attribute]
415pub fn post(attr: TokenStream, item: TokenStream) -> TokenStream {
416 http_request("post", attr, item)
417}
418
419#[proc_macro_attribute]
420pub fn put(attr: TokenStream, item: TokenStream) -> TokenStream {
421 http_request("put", attr, item)
422}
423
424#[proc_macro_attribute]
425pub fn delete(attr: TokenStream, item: TokenStream) -> TokenStream {
426 http_request("delete", attr, item)
427}
428
429#[proc_macro_attribute]
430pub fn patch(attr: TokenStream, item: TokenStream) -> TokenStream {
431 http_request("patch", attr, item)
432}
433
434#[proc_macro_attribute]
435pub fn get_mapping(attr: TokenStream, item: TokenStream) -> TokenStream {
436 http_request("get", attr, item)
437}
438
439#[proc_macro_attribute]
440pub fn post_mapping(attr: TokenStream, item: TokenStream) -> TokenStream {
441 http_request("post", attr, item)
442}
443
444#[proc_macro_attribute]
445pub fn put_mapping(attr: TokenStream, item: TokenStream) -> TokenStream {
446 http_request("put", attr, item)
447}
448
449#[proc_macro_attribute]
450pub fn delete_mapping(attr: TokenStream, item: TokenStream) -> TokenStream {
451 http_request("delete", attr, item)
452}
453
454#[proc_macro_attribute]
455pub fn patch_mapping(attr: TokenStream, item: TokenStream) -> TokenStream {
456 http_request("patch", attr, item)
457}
458
459#[proc_macro_attribute]
460pub fn scheduled(attr: TokenStream, item: TokenStream) -> TokenStream {
461 let interval_lit: syn::LitInt = match syn::parse(attr.clone()) {
462 Ok(v) => v,
463 Err(_) => panic!("#[scheduled] expects a millisecond interval literal, e.g., #[scheduled(60000)]"),
464 };
465 let interval_ms: u64 = interval_lit.base10_parse().expect("Invalid millisecond value for #[scheduled]");
466
467 let func: syn::ItemFn = syn::parse(item).expect("#[scheduled] must be used on a free async function");
468 let func_name = func.sig.ident.clone();
469 let reg_ty_ident = format_ident!("__SchedReg_{}", func_name);
470
471 let expanded = quote! {
472 #func
473
474 struct #reg_ty_ident;
475 impl #reg_ty_ident {
476 fn __spawn(pool: &sqlx::MySqlPool) {
477 let pool = pool.clone();
478 tokio::spawn(async move {
479 loop {
480 #func_name(&pool).await;
481 tokio::time::sleep(std::time::Duration::from_millis(#interval_ms)).await;
482 }
483 });
484 }
485 }
486
487 inventory::submit! {
488 bloom_web_core::scheduler_registry::Scheduled {
489 name: stringify!(#func_name),
490 spawn: <#reg_ty_ident>::__spawn,
491 }
492 }
493 };
494
495 TokenStream::from(expanded)
496}
497
498#[proc_macro_derive(ApiSchema, attributes(schema))]
499pub fn derive_api_schema(input: TokenStream) -> TokenStream {
500 let input = parse_macro_input!(input as DeriveInput);
501 let struct_name = &input.ident;
502 let struct_name_str = struct_name.to_string();
503
504 let expanded = quote! {
505 inventory::submit! {
506 bloom_web_core::swagger_registry::SchemaInfo {
507 name: #struct_name_str,
508 }
509 }
510 };
511
512 TokenStream::from(expanded)
513}
514
515fn http_request(verb: &str, attr: TokenStream, item: TokenStream) -> TokenStream {
516 let path: LitStr = syn::parse(attr).expect("Expected path string in mapping attribute, e.g. \"/path\"");
517 let func: syn::ItemFn = syn::parse(item).expect("Mapping attribute must be used on a free function");
518 let func_name = func.sig.ident.clone();
519 let reg_ty_ident = format_ident!("__MapReg_{}__{}", verb, func_name);
520
521 let verb_attr = match verb {
522 "get" => quote! { #[actix_web::get(#path)] },
523 "post" => quote! { #[actix_web::post(#path)] },
524 "put" => quote! { #[actix_web::put(#path)] },
525 "delete" => quote! { #[actix_web::delete(#path)] },
526 "patch" => quote! { #[actix_web::patch(#path)] },
527 _ => quote! {},
528 };
529
530 let method_str = verb.to_uppercase();
531 let summary = format!("{} {}", method_str, path.value());
532
533 let request_schema_option = if matches!(verb, "post" | "put" | "patch") {
534 let mut found_schema = None;
535 for input in &func.sig.inputs {
536 if let syn::FnArg::Typed(pat_type) = input {
537 if let syn::Type::Path(type_path) = &*pat_type.ty {
538 for segment in &type_path.path.segments {
539 if segment.ident == "Json" {
540 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
541 if let Some(syn::GenericArgument::Type(syn::Type::Path(inner_type))) = args.args.first() {
542 if let Some(last_segment) = inner_type.path.segments.last() {
543 found_schema = Some(last_segment.ident.to_string());
544 break;
545 }
546 }
547 }
548 }
549 }
550 }
551 }
552 }
553 found_schema
554 } else {
555 None
556 };
557
558 let request_schema_literal = match request_schema_option {
559 Some(schema_name) => {
560 let schema_str = LitStr::new(&schema_name, path.span());
561 quote! { Some(#schema_str) }
562 },
563 None => quote! { None },
564 };
565
566 let expanded = quote! {
567 #verb_attr
568 #func
569
570 #[allow(non_camel_case_types)]
571 struct #reg_ty_ident;
572 impl #reg_ty_ident {
573 fn __configure(cfg: &mut actix_web::web::ServiceConfig) {
574 cfg.service(#func_name);
575 }
576 }
577
578 inventory::submit! {
579 bloom_web_core::swagger_registry::PathOperation {
580 path: #path,
581 method: #verb,
582 operation_id: stringify!(#func_name),
583 summary: #summary,
584 request_schema: #request_schema_literal,
585 }
586 }
587
588 inventory::submit! {
589 bloom_web_core::controller_registry::Controller {
590 name: #path,
591 configure: <#reg_ty_ident>::__configure,
592 }
593 }
594 };
595
596 TokenStream::from(expanded)
597}