1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{
4 parse::{Parse, ParseStream},
5 parse_macro_input, Data, DeriveInput, Fields, FnArg, ImplItem, ItemImpl, LitStr, Pat, Token,
6 Type,
7};
8
9struct ControllerArgs {
14 prefix: String,
15 role: Option<String>,
16 state: Option<String>,
17}
18
19impl Parse for ControllerArgs {
20 fn parse(input: ParseStream) -> syn::Result<Self> {
21 let mut prefix = None;
22 let mut role = None;
23 let mut state = None;
24
25 while !input.is_empty() {
26 let ident: syn::Ident = input.parse()?;
27 let _: Token![=] = input.parse()?;
28 let lit: LitStr = input.parse()?;
29
30 match ident.to_string().as_str() {
31 "prefix" => prefix = Some(lit.value()),
32 "role" => role = Some(lit.value()),
33 "state" => state = Some(lit.value()),
34 _ => return Err(syn::Error::new(ident.span(), "expected `prefix`, `role`, or `state`")),
35 }
36
37 let _ = input.parse::<Token![,]>();
39 }
40
41 Ok(ControllerArgs {
42 prefix: prefix.unwrap_or_default(),
43 role,
44 state,
45 })
46 }
47}
48
49struct RouteInfo {
51 method: String,
52 path: String,
53 fn_name: syn::Ident,
54}
55
56fn find_auth_guard_param(method: &syn::ImplItemFn) -> Option<syn::Ident> {
58 for arg in &method.sig.inputs {
59 if let FnArg::Typed(pat_type) = arg {
60 if let Type::Path(type_path) = pat_type.ty.as_ref() {
61 let last_seg = type_path.path.segments.last();
62 if let Some(seg) = last_seg {
63 if seg.ident == "AuthGuard" {
64 if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
66 return Some(pat_ident.ident.clone());
67 }
68 }
69 }
70 }
71 }
72 }
73 None
74}
75
76fn extract_route_info(item: &ImplItem) -> Option<RouteInfo> {
78 let method_item = match item {
79 ImplItem::Fn(m) => m,
80 _ => return None,
81 };
82
83 let fn_name = method_item.sig.ident.clone();
84 let mut http_method = None;
85 let mut path = None;
86
87 for attr in &method_item.attrs {
88 let seg = attr.path().segments.last()?;
89 let name = seg.ident.to_string();
90
91 match name.as_str() {
92 "get" | "post" | "put" | "delete" | "patch" => {
93 http_method = Some(name);
94 if let Ok(lit) = attr.parse_args::<LitStr>() {
95 path = Some(lit.value());
96 }
97 }
98 _ => {}
99 }
100 }
101
102 Some(RouteInfo {
103 method: http_method?,
104 path: path.unwrap_or_else(|| "/".to_string()),
105 fn_name,
106 })
107}
108
109#[proc_macro_attribute]
132pub fn controller(args: TokenStream, input: TokenStream) -> TokenStream {
133 let args = parse_macro_input!(args as ControllerArgs);
134 let mut impl_block = parse_macro_input!(input as ItemImpl);
135 let prefix = &args.prefix;
136
137 let routes: Vec<RouteInfo> = impl_block
139 .items
140 .iter()
141 .filter_map(extract_route_info)
142 .collect();
143
144 let controller_role = args.role.clone();
145
146 for item in &mut impl_block.items {
148 if let ImplItem::Fn(method) = item {
149 let mut role_value = None;
151 for attr in &method.attrs {
152 let name = attr
153 .path()
154 .segments
155 .last()
156 .map(|s| s.ident.to_string())
157 .unwrap_or_default();
158 if name == "require_role" {
159 if let Ok(lit) = attr.parse_args::<LitStr>() {
160 role_value = Some(lit.value());
161 }
162 }
163 }
164
165 let effective_role = role_value.or_else(|| controller_role.clone());
167
168 if let Some(role) = effective_role {
170 if let Some(auth_param) = find_auth_guard_param(method) {
171 let check = syn::parse2::<syn::Stmt>(quote! {
172 #auth_param.require_role(#role)?;
173 })
174 .expect("Failed to parse role check statement");
175
176 method.block.stmts.insert(0, check);
177 }
178 }
179
180 method.attrs.retain(|attr| {
182 let name = attr
183 .path()
184 .segments
185 .last()
186 .map(|s| s.ident.to_string())
187 .unwrap_or_default();
188 !matches!(
189 name.as_str(),
190 "get" | "post" | "put" | "delete" | "patch" | "require_role"
191 )
192 });
193 }
194 }
195
196 let mut path_groups: std::collections::BTreeMap<String, Vec<&RouteInfo>> =
198 std::collections::BTreeMap::new();
199 for route in &routes {
200 path_groups
201 .entry(route.path.clone())
202 .or_default()
203 .push(route);
204 }
205
206 let route_registrations: Vec<proc_macro2::TokenStream> = path_groups
207 .iter()
208 .map(|(path, methods)| {
209 let mut chain = Vec::new();
210 for (i, route) in methods.iter().enumerate() {
211 let method_ident = format_ident!("{}", route.method);
212 let fn_name = &route.fn_name;
213
214 if i == 0 {
215 chain.push(quote! {
216 axum::routing::#method_ident(Self::#fn_name)
217 });
218 } else {
219 chain.push(quote! {
220 .#method_ident(Self::#fn_name)
221 });
222 }
223 }
224
225 quote! {
226 .route(#path, #(#chain)*)
227 }
228 })
229 .collect();
230
231 let self_ty = &impl_block.self_ty;
232
233 let state_type: proc_macro2::TokenStream = if let Some(ref state_path) = args.state {
235 state_path.parse().unwrap_or_else(|_| quote! { karbon::http::AppState })
236 } else {
237 quote! { karbon::http::AppState }
238 };
239
240 let expanded = quote! {
241 #impl_block
242
243 impl #self_ty {
244 pub fn router() -> axum::Router<#state_type> {
246 axum::Router::new()
247 #(#route_registrations)*
248 }
249
250 pub fn prefix() -> &'static str {
252 #prefix
253 }
254 }
255 };
256
257 TokenStream::from(expanded)
258}
259
260#[proc_macro_attribute]
266pub fn get(_args: TokenStream, input: TokenStream) -> TokenStream {
267 input
268}
269
270#[proc_macro_attribute]
272pub fn post(_args: TokenStream, input: TokenStream) -> TokenStream {
273 input
274}
275
276#[proc_macro_attribute]
278pub fn put(_args: TokenStream, input: TokenStream) -> TokenStream {
279 input
280}
281
282#[proc_macro_attribute]
284pub fn delete(_args: TokenStream, input: TokenStream) -> TokenStream {
285 input
286}
287
288#[proc_macro_attribute]
290pub fn patch(_args: TokenStream, input: TokenStream) -> TokenStream {
291 input
292}
293
294#[proc_macro_attribute]
298pub fn require_role(_args: TokenStream, input: TokenStream) -> TokenStream {
299 input
300}
301
302fn extract_table_name(attrs: &[syn::Attribute]) -> Option<String> {
307 for attr in attrs {
308 if attr.path().is_ident("table_name") {
309 if let Ok(lit) = attr.parse_args::<LitStr>() {
310 return Some(lit.value());
311 }
312 }
313 }
314 None
315}
316
317fn has_field_attr(field: &syn::Field, name: &str) -> bool {
319 field.attrs.iter().any(|attr| attr.path().is_ident(name))
320}
321
322fn has_struct_attr(attrs: &[syn::Attribute], name: &str) -> bool {
324 attrs.iter().any(|attr| attr.path().is_ident(name))
325}
326
327fn extract_slug_from(field: &syn::Field) -> Option<String> {
329 for attr in &field.attrs {
330 if attr.path().is_ident("slug_from") {
331 if let Ok(lit) = attr.parse_args::<LitStr>() {
332 return Some(lit.value());
333 }
334 }
335 }
336 None
337}
338
339fn is_option_type(ty: &Type) -> bool {
341 if let Type::Path(type_path) = ty {
342 if let Some(seg) = type_path.path.segments.last() {
343 return seg.ident == "Option";
344 }
345 }
346 false
347}
348
349#[proc_macro_derive(Insertable, attributes(table_name, auto_increment, skip_insert, timestamps, slug_from))]
379pub fn derive_insertable(input: TokenStream) -> TokenStream {
380 let input = parse_macro_input!(input as DeriveInput);
381 let name = &input.ident;
382
383 let table = match extract_table_name(&input.attrs) {
384 Some(t) => t,
385 None => {
386 return syn::Error::new_spanned(&input.ident, "Missing #[table_name(\"...\")]")
387 .to_compile_error()
388 .into();
389 }
390 };
391
392 let fields = match &input.data {
393 Data::Struct(data) => match &data.fields {
394 Fields::Named(f) => &f.named,
395 _ => {
396 return syn::Error::new_spanned(name, "Insertable only works on structs with named fields")
397 .to_compile_error()
398 .into();
399 }
400 },
401 _ => {
402 return syn::Error::new_spanned(name, "Insertable only works on structs")
403 .to_compile_error()
404 .into();
405 }
406 };
407
408 let has_timestamps = has_struct_attr(&input.attrs, "timestamps");
409
410 let insert_fields: Vec<_> = fields
412 .iter()
413 .filter(|f| !has_field_attr(f, "auto_increment") && !has_field_attr(f, "skip_insert"))
414 .collect();
415
416 let mut column_names: Vec<String> = insert_fields
417 .iter()
418 .map(|f| {
419 let name = f.ident.as_ref().unwrap().to_string();
420 name.strip_prefix("r#").unwrap_or(&name).to_string()
421 })
422 .collect();
423
424 if has_timestamps {
426 column_names.push("created_at".to_string());
427 }
428
429
430 let mut bind_calls: Vec<proc_macro2::TokenStream> = Vec::new();
432 let mut slug_lets: Vec<proc_macro2::TokenStream> = Vec::new();
433
434 for field in &insert_fields {
435 let ident = field.ident.as_ref().unwrap();
436
437 if let Some(source_field) = extract_slug_from(field) {
438 let source_ident = format_ident!("{}", source_field);
439 let var_name = format_ident!("__slug_{}", ident);
440 slug_lets.push(quote! {
441 let #var_name = if self.#ident.is_empty() {
442 slug::slugify(&self.#source_ident)
443 } else {
444 self.#ident.clone()
445 };
446 });
447 bind_calls.push(quote! { .bind(&#var_name) });
448 } else {
449 bind_calls.push(quote! { .bind(&self.#ident) });
450 }
451 }
452
453 if has_timestamps {
455 bind_calls.push(quote! { .bind(chrono::Utc::now()) });
456 }
457
458 let columns_str = column_names.join(", ");
460 let placeholders_str: String = (1..=column_names.len())
461 .map(|_| "?".to_string())
462 .collect::<Vec<_>>()
463 .join(", ");
464 let sql_literal = format!("INSERT INTO {} ({}) VALUES ({})", table, columns_str, placeholders_str);
465
466 let expanded = quote! {
467 impl #name {
468 pub async fn insert<'e, E>(&self, executor: E) -> karbon::error::AppResult<u64>
470 where
471 E: sqlx::Executor<'e, Database = karbon::db::Db>,
472 {
473 #(#slug_lets)*
474 let result = sqlx::query(#sql_literal)
475 #(#bind_calls)*
476 .execute(executor)
477 .await?;
478 Ok(karbon::db::last_insert_id(&result))
479 }
480 }
481 };
482
483 TokenStream::from(expanded)
484}
485
486#[proc_macro_derive(Updatable, attributes(table_name, primary_key, skip_update, timestamps))]
517pub fn derive_updatable(input: TokenStream) -> TokenStream {
518 let input = parse_macro_input!(input as DeriveInput);
519 let name = &input.ident;
520
521 let table = match extract_table_name(&input.attrs) {
522 Some(t) => t,
523 None => {
524 return syn::Error::new_spanned(&input.ident, "Missing #[table_name(\"...\")]")
525 .to_compile_error()
526 .into();
527 }
528 };
529
530 let fields = match &input.data {
531 Data::Struct(data) => match &data.fields {
532 Fields::Named(f) => &f.named,
533 _ => {
534 return syn::Error::new_spanned(name, "Updatable only works on structs with named fields")
535 .to_compile_error()
536 .into();
537 }
538 },
539 _ => {
540 return syn::Error::new_spanned(name, "Updatable only works on structs")
541 .to_compile_error()
542 .into();
543 }
544 };
545
546 let pk_field = fields.iter().find(|f| has_field_attr(f, "primary_key"));
548 let pk_field = match pk_field {
549 Some(f) => f,
550 None => {
551 return syn::Error::new_spanned(name, "Updatable requires exactly one #[primary_key] field")
552 .to_compile_error()
553 .into();
554 }
555 };
556 let pk_ident = pk_field.ident.as_ref().unwrap();
557 let pk_col_raw = pk_ident.to_string();
558 let pk_col = pk_col_raw.strip_prefix("r#").unwrap_or(&pk_col_raw).to_string();
559
560 let has_timestamps = has_struct_attr(&input.attrs, "timestamps");
561
562 let update_fields: Vec<_> = fields
564 .iter()
565 .filter(|f| !has_field_attr(f, "primary_key") && !has_field_attr(f, "skip_update"))
566 .collect();
567
568 let mut set_pushes = Vec::new();
570 let mut bind_pushes = Vec::new();
571
572 for field in &update_fields {
573 let ident = field.ident.as_ref().unwrap();
574 let col_name_raw = ident.to_string();
575 let col_name = col_name_raw.strip_prefix("r#").unwrap_or(&col_name_raw).to_string();
576
577 if is_option_type(&field.ty) {
578 set_pushes.push(quote! {
579 if self.#ident.is_some() {
580 param_idx += 1;
581 set_clauses.push(format!("{} = {}", #col_name, karbon::db::placeholder(param_idx)));
582 }
583 });
584 bind_pushes.push(quote! {
585 if let Some(ref val) = self.#ident {
586 query = query.bind(val);
587 }
588 });
589 } else {
590 set_pushes.push(quote! {
591 param_idx += 1;
592 set_clauses.push(format!("{} = {}", #col_name, karbon::db::placeholder(param_idx)));
593 });
594 bind_pushes.push(quote! {
595 query = query.bind(&self.#ident);
596 });
597 }
598 }
599
600 let timestamps_set_push = if has_timestamps {
602 quote! {
603 param_idx += 1;
604 set_clauses.push(format!("{} = {}", "updated_at", karbon::db::placeholder(param_idx)));
605 }
606 } else {
607 quote! {}
608 };
609
610 let timestamps_bind_push = if has_timestamps {
611 quote! {
612 query = query.bind(chrono::Utc::now());
613 }
614 } else {
615 quote! {}
616 };
617
618 let expanded = quote! {
619 impl #name {
620 pub async fn update<'e, E>(&self, executor: E) -> karbon::error::AppResult<u64>
629 where
630 E: sqlx::Executor<'e, Database = karbon::db::Db>,
631 {
632 let mut set_clauses: Vec<String> = Vec::new();
633 let mut param_idx: usize = 0;
634
635 #(#set_pushes)*
636 #timestamps_set_push
637
638 if set_clauses.is_empty() {
639 return Ok(0); }
641
642 param_idx += 1;
643 let sql = format!(
644 "UPDATE {} SET {} WHERE {} = {}",
645 #table,
646 set_clauses.join(", "),
647 #pk_col,
648 karbon::db::placeholder(param_idx),
649 );
650
651 let mut query = sqlx::query(&sql);
652
653 #(#bind_pushes)*
654 #timestamps_bind_push
655
656 query = query.bind(&self.#pk_ident);
658
659 let result = query.execute(executor).await?;
660 Ok(result.rows_affected())
661 }
662 }
663 };
664
665 TokenStream::from(expanded)
666}