1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{
4 parse::{Parse, ParseStream},
5 parse_macro_input, Data, DeriveInput, Fields, FnArg, ImplItem, ItemImpl, LitStr, Pat, Token,
6 Type,
7};
8
9struct ControllerArgs {
14 prefix: String,
15 role: Option<String>,
16}
17
18impl Parse for ControllerArgs {
19 fn parse(input: ParseStream) -> syn::Result<Self> {
20 let mut prefix = None;
21 let mut role = None;
22
23 while !input.is_empty() {
24 let ident: syn::Ident = input.parse()?;
25 let _: Token![=] = input.parse()?;
26 let lit: LitStr = input.parse()?;
27
28 match ident.to_string().as_str() {
29 "prefix" => prefix = Some(lit.value()),
30 "role" => role = Some(lit.value()),
31 _ => return Err(syn::Error::new(ident.span(), "expected `prefix` or `role`")),
32 }
33
34 let _ = input.parse::<Token![,]>();
36 }
37
38 Ok(ControllerArgs {
39 prefix: prefix.unwrap_or_default(),
40 role,
41 })
42 }
43}
44
45struct RouteInfo {
47 method: String,
48 path: String,
49 fn_name: syn::Ident,
50}
51
52fn find_auth_guard_param(method: &syn::ImplItemFn) -> Option<syn::Ident> {
54 for arg in &method.sig.inputs {
55 if let FnArg::Typed(pat_type) = arg {
56 if let Type::Path(type_path) = pat_type.ty.as_ref() {
57 let last_seg = type_path.path.segments.last();
58 if let Some(seg) = last_seg {
59 if seg.ident == "AuthGuard" {
60 if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
62 return Some(pat_ident.ident.clone());
63 }
64 }
65 }
66 }
67 }
68 }
69 None
70}
71
72fn extract_route_info(item: &ImplItem) -> Option<RouteInfo> {
74 let method_item = match item {
75 ImplItem::Fn(m) => m,
76 _ => return None,
77 };
78
79 let fn_name = method_item.sig.ident.clone();
80 let mut http_method = None;
81 let mut path = None;
82
83 for attr in &method_item.attrs {
84 let seg = attr.path().segments.last()?;
85 let name = seg.ident.to_string();
86
87 match name.as_str() {
88 "get" | "post" | "put" | "delete" | "patch" => {
89 http_method = Some(name);
90 if let Ok(lit) = attr.parse_args::<LitStr>() {
91 path = Some(lit.value());
92 }
93 }
94 _ => {}
95 }
96 }
97
98 Some(RouteInfo {
99 method: http_method?,
100 path: path.unwrap_or_else(|| "/".to_string()),
101 fn_name,
102 })
103}
104
105#[proc_macro_attribute]
128pub fn controller(args: TokenStream, input: TokenStream) -> TokenStream {
129 let args = parse_macro_input!(args as ControllerArgs);
130 let mut impl_block = parse_macro_input!(input as ItemImpl);
131 let prefix = &args.prefix;
132
133 let routes: Vec<RouteInfo> = impl_block
135 .items
136 .iter()
137 .filter_map(extract_route_info)
138 .collect();
139
140 let controller_role = args.role.clone();
141
142 for item in &mut impl_block.items {
144 if let ImplItem::Fn(method) = item {
145 let mut role_value = None;
147 for attr in &method.attrs {
148 let name = attr
149 .path()
150 .segments
151 .last()
152 .map(|s| s.ident.to_string())
153 .unwrap_or_default();
154 if name == "require_role" {
155 if let Ok(lit) = attr.parse_args::<LitStr>() {
156 role_value = Some(lit.value());
157 }
158 }
159 }
160
161 let effective_role = role_value.or_else(|| controller_role.clone());
163
164 if let Some(role) = effective_role {
166 if let Some(auth_param) = find_auth_guard_param(method) {
167 let check = syn::parse2::<syn::Stmt>(quote! {
168 #auth_param.require_role(#role)?;
169 })
170 .expect("Failed to parse role check statement");
171
172 method.block.stmts.insert(0, check);
173 }
174 }
175
176 method.attrs.retain(|attr| {
178 let name = attr
179 .path()
180 .segments
181 .last()
182 .map(|s| s.ident.to_string())
183 .unwrap_or_default();
184 !matches!(
185 name.as_str(),
186 "get" | "post" | "put" | "delete" | "patch" | "require_role"
187 )
188 });
189 }
190 }
191
192 let mut path_groups: std::collections::BTreeMap<String, Vec<&RouteInfo>> =
194 std::collections::BTreeMap::new();
195 for route in &routes {
196 path_groups
197 .entry(route.path.clone())
198 .or_default()
199 .push(route);
200 }
201
202 let route_registrations: Vec<proc_macro2::TokenStream> = path_groups
203 .iter()
204 .map(|(path, methods)| {
205 let mut chain = Vec::new();
206 for (i, route) in methods.iter().enumerate() {
207 let method_ident = format_ident!("{}", route.method);
208 let fn_name = &route.fn_name;
209
210 if i == 0 {
211 chain.push(quote! {
212 axum::routing::#method_ident(Self::#fn_name)
213 });
214 } else {
215 chain.push(quote! {
216 .#method_ident(Self::#fn_name)
217 });
218 }
219 }
220
221 quote! {
222 .route(#path, #(#chain)*)
223 }
224 })
225 .collect();
226
227 let self_ty = &impl_block.self_ty;
228
229 let expanded = quote! {
230 #impl_block
231
232 impl #self_ty {
233 pub fn router() -> axum::Router<karbon::http::AppState> {
235 axum::Router::new()
236 #(#route_registrations)*
237 }
238
239 pub fn prefix() -> &'static str {
241 #prefix
242 }
243 }
244 };
245
246 TokenStream::from(expanded)
247}
248
249#[proc_macro_attribute]
255pub fn get(_args: TokenStream, input: TokenStream) -> TokenStream {
256 input
257}
258
259#[proc_macro_attribute]
261pub fn post(_args: TokenStream, input: TokenStream) -> TokenStream {
262 input
263}
264
265#[proc_macro_attribute]
267pub fn put(_args: TokenStream, input: TokenStream) -> TokenStream {
268 input
269}
270
271#[proc_macro_attribute]
273pub fn delete(_args: TokenStream, input: TokenStream) -> TokenStream {
274 input
275}
276
277#[proc_macro_attribute]
279pub fn patch(_args: TokenStream, input: TokenStream) -> TokenStream {
280 input
281}
282
283#[proc_macro_attribute]
287pub fn require_role(_args: TokenStream, input: TokenStream) -> TokenStream {
288 input
289}
290
291fn extract_table_name(attrs: &[syn::Attribute]) -> Option<String> {
296 for attr in attrs {
297 if attr.path().is_ident("table_name") {
298 if let Ok(lit) = attr.parse_args::<LitStr>() {
299 return Some(lit.value());
300 }
301 }
302 }
303 None
304}
305
306fn has_field_attr(field: &syn::Field, name: &str) -> bool {
308 field.attrs.iter().any(|attr| attr.path().is_ident(name))
309}
310
311fn has_struct_attr(attrs: &[syn::Attribute], name: &str) -> bool {
313 attrs.iter().any(|attr| attr.path().is_ident(name))
314}
315
316fn extract_slug_from(field: &syn::Field) -> Option<String> {
318 for attr in &field.attrs {
319 if attr.path().is_ident("slug_from") {
320 if let Ok(lit) = attr.parse_args::<LitStr>() {
321 return Some(lit.value());
322 }
323 }
324 }
325 None
326}
327
328fn is_option_type(ty: &Type) -> bool {
330 if let Type::Path(type_path) = ty {
331 if let Some(seg) = type_path.path.segments.last() {
332 return seg.ident == "Option";
333 }
334 }
335 false
336}
337
338#[proc_macro_derive(Insertable, attributes(table_name, auto_increment, skip_insert, timestamps, slug_from))]
368pub fn derive_insertable(input: TokenStream) -> TokenStream {
369 let input = parse_macro_input!(input as DeriveInput);
370 let name = &input.ident;
371
372 let table = match extract_table_name(&input.attrs) {
373 Some(t) => t,
374 None => {
375 return syn::Error::new_spanned(&input.ident, "Missing #[table_name(\"...\")]")
376 .to_compile_error()
377 .into();
378 }
379 };
380
381 let fields = match &input.data {
382 Data::Struct(data) => match &data.fields {
383 Fields::Named(f) => &f.named,
384 _ => {
385 return syn::Error::new_spanned(name, "Insertable only works on structs with named fields")
386 .to_compile_error()
387 .into();
388 }
389 },
390 _ => {
391 return syn::Error::new_spanned(name, "Insertable only works on structs")
392 .to_compile_error()
393 .into();
394 }
395 };
396
397 let has_timestamps = has_struct_attr(&input.attrs, "timestamps");
398
399 let insert_fields: Vec<_> = fields
401 .iter()
402 .filter(|f| !has_field_attr(f, "auto_increment") && !has_field_attr(f, "skip_insert"))
403 .collect();
404
405 let mut column_names: Vec<String> = insert_fields
406 .iter()
407 .map(|f| f.ident.as_ref().unwrap().to_string())
408 .collect();
409
410 if has_timestamps {
412 column_names.push("created_at".to_string());
413 }
414
415
416 let mut bind_calls: Vec<proc_macro2::TokenStream> = Vec::new();
418 let mut slug_lets: Vec<proc_macro2::TokenStream> = Vec::new();
419
420 for field in &insert_fields {
421 let ident = field.ident.as_ref().unwrap();
422
423 if let Some(source_field) = extract_slug_from(field) {
424 let source_ident = format_ident!("{}", source_field);
425 let var_name = format_ident!("__slug_{}", ident);
426 slug_lets.push(quote! {
427 let #var_name = if self.#ident.is_empty() {
428 slug::slugify(&self.#source_ident)
429 } else {
430 self.#ident.clone()
431 };
432 });
433 bind_calls.push(quote! { .bind(&#var_name) });
434 } else {
435 bind_calls.push(quote! { .bind(&self.#ident) });
436 }
437 }
438
439 if has_timestamps {
441 bind_calls.push(quote! { .bind(chrono::Utc::now()) });
442 }
443
444 let columns_str = column_names.join(", ");
446 let placeholders_str: String = (1..=column_names.len())
447 .map(|_| "?".to_string())
448 .collect::<Vec<_>>()
449 .join(", ");
450 let sql_literal = format!("INSERT INTO {} ({}) VALUES ({})", table, columns_str, placeholders_str);
451
452 let expanded = quote! {
453 impl #name {
454 pub async fn insert<'e, E>(&self, executor: E) -> karbon::error::AppResult<u64>
456 where
457 E: sqlx::Executor<'e, Database = karbon::db::Db>,
458 {
459 #(#slug_lets)*
460 let result = sqlx::query(#sql_literal)
461 #(#bind_calls)*
462 .execute(executor)
463 .await?;
464 Ok(karbon::db::last_insert_id(&result))
465 }
466 }
467 };
468
469 TokenStream::from(expanded)
470}
471
472#[proc_macro_derive(Updatable, attributes(table_name, primary_key, skip_update, timestamps))]
503pub fn derive_updatable(input: TokenStream) -> TokenStream {
504 let input = parse_macro_input!(input as DeriveInput);
505 let name = &input.ident;
506
507 let table = match extract_table_name(&input.attrs) {
508 Some(t) => t,
509 None => {
510 return syn::Error::new_spanned(&input.ident, "Missing #[table_name(\"...\")]")
511 .to_compile_error()
512 .into();
513 }
514 };
515
516 let fields = match &input.data {
517 Data::Struct(data) => match &data.fields {
518 Fields::Named(f) => &f.named,
519 _ => {
520 return syn::Error::new_spanned(name, "Updatable only works on structs with named fields")
521 .to_compile_error()
522 .into();
523 }
524 },
525 _ => {
526 return syn::Error::new_spanned(name, "Updatable only works on structs")
527 .to_compile_error()
528 .into();
529 }
530 };
531
532 let pk_field = fields.iter().find(|f| has_field_attr(f, "primary_key"));
534 let pk_field = match pk_field {
535 Some(f) => f,
536 None => {
537 return syn::Error::new_spanned(name, "Updatable requires exactly one #[primary_key] field")
538 .to_compile_error()
539 .into();
540 }
541 };
542 let pk_ident = pk_field.ident.as_ref().unwrap();
543 let pk_col = pk_ident.to_string();
544
545 let has_timestamps = has_struct_attr(&input.attrs, "timestamps");
546
547 let update_fields: Vec<_> = fields
549 .iter()
550 .filter(|f| !has_field_attr(f, "primary_key") && !has_field_attr(f, "skip_update"))
551 .collect();
552
553 let mut set_pushes = Vec::new();
555 let mut bind_pushes = Vec::new();
556
557 for field in &update_fields {
558 let ident = field.ident.as_ref().unwrap();
559 let col_name = ident.to_string();
560
561 if is_option_type(&field.ty) {
562 set_pushes.push(quote! {
563 if self.#ident.is_some() {
564 param_idx += 1;
565 set_clauses.push(format!("{} = {}", #col_name, karbon::db::placeholder(param_idx)));
566 }
567 });
568 bind_pushes.push(quote! {
569 if let Some(ref val) = self.#ident {
570 query = query.bind(val);
571 }
572 });
573 } else {
574 set_pushes.push(quote! {
575 param_idx += 1;
576 set_clauses.push(format!("{} = {}", #col_name, karbon::db::placeholder(param_idx)));
577 });
578 bind_pushes.push(quote! {
579 query = query.bind(&self.#ident);
580 });
581 }
582 }
583
584 let timestamps_set_push = if has_timestamps {
586 quote! {
587 param_idx += 1;
588 set_clauses.push(format!("{} = {}", "updated_at", karbon::db::placeholder(param_idx)));
589 }
590 } else {
591 quote! {}
592 };
593
594 let timestamps_bind_push = if has_timestamps {
595 quote! {
596 query = query.bind(chrono::Utc::now());
597 }
598 } else {
599 quote! {}
600 };
601
602 let expanded = quote! {
603 impl #name {
604 pub async fn update<'e, E>(&self, executor: E) -> karbon::error::AppResult<u64>
613 where
614 E: sqlx::Executor<'e, Database = karbon::db::Db>,
615 {
616 let mut set_clauses: Vec<String> = Vec::new();
617 let mut param_idx: usize = 0;
618
619 #(#set_pushes)*
620 #timestamps_set_push
621
622 if set_clauses.is_empty() {
623 return Ok(0); }
625
626 param_idx += 1;
627 let sql = format!(
628 "UPDATE {} SET {} WHERE {} = {}",
629 #table,
630 set_clauses.join(", "),
631 #pk_col,
632 karbon::db::placeholder(param_idx),
633 );
634
635 let mut query = sqlx::query(&sql);
636
637 #(#bind_pushes)*
638 #timestamps_bind_push
639
640 query = query.bind(&self.#pk_ident);
642
643 let result = query.execute(executor).await?;
644 Ok(result.rows_affected())
645 }
646 }
647 };
648
649 TokenStream::from(expanded)
650}