1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{
4 parse::{Parse, ParseStream},
5 parse_macro_input, Data, DeriveInput, Fields, FnArg, ImplItem, ItemImpl, LitStr, Pat, Token,
6 Type,
7};
8
9struct ControllerArgs {
14 prefix: String,
15 role: Option<String>,
16 state: Option<String>,
17}
18
19impl Parse for ControllerArgs {
20 fn parse(input: ParseStream) -> syn::Result<Self> {
21 let mut prefix = None;
22 let mut role = None;
23 let mut state = None;
24
25 while !input.is_empty() {
26 let ident: syn::Ident = input.parse()?;
27 let _: Token![=] = input.parse()?;
28 let lit: LitStr = input.parse()?;
29
30 match ident.to_string().as_str() {
31 "prefix" => prefix = Some(lit.value()),
32 "role" => role = Some(lit.value()),
33 "state" => state = Some(lit.value()),
34 _ => return Err(syn::Error::new(ident.span(), "expected `prefix`, `role`, or `state`")),
35 }
36
37 let _ = input.parse::<Token![,]>();
39 }
40
41 Ok(ControllerArgs {
42 prefix: prefix.unwrap_or_default(),
43 role,
44 state,
45 })
46 }
47}
48
49struct RouteInfo {
51 method: String,
52 path: String,
53 fn_name: syn::Ident,
54}
55
56fn find_auth_guard_param(method: &syn::ImplItemFn) -> Option<syn::Ident> {
58 for arg in &method.sig.inputs {
59 if let FnArg::Typed(pat_type) = arg {
60 if let Type::Path(type_path) = pat_type.ty.as_ref() {
61 let last_seg = type_path.path.segments.last();
62 if let Some(seg) = last_seg {
63 if seg.ident == "AuthGuard" {
64 if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
66 return Some(pat_ident.ident.clone());
67 }
68 }
69 }
70 }
71 }
72 }
73 None
74}
75
76fn extract_route_info(item: &ImplItem) -> Option<RouteInfo> {
78 let method_item = match item {
79 ImplItem::Fn(m) => m,
80 _ => return None,
81 };
82
83 let fn_name = method_item.sig.ident.clone();
84 let mut http_method = None;
85 let mut path = None;
86
87 for attr in &method_item.attrs {
88 let seg = attr.path().segments.last()?;
89 let name = seg.ident.to_string();
90
91 match name.as_str() {
92 "get" | "post" | "put" | "delete" | "patch" => {
93 http_method = Some(name);
94 if let Ok(lit) = attr.parse_args::<LitStr>() {
95 path = Some(lit.value());
96 }
97 }
98 _ => {}
99 }
100 }
101
102 Some(RouteInfo {
103 method: http_method?,
104 path: path.unwrap_or_else(|| "/".to_string()),
105 fn_name,
106 })
107}
108
109#[proc_macro_attribute]
132pub fn controller(args: TokenStream, input: TokenStream) -> TokenStream {
133 let args = parse_macro_input!(args as ControllerArgs);
134 let mut impl_block = parse_macro_input!(input as ItemImpl);
135 let prefix = &args.prefix;
136
137 let routes: Vec<RouteInfo> = impl_block
139 .items
140 .iter()
141 .filter_map(extract_route_info)
142 .collect();
143
144 let controller_role = args.role.clone();
145
146 for item in &mut impl_block.items {
148 if let ImplItem::Fn(method) = item {
149 let mut role_value = None;
151 for attr in &method.attrs {
152 let name = attr
153 .path()
154 .segments
155 .last()
156 .map(|s| s.ident.to_string())
157 .unwrap_or_default();
158 if name == "require_role" {
159 if let Ok(lit) = attr.parse_args::<LitStr>() {
160 role_value = Some(lit.value());
161 }
162 }
163 }
164
165 let effective_role = role_value.or_else(|| controller_role.clone());
167
168 if let Some(role) = effective_role {
170 if let Some(auth_param) = find_auth_guard_param(method) {
171 let check = syn::parse2::<syn::Stmt>(quote! {
172 #auth_param.require_role(#role)?;
173 })
174 .expect("Failed to parse role check statement");
175
176 method.block.stmts.insert(0, check);
177 }
178 }
179
180 method.attrs.retain(|attr| {
182 let name = attr
183 .path()
184 .segments
185 .last()
186 .map(|s| s.ident.to_string())
187 .unwrap_or_default();
188 !matches!(
189 name.as_str(),
190 "get" | "post" | "put" | "delete" | "patch" | "require_role"
191 )
192 });
193 }
194 }
195
196 let mut path_groups: std::collections::BTreeMap<String, Vec<&RouteInfo>> =
198 std::collections::BTreeMap::new();
199 for route in &routes {
200 path_groups
201 .entry(route.path.clone())
202 .or_default()
203 .push(route);
204 }
205
206 let route_registrations: Vec<proc_macro2::TokenStream> = path_groups
207 .iter()
208 .map(|(path, methods)| {
209 let mut chain = Vec::new();
210 for (i, route) in methods.iter().enumerate() {
211 let method_ident = format_ident!("{}", route.method);
212 let fn_name = &route.fn_name;
213
214 if i == 0 {
215 chain.push(quote! {
216 axum::routing::#method_ident(Self::#fn_name)
217 });
218 } else {
219 chain.push(quote! {
220 .#method_ident(Self::#fn_name)
221 });
222 }
223 }
224
225 quote! {
226 .route(#path, #(#chain)*)
227 }
228 })
229 .collect();
230
231 let self_ty = &impl_block.self_ty;
232
233 let state_type: proc_macro2::TokenStream = if let Some(ref state_path) = args.state {
235 state_path.parse().unwrap_or_else(|_| quote! { karbon::http::AppState })
236 } else {
237 quote! { karbon::http::AppState }
238 };
239
240 let expanded = quote! {
241 #impl_block
242
243 impl #self_ty {
244 pub fn router() -> axum::Router<#state_type> {
246 axum::Router::new()
247 #(#route_registrations)*
248 }
249
250 pub fn prefix() -> &'static str {
252 #prefix
253 }
254 }
255 };
256
257 TokenStream::from(expanded)
258}
259
260#[proc_macro_attribute]
266pub fn get(_args: TokenStream, input: TokenStream) -> TokenStream {
267 input
268}
269
270#[proc_macro_attribute]
272pub fn post(_args: TokenStream, input: TokenStream) -> TokenStream {
273 input
274}
275
276#[proc_macro_attribute]
278pub fn put(_args: TokenStream, input: TokenStream) -> TokenStream {
279 input
280}
281
282#[proc_macro_attribute]
284pub fn delete(_args: TokenStream, input: TokenStream) -> TokenStream {
285 input
286}
287
288#[proc_macro_attribute]
290pub fn patch(_args: TokenStream, input: TokenStream) -> TokenStream {
291 input
292}
293
294#[proc_macro_attribute]
298pub fn require_role(_args: TokenStream, input: TokenStream) -> TokenStream {
299 input
300}
301
302fn extract_table_name(attrs: &[syn::Attribute]) -> Option<String> {
307 for attr in attrs {
308 if attr.path().is_ident("table_name") {
309 if let Ok(lit) = attr.parse_args::<LitStr>() {
310 return Some(lit.value());
311 }
312 }
313 }
314 None
315}
316
317fn has_field_attr(field: &syn::Field, name: &str) -> bool {
319 field.attrs.iter().any(|attr| attr.path().is_ident(name))
320}
321
322fn has_struct_attr(attrs: &[syn::Attribute], name: &str) -> bool {
324 attrs.iter().any(|attr| attr.path().is_ident(name))
325}
326
327fn extract_slug_from(field: &syn::Field) -> Option<String> {
329 for attr in &field.attrs {
330 if attr.path().is_ident("slug_from") {
331 if let Ok(lit) = attr.parse_args::<LitStr>() {
332 return Some(lit.value());
333 }
334 }
335 }
336 None
337}
338
339fn is_option_type(ty: &Type) -> bool {
341 if let Type::Path(type_path) = ty {
342 if let Some(seg) = type_path.path.segments.last() {
343 return seg.ident == "Option";
344 }
345 }
346 false
347}
348
349#[proc_macro_derive(Insertable, attributes(table_name, auto_increment, skip_insert, timestamps, slug_from))]
379pub fn derive_insertable(input: TokenStream) -> TokenStream {
380 let input = parse_macro_input!(input as DeriveInput);
381 let name = &input.ident;
382
383 let table = match extract_table_name(&input.attrs) {
384 Some(t) => t,
385 None => {
386 return syn::Error::new_spanned(&input.ident, "Missing #[table_name(\"...\")]")
387 .to_compile_error()
388 .into();
389 }
390 };
391
392 let fields = match &input.data {
393 Data::Struct(data) => match &data.fields {
394 Fields::Named(f) => &f.named,
395 _ => {
396 return syn::Error::new_spanned(name, "Insertable only works on structs with named fields")
397 .to_compile_error()
398 .into();
399 }
400 },
401 _ => {
402 return syn::Error::new_spanned(name, "Insertable only works on structs")
403 .to_compile_error()
404 .into();
405 }
406 };
407
408 let has_timestamps = has_struct_attr(&input.attrs, "timestamps");
409
410 let insert_fields: Vec<_> = fields
412 .iter()
413 .filter(|f| !has_field_attr(f, "auto_increment") && !has_field_attr(f, "skip_insert"))
414 .collect();
415
416 let mut column_names: Vec<String> = insert_fields
417 .iter()
418 .map(|f| f.ident.as_ref().unwrap().to_string())
419 .collect();
420
421 if has_timestamps {
423 column_names.push("created_at".to_string());
424 }
425
426
427 let mut bind_calls: Vec<proc_macro2::TokenStream> = Vec::new();
429 let mut slug_lets: Vec<proc_macro2::TokenStream> = Vec::new();
430
431 for field in &insert_fields {
432 let ident = field.ident.as_ref().unwrap();
433
434 if let Some(source_field) = extract_slug_from(field) {
435 let source_ident = format_ident!("{}", source_field);
436 let var_name = format_ident!("__slug_{}", ident);
437 slug_lets.push(quote! {
438 let #var_name = if self.#ident.is_empty() {
439 slug::slugify(&self.#source_ident)
440 } else {
441 self.#ident.clone()
442 };
443 });
444 bind_calls.push(quote! { .bind(&#var_name) });
445 } else {
446 bind_calls.push(quote! { .bind(&self.#ident) });
447 }
448 }
449
450 if has_timestamps {
452 bind_calls.push(quote! { .bind(chrono::Utc::now()) });
453 }
454
455 let columns_str = column_names.join(", ");
457 let placeholders_str: String = (1..=column_names.len())
458 .map(|_| "?".to_string())
459 .collect::<Vec<_>>()
460 .join(", ");
461 let sql_literal = format!("INSERT INTO {} ({}) VALUES ({})", table, columns_str, placeholders_str);
462
463 let expanded = quote! {
464 impl #name {
465 pub async fn insert<'e, E>(&self, executor: E) -> karbon::error::AppResult<u64>
467 where
468 E: sqlx::Executor<'e, Database = karbon::db::Db>,
469 {
470 #(#slug_lets)*
471 let result = sqlx::query(#sql_literal)
472 #(#bind_calls)*
473 .execute(executor)
474 .await?;
475 Ok(karbon::db::last_insert_id(&result))
476 }
477 }
478 };
479
480 TokenStream::from(expanded)
481}
482
483#[proc_macro_derive(Updatable, attributes(table_name, primary_key, skip_update, timestamps))]
514pub fn derive_updatable(input: TokenStream) -> TokenStream {
515 let input = parse_macro_input!(input as DeriveInput);
516 let name = &input.ident;
517
518 let table = match extract_table_name(&input.attrs) {
519 Some(t) => t,
520 None => {
521 return syn::Error::new_spanned(&input.ident, "Missing #[table_name(\"...\")]")
522 .to_compile_error()
523 .into();
524 }
525 };
526
527 let fields = match &input.data {
528 Data::Struct(data) => match &data.fields {
529 Fields::Named(f) => &f.named,
530 _ => {
531 return syn::Error::new_spanned(name, "Updatable only works on structs with named fields")
532 .to_compile_error()
533 .into();
534 }
535 },
536 _ => {
537 return syn::Error::new_spanned(name, "Updatable only works on structs")
538 .to_compile_error()
539 .into();
540 }
541 };
542
543 let pk_field = fields.iter().find(|f| has_field_attr(f, "primary_key"));
545 let pk_field = match pk_field {
546 Some(f) => f,
547 None => {
548 return syn::Error::new_spanned(name, "Updatable requires exactly one #[primary_key] field")
549 .to_compile_error()
550 .into();
551 }
552 };
553 let pk_ident = pk_field.ident.as_ref().unwrap();
554 let pk_col = pk_ident.to_string();
555
556 let has_timestamps = has_struct_attr(&input.attrs, "timestamps");
557
558 let update_fields: Vec<_> = fields
560 .iter()
561 .filter(|f| !has_field_attr(f, "primary_key") && !has_field_attr(f, "skip_update"))
562 .collect();
563
564 let mut set_pushes = Vec::new();
566 let mut bind_pushes = Vec::new();
567
568 for field in &update_fields {
569 let ident = field.ident.as_ref().unwrap();
570 let col_name = ident.to_string();
571
572 if is_option_type(&field.ty) {
573 set_pushes.push(quote! {
574 if self.#ident.is_some() {
575 param_idx += 1;
576 set_clauses.push(format!("{} = {}", #col_name, karbon::db::placeholder(param_idx)));
577 }
578 });
579 bind_pushes.push(quote! {
580 if let Some(ref val) = self.#ident {
581 query = query.bind(val);
582 }
583 });
584 } else {
585 set_pushes.push(quote! {
586 param_idx += 1;
587 set_clauses.push(format!("{} = {}", #col_name, karbon::db::placeholder(param_idx)));
588 });
589 bind_pushes.push(quote! {
590 query = query.bind(&self.#ident);
591 });
592 }
593 }
594
595 let timestamps_set_push = if has_timestamps {
597 quote! {
598 param_idx += 1;
599 set_clauses.push(format!("{} = {}", "updated_at", karbon::db::placeholder(param_idx)));
600 }
601 } else {
602 quote! {}
603 };
604
605 let timestamps_bind_push = if has_timestamps {
606 quote! {
607 query = query.bind(chrono::Utc::now());
608 }
609 } else {
610 quote! {}
611 };
612
613 let expanded = quote! {
614 impl #name {
615 pub async fn update<'e, E>(&self, executor: E) -> karbon::error::AppResult<u64>
624 where
625 E: sqlx::Executor<'e, Database = karbon::db::Db>,
626 {
627 let mut set_clauses: Vec<String> = Vec::new();
628 let mut param_idx: usize = 0;
629
630 #(#set_pushes)*
631 #timestamps_set_push
632
633 if set_clauses.is_empty() {
634 return Ok(0); }
636
637 param_idx += 1;
638 let sql = format!(
639 "UPDATE {} SET {} WHERE {} = {}",
640 #table,
641 set_clauses.join(", "),
642 #pk_col,
643 karbon::db::placeholder(param_idx),
644 );
645
646 let mut query = sqlx::query(&sql);
647
648 #(#bind_pushes)*
649 #timestamps_bind_push
650
651 query = query.bind(&self.#pk_ident);
653
654 let result = query.execute(executor).await?;
655 Ok(result.rows_affected())
656 }
657 }
658 };
659
660 TokenStream::from(expanded)
661}