1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{
4 parse::{Parse, ParseStream},
5 parse_macro_input, Data, DeriveInput, Fields, FnArg, ImplItem, ItemImpl, LitStr, Pat, Token,
6 Type,
7};
8
9struct ControllerArgs {
14 prefix: String,
15 role: Option<String>,
16}
17
18impl Parse for ControllerArgs {
19 fn parse(input: ParseStream) -> syn::Result<Self> {
20 let mut prefix = None;
21 let mut role = None;
22
23 while !input.is_empty() {
24 let ident: syn::Ident = input.parse()?;
25 let _: Token![=] = input.parse()?;
26 let lit: LitStr = input.parse()?;
27
28 match ident.to_string().as_str() {
29 "prefix" => prefix = Some(lit.value()),
30 "role" => role = Some(lit.value()),
31 _ => return Err(syn::Error::new(ident.span(), "expected `prefix` or `role`")),
32 }
33
34 let _ = input.parse::<Token![,]>();
36 }
37
38 Ok(ControllerArgs {
39 prefix: prefix.unwrap_or_default(),
40 role,
41 })
42 }
43}
44
45struct RouteInfo {
47 method: String,
48 path: String,
49 fn_name: syn::Ident,
50}
51
52fn find_auth_guard_param(method: &syn::ImplItemFn) -> Option<syn::Ident> {
54 for arg in &method.sig.inputs {
55 if let FnArg::Typed(pat_type) = arg {
56 if let Type::Path(type_path) = pat_type.ty.as_ref() {
57 let last_seg = type_path.path.segments.last();
58 if let Some(seg) = last_seg {
59 if seg.ident == "AuthGuard" {
60 if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
62 return Some(pat_ident.ident.clone());
63 }
64 }
65 }
66 }
67 }
68 }
69 None
70}
71
72fn extract_route_info(item: &ImplItem) -> Option<RouteInfo> {
74 let method_item = match item {
75 ImplItem::Fn(m) => m,
76 _ => return None,
77 };
78
79 let fn_name = method_item.sig.ident.clone();
80 let mut http_method = None;
81 let mut path = None;
82
83 for attr in &method_item.attrs {
84 let seg = attr.path().segments.last()?;
85 let name = seg.ident.to_string();
86
87 match name.as_str() {
88 "get" | "post" | "put" | "delete" | "patch" => {
89 http_method = Some(name);
90 if let Ok(lit) = attr.parse_args::<LitStr>() {
91 path = Some(lit.value());
92 }
93 }
94 _ => {}
95 }
96 }
97
98 Some(RouteInfo {
99 method: http_method?,
100 path: path.unwrap_or_else(|| "/".to_string()),
101 fn_name,
102 })
103}
104
105#[proc_macro_attribute]
128pub fn controller(args: TokenStream, input: TokenStream) -> TokenStream {
129 let args = parse_macro_input!(args as ControllerArgs);
130 let mut impl_block = parse_macro_input!(input as ItemImpl);
131 let prefix = &args.prefix;
132
133 let routes: Vec<RouteInfo> = impl_block
135 .items
136 .iter()
137 .filter_map(extract_route_info)
138 .collect();
139
140 let controller_role = args.role.clone();
141
142 for item in &mut impl_block.items {
144 if let ImplItem::Fn(method) = item {
145 let mut role_value = None;
147 for attr in &method.attrs {
148 let name = attr
149 .path()
150 .segments
151 .last()
152 .map(|s| s.ident.to_string())
153 .unwrap_or_default();
154 if name == "require_role" {
155 if let Ok(lit) = attr.parse_args::<LitStr>() {
156 role_value = Some(lit.value());
157 }
158 }
159 }
160
161 let effective_role = role_value.or_else(|| controller_role.clone());
163
164 if let Some(role) = effective_role {
166 if let Some(auth_param) = find_auth_guard_param(method) {
167 let check = syn::parse2::<syn::Stmt>(quote! {
168 #auth_param.require_role(#role)?;
169 })
170 .expect("Failed to parse role check statement");
171
172 method.block.stmts.insert(0, check);
173 }
174 }
175
176 method.attrs.retain(|attr| {
178 let name = attr
179 .path()
180 .segments
181 .last()
182 .map(|s| s.ident.to_string())
183 .unwrap_or_default();
184 !matches!(
185 name.as_str(),
186 "get" | "post" | "put" | "delete" | "patch" | "require_role"
187 )
188 });
189 }
190 }
191
192 let mut path_groups: std::collections::BTreeMap<String, Vec<&RouteInfo>> =
194 std::collections::BTreeMap::new();
195 for route in &routes {
196 path_groups
197 .entry(route.path.clone())
198 .or_default()
199 .push(route);
200 }
201
202 let route_registrations: Vec<proc_macro2::TokenStream> = path_groups
203 .iter()
204 .map(|(path, methods)| {
205 let mut chain = Vec::new();
206 for (i, route) in methods.iter().enumerate() {
207 let method_ident = format_ident!("{}", route.method);
208 let fn_name = &route.fn_name;
209
210 if i == 0 {
211 chain.push(quote! {
212 axum::routing::#method_ident(Self::#fn_name)
213 });
214 } else {
215 chain.push(quote! {
216 .#method_ident(Self::#fn_name)
217 });
218 }
219 }
220
221 quote! {
222 .route(#path, #(#chain)*)
223 }
224 })
225 .collect();
226
227 let self_ty = &impl_block.self_ty;
228
229 let expanded = quote! {
230 #impl_block
231
232 impl #self_ty {
233 pub fn router() -> axum::Router<framework::http::AppState> {
235 axum::Router::new()
236 #(#route_registrations)*
237 }
238
239 pub fn prefix() -> &'static str {
241 #prefix
242 }
243 }
244 };
245
246 TokenStream::from(expanded)
247}
248
249#[proc_macro_attribute]
255pub fn get(_args: TokenStream, input: TokenStream) -> TokenStream {
256 input
257}
258
259#[proc_macro_attribute]
261pub fn post(_args: TokenStream, input: TokenStream) -> TokenStream {
262 input
263}
264
265#[proc_macro_attribute]
267pub fn put(_args: TokenStream, input: TokenStream) -> TokenStream {
268 input
269}
270
271#[proc_macro_attribute]
273pub fn delete(_args: TokenStream, input: TokenStream) -> TokenStream {
274 input
275}
276
277#[proc_macro_attribute]
279pub fn patch(_args: TokenStream, input: TokenStream) -> TokenStream {
280 input
281}
282
283#[proc_macro_attribute]
287pub fn require_role(_args: TokenStream, input: TokenStream) -> TokenStream {
288 input
289}
290
291fn extract_table_name(attrs: &[syn::Attribute]) -> Option<String> {
296 for attr in attrs {
297 if attr.path().is_ident("table_name") {
298 if let Ok(lit) = attr.parse_args::<LitStr>() {
299 return Some(lit.value());
300 }
301 }
302 }
303 None
304}
305
306fn has_field_attr(field: &syn::Field, name: &str) -> bool {
308 field.attrs.iter().any(|attr| attr.path().is_ident(name))
309}
310
311fn has_struct_attr(attrs: &[syn::Attribute], name: &str) -> bool {
313 attrs.iter().any(|attr| attr.path().is_ident(name))
314}
315
316fn extract_slug_from(field: &syn::Field) -> Option<String> {
318 for attr in &field.attrs {
319 if attr.path().is_ident("slug_from") {
320 if let Ok(lit) = attr.parse_args::<LitStr>() {
321 return Some(lit.value());
322 }
323 }
324 }
325 None
326}
327
328fn is_option_type(ty: &Type) -> bool {
330 if let Type::Path(type_path) = ty {
331 if let Some(seg) = type_path.path.segments.last() {
332 return seg.ident == "Option";
333 }
334 }
335 false
336}
337
338#[proc_macro_derive(Insertable, attributes(table_name, auto_increment, skip_insert, timestamps, slug_from))]
368pub fn derive_insertable(input: TokenStream) -> TokenStream {
369 let input = parse_macro_input!(input as DeriveInput);
370 let name = &input.ident;
371
372 let table = match extract_table_name(&input.attrs) {
373 Some(t) => t,
374 None => {
375 return syn::Error::new_spanned(&input.ident, "Missing #[table_name(\"...\")]")
376 .to_compile_error()
377 .into();
378 }
379 };
380
381 let fields = match &input.data {
382 Data::Struct(data) => match &data.fields {
383 Fields::Named(f) => &f.named,
384 _ => {
385 return syn::Error::new_spanned(name, "Insertable only works on structs with named fields")
386 .to_compile_error()
387 .into();
388 }
389 },
390 _ => {
391 return syn::Error::new_spanned(name, "Insertable only works on structs")
392 .to_compile_error()
393 .into();
394 }
395 };
396
397 let has_timestamps = has_struct_attr(&input.attrs, "timestamps");
398
399 let insert_fields: Vec<_> = fields
401 .iter()
402 .filter(|f| !has_field_attr(f, "auto_increment") && !has_field_attr(f, "skip_insert"))
403 .collect();
404
405 let mut column_names: Vec<String> = insert_fields
406 .iter()
407 .map(|f| f.ident.as_ref().unwrap().to_string())
408 .collect();
409
410 if has_timestamps {
412 column_names.push("created_at".to_string());
413 }
414
415 let column_refs_tokens: Vec<proc_macro2::TokenStream> = column_names
416 .iter()
417 .map(|c| quote! { #c })
418 .collect();
419
420 let mut bind_calls: Vec<proc_macro2::TokenStream> = Vec::new();
422 let mut slug_lets: Vec<proc_macro2::TokenStream> = Vec::new();
423
424 for field in &insert_fields {
425 let ident = field.ident.as_ref().unwrap();
426
427 if let Some(source_field) = extract_slug_from(field) {
428 let source_ident = format_ident!("{}", source_field);
429 let var_name = format_ident!("__slug_{}", ident);
430 slug_lets.push(quote! {
431 let #var_name = if self.#ident.is_empty() {
432 slug::slugify(&self.#source_ident)
433 } else {
434 self.#ident.clone()
435 };
436 });
437 bind_calls.push(quote! { .bind(&#var_name) });
438 } else {
439 bind_calls.push(quote! { .bind(&self.#ident) });
440 }
441 }
442
443 if has_timestamps {
445 bind_calls.push(quote! { .bind(chrono::Utc::now()) });
446 }
447
448 let expanded = quote! {
449 impl #name {
450 pub async fn insert<'e, E>(&self, executor: E) -> framework::error::AppResult<u64>
452 where
453 E: sqlx::Executor<'e, Database = framework::db::Db>,
454 {
455 #(#slug_lets)*
456 let sql = framework::db::insert_sql(#table, &[#(#column_refs_tokens),*]);
457 framework::db::execute_insert(
458 sqlx::query(&sql)
459 #(#bind_calls)*,
460 executor
461 ).await
462 }
463 }
464 };
465
466 TokenStream::from(expanded)
467}
468
469#[proc_macro_derive(Updatable, attributes(table_name, primary_key, skip_update, timestamps))]
500pub fn derive_updatable(input: TokenStream) -> TokenStream {
501 let input = parse_macro_input!(input as DeriveInput);
502 let name = &input.ident;
503
504 let table = match extract_table_name(&input.attrs) {
505 Some(t) => t,
506 None => {
507 return syn::Error::new_spanned(&input.ident, "Missing #[table_name(\"...\")]")
508 .to_compile_error()
509 .into();
510 }
511 };
512
513 let fields = match &input.data {
514 Data::Struct(data) => match &data.fields {
515 Fields::Named(f) => &f.named,
516 _ => {
517 return syn::Error::new_spanned(name, "Updatable only works on structs with named fields")
518 .to_compile_error()
519 .into();
520 }
521 },
522 _ => {
523 return syn::Error::new_spanned(name, "Updatable only works on structs")
524 .to_compile_error()
525 .into();
526 }
527 };
528
529 let pk_field = fields.iter().find(|f| has_field_attr(f, "primary_key"));
531 let pk_field = match pk_field {
532 Some(f) => f,
533 None => {
534 return syn::Error::new_spanned(name, "Updatable requires exactly one #[primary_key] field")
535 .to_compile_error()
536 .into();
537 }
538 };
539 let pk_ident = pk_field.ident.as_ref().unwrap();
540 let pk_col = pk_ident.to_string();
541
542 let has_timestamps = has_struct_attr(&input.attrs, "timestamps");
543
544 let update_fields: Vec<_> = fields
546 .iter()
547 .filter(|f| !has_field_attr(f, "primary_key") && !has_field_attr(f, "skip_update"))
548 .collect();
549
550 let mut set_pushes = Vec::new();
552 let mut bind_pushes = Vec::new();
553
554 for field in &update_fields {
555 let ident = field.ident.as_ref().unwrap();
556 let col_name = ident.to_string();
557
558 if is_option_type(&field.ty) {
559 set_pushes.push(quote! {
560 if self.#ident.is_some() {
561 param_idx += 1;
562 set_clauses.push(format!("{} = {}", #col_name, framework::db::placeholder(param_idx)));
563 }
564 });
565 bind_pushes.push(quote! {
566 if let Some(ref val) = self.#ident {
567 query = query.bind(val);
568 }
569 });
570 } else {
571 set_pushes.push(quote! {
572 param_idx += 1;
573 set_clauses.push(format!("{} = {}", #col_name, framework::db::placeholder(param_idx)));
574 });
575 bind_pushes.push(quote! {
576 query = query.bind(&self.#ident);
577 });
578 }
579 }
580
581 let timestamps_set_push = if has_timestamps {
583 quote! {
584 param_idx += 1;
585 set_clauses.push(format!("{} = {}", "updated_at", framework::db::placeholder(param_idx)));
586 }
587 } else {
588 quote! {}
589 };
590
591 let timestamps_bind_push = if has_timestamps {
592 quote! {
593 query = query.bind(chrono::Utc::now());
594 }
595 } else {
596 quote! {}
597 };
598
599 let expanded = quote! {
600 impl #name {
601 pub async fn update<'e, E>(&self, executor: E) -> framework::error::AppResult<u64>
610 where
611 E: sqlx::Executor<'e, Database = framework::db::Db>,
612 {
613 let mut set_clauses: Vec<String> = Vec::new();
614 let mut param_idx: usize = 0;
615
616 #(#set_pushes)*
617 #timestamps_set_push
618
619 if set_clauses.is_empty() {
620 return Ok(0); }
622
623 param_idx += 1;
624 let sql = format!(
625 "UPDATE {} SET {} WHERE {} = {}",
626 #table,
627 set_clauses.join(", "),
628 #pk_col,
629 framework::db::placeholder(param_idx),
630 );
631
632 let mut query = sqlx::query(&sql);
633
634 #(#bind_pushes)*
635 #timestamps_bind_push
636
637 query = query.bind(&self.#pk_ident);
639
640 let result = query.execute(executor).await?;
641 Ok(result.rows_affected())
642 }
643 }
644 };
645
646 TokenStream::from(expanded)
647}