Skip to main content

karbon_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{
4    parse::{Parse, ParseStream},
5    parse_macro_input, Data, DeriveInput, Fields, FnArg, ImplItem, ItemImpl, LitStr, Pat, Token,
6    Type,
7};
8
9// ─────────────────────────────────────────────
10// #[controller(prefix = "/api/v1/admin/articles")]
11// ─────────────────────────────────────────────
12
13struct ControllerArgs {
14    prefix: String,
15    role: Option<String>,
16}
17
18impl Parse for ControllerArgs {
19    fn parse(input: ParseStream) -> syn::Result<Self> {
20        let mut prefix = None;
21        let mut role = None;
22
23        while !input.is_empty() {
24            let ident: syn::Ident = input.parse()?;
25            let _: Token![=] = input.parse()?;
26            let lit: LitStr = input.parse()?;
27
28            match ident.to_string().as_str() {
29                "prefix" => prefix = Some(lit.value()),
30                "role" => role = Some(lit.value()),
31                _ => return Err(syn::Error::new(ident.span(), "expected `prefix` or `role`")),
32            }
33
34            // Consume optional comma
35            let _ = input.parse::<Token![,]>();
36        }
37
38        Ok(ControllerArgs {
39            prefix: prefix.unwrap_or_default(),
40            role,
41        })
42    }
43}
44
45/// Route info extracted from method attributes
46struct RouteInfo {
47    method: String,
48    path: String,
49    fn_name: syn::Ident,
50}
51
52/// Find the parameter name that has type AuthGuard in a method signature
53fn find_auth_guard_param(method: &syn::ImplItemFn) -> Option<syn::Ident> {
54    for arg in &method.sig.inputs {
55        if let FnArg::Typed(pat_type) = arg {
56            if let Type::Path(type_path) = pat_type.ty.as_ref() {
57                let last_seg = type_path.path.segments.last();
58                if let Some(seg) = last_seg {
59                    if seg.ident == "AuthGuard" {
60                        // Extract the pattern name
61                        if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
62                            return Some(pat_ident.ident.clone());
63                        }
64                    }
65                }
66            }
67        }
68    }
69    None
70}
71
72/// Extracts route annotations from an impl method.
73fn extract_route_info(item: &ImplItem) -> Option<RouteInfo> {
74    let method_item = match item {
75        ImplItem::Fn(m) => m,
76        _ => return None,
77    };
78
79    let fn_name = method_item.sig.ident.clone();
80    let mut http_method = None;
81    let mut path = None;
82
83    for attr in &method_item.attrs {
84        let seg = attr.path().segments.last()?;
85        let name = seg.ident.to_string();
86
87        match name.as_str() {
88            "get" | "post" | "put" | "delete" | "patch" => {
89                http_method = Some(name);
90                if let Ok(lit) = attr.parse_args::<LitStr>() {
91                    path = Some(lit.value());
92                }
93            }
94            _ => {}
95        }
96    }
97
98    Some(RouteInfo {
99        method: http_method?,
100        path: path.unwrap_or_else(|| "/".to_string()),
101        fn_name,
102    })
103}
104
105/// # Controller macro
106///
107/// Generates a `router()` function from annotated methods.
108/// `#[require_role("ROLE_X")]` auto-injects `auth.require_role("ROLE_X")?;`
109/// at the beginning of the handler body.
110///
111/// ## Usage
112/// ```ignore
113/// #[controller(prefix = "/api/v1/admin/articles")]
114/// impl ArticleController {
115///     #[get("/")]
116///     #[require_role("ROLE_ADMIN")]
117///     async fn list(
118///         auth: AuthGuard,
119///         State(state): State<AppState>,
120///     ) -> AppResult<impl IntoResponse> {
121///         // auth.require_role("ROLE_ADMIN")?; ← auto-injected!
122///         let items = ...;
123///         Ok(Json(items))
124///     }
125/// }
126/// ```
127#[proc_macro_attribute]
128pub fn controller(args: TokenStream, input: TokenStream) -> TokenStream {
129    let args = parse_macro_input!(args as ControllerArgs);
130    let mut impl_block = parse_macro_input!(input as ItemImpl);
131    let prefix = &args.prefix;
132
133    // Extract route info BEFORE modifying the impl block
134    let routes: Vec<RouteInfo> = impl_block
135        .items
136        .iter()
137        .filter_map(extract_route_info)
138        .collect();
139
140    let controller_role = args.role.clone();
141
142    // Inject role checks and strip custom attributes
143    for item in &mut impl_block.items {
144        if let ImplItem::Fn(method) = item {
145            // Find require_role value for this method (method-level overrides controller-level)
146            let mut role_value = None;
147            for attr in &method.attrs {
148                let name = attr
149                    .path()
150                    .segments
151                    .last()
152                    .map(|s| s.ident.to_string())
153                    .unwrap_or_default();
154                if name == "require_role" {
155                    if let Ok(lit) = attr.parse_args::<LitStr>() {
156                        role_value = Some(lit.value());
157                    }
158                }
159            }
160
161            // Fall back to controller-level role if no method-level role
162            let effective_role = role_value.or_else(|| controller_role.clone());
163
164            // If a role is set, inject the check at the start of the body
165            if let Some(role) = effective_role {
166                if let Some(auth_param) = find_auth_guard_param(method) {
167                    let check = syn::parse2::<syn::Stmt>(quote! {
168                        #auth_param.require_role(#role)?;
169                    })
170                    .expect("Failed to parse role check statement");
171
172                    method.block.stmts.insert(0, check);
173                }
174            }
175
176            // Strip custom attributes
177            method.attrs.retain(|attr| {
178                let name = attr
179                    .path()
180                    .segments
181                    .last()
182                    .map(|s| s.ident.to_string())
183                    .unwrap_or_default();
184                !matches!(
185                    name.as_str(),
186                    "get" | "post" | "put" | "delete" | "patch" | "require_role"
187                )
188            });
189        }
190    }
191
192    // Build route registrations grouped by path
193    let mut path_groups: std::collections::BTreeMap<String, Vec<&RouteInfo>> =
194        std::collections::BTreeMap::new();
195    for route in &routes {
196        path_groups
197            .entry(route.path.clone())
198            .or_default()
199            .push(route);
200    }
201
202    let route_registrations: Vec<proc_macro2::TokenStream> = path_groups
203        .iter()
204        .map(|(path, methods)| {
205            let mut chain = Vec::new();
206            for (i, route) in methods.iter().enumerate() {
207                let method_ident = format_ident!("{}", route.method);
208                let fn_name = &route.fn_name;
209
210                if i == 0 {
211                    chain.push(quote! {
212                        axum::routing::#method_ident(Self::#fn_name)
213                    });
214                } else {
215                    chain.push(quote! {
216                        .#method_ident(Self::#fn_name)
217                    });
218                }
219            }
220
221            quote! {
222                .route(#path, #(#chain)*)
223            }
224        })
225        .collect();
226
227    let self_ty = &impl_block.self_ty;
228
229    let expanded = quote! {
230        #impl_block
231
232        impl #self_ty {
233            /// Auto-generated router from #[controller] annotations
234            pub fn router() -> axum::Router<karbon::http::AppState> {
235                axum::Router::new()
236                    #(#route_registrations)*
237            }
238
239            /// Returns the route prefix for this controller
240            pub fn prefix() -> &'static str {
241                #prefix
242            }
243        }
244    };
245
246    TokenStream::from(expanded)
247}
248
249// ─────────────────────────────────────────────
250// Standalone route macros (marker attributes, consumed by #[controller])
251// ─────────────────────────────────────────────
252
253/// Mark a handler as a GET route: `#[get("/path")]`
254#[proc_macro_attribute]
255pub fn get(_args: TokenStream, input: TokenStream) -> TokenStream {
256    input
257}
258
259/// Mark a handler as a POST route: `#[post("/path")]`
260#[proc_macro_attribute]
261pub fn post(_args: TokenStream, input: TokenStream) -> TokenStream {
262    input
263}
264
265/// Mark a handler as a PUT route: `#[put("/path")]`
266#[proc_macro_attribute]
267pub fn put(_args: TokenStream, input: TokenStream) -> TokenStream {
268    input
269}
270
271/// Mark a handler as a DELETE route: `#[delete("/path")]`
272#[proc_macro_attribute]
273pub fn delete(_args: TokenStream, input: TokenStream) -> TokenStream {
274    input
275}
276
277/// Mark a handler as a PATCH route: `#[patch("/path")]`
278#[proc_macro_attribute]
279pub fn patch(_args: TokenStream, input: TokenStream) -> TokenStream {
280    input
281}
282
283/// Require a role — auto-injects `auth.require_role("ROLE")?;` at the start of the handler.
284/// The handler must have an `AuthGuard` parameter.
285/// Usage: `#[require_role("ROLE_ADMIN")]`
286#[proc_macro_attribute]
287pub fn require_role(_args: TokenStream, input: TokenStream) -> TokenStream {
288    input
289}
290
291// ─────────────────────────────────────────────
292// Helper: extract #[table_name = "..."] from attributes
293// ─────────────────────────────────────────────
294
295fn extract_table_name(attrs: &[syn::Attribute]) -> Option<String> {
296    for attr in attrs {
297        if attr.path().is_ident("table_name") {
298            if let Ok(lit) = attr.parse_args::<LitStr>() {
299                return Some(lit.value());
300            }
301        }
302    }
303    None
304}
305
306/// Check if a field has a specific helper attribute like #[auto_increment], #[skip_insert], etc.
307fn has_field_attr(field: &syn::Field, name: &str) -> bool {
308    field.attrs.iter().any(|attr| attr.path().is_ident(name))
309}
310
311/// Check if the struct has a specific attribute like #[timestamps]
312fn has_struct_attr(attrs: &[syn::Attribute], name: &str) -> bool {
313    attrs.iter().any(|attr| attr.path().is_ident(name))
314}
315
316/// Extract #[slug_from("field_name")] value from a field
317fn extract_slug_from(field: &syn::Field) -> Option<String> {
318    for attr in &field.attrs {
319        if attr.path().is_ident("slug_from") {
320            if let Ok(lit) = attr.parse_args::<LitStr>() {
321                return Some(lit.value());
322            }
323        }
324    }
325    None
326}
327
328/// Check if a field type is Option<T>
329fn is_option_type(ty: &Type) -> bool {
330    if let Type::Path(type_path) = ty {
331        if let Some(seg) = type_path.path.segments.last() {
332            return seg.ident == "Option";
333        }
334    }
335    false
336}
337
338// ─────────────────────────────────────────────
339// #[derive(Insertable)]
340// ─────────────────────────────────────────────
341
342/// Derive macro that generates an `insert()` method for a struct.
343///
344/// ## Attributes
345/// - `#[table_name("table")]` — required, specifies the DB table
346/// - `#[auto_increment]` — on a field: exclude from INSERT (auto-generated by DB)
347/// - `#[skip_insert]` — on a field: exclude from INSERT (e.g. created with DEFAULT)
348/// - `#[timestamps]` — on the struct: auto-adds `created_at` column with `Utc::now()`
349/// - `#[slug_from("field")]` — on a field: auto-generates slug from another field if empty
350///
351/// ## Example
352/// ```ignore
353/// #[derive(Insertable)]
354/// #[table_name("posts")]
355/// #[timestamps]
356/// pub struct NewPost {
357///     pub title: String,
358///     #[slug_from("title")]
359///     pub slug: String,
360/// }
361///
362/// // Generated:
363/// // dto.insert(pool).await?;
364/// // → INSERT INTO posts (title, slug, created_at) VALUES (?, ?, ?)
365/// // slug auto-generated from title if empty
366/// ```
367#[proc_macro_derive(Insertable, attributes(table_name, auto_increment, skip_insert, timestamps, slug_from))]
368pub fn derive_insertable(input: TokenStream) -> TokenStream {
369    let input = parse_macro_input!(input as DeriveInput);
370    let name = &input.ident;
371
372    let table = match extract_table_name(&input.attrs) {
373        Some(t) => t,
374        None => {
375            return syn::Error::new_spanned(&input.ident, "Missing #[table_name(\"...\")]")
376                .to_compile_error()
377                .into();
378        }
379    };
380
381    let fields = match &input.data {
382        Data::Struct(data) => match &data.fields {
383            Fields::Named(f) => &f.named,
384            _ => {
385                return syn::Error::new_spanned(name, "Insertable only works on structs with named fields")
386                    .to_compile_error()
387                    .into();
388            }
389        },
390        _ => {
391            return syn::Error::new_spanned(name, "Insertable only works on structs")
392                .to_compile_error()
393                .into();
394        }
395    };
396
397    let has_timestamps = has_struct_attr(&input.attrs, "timestamps");
398
399    // Collect fields that participate in INSERT (skip auto_increment and skip_insert)
400    let insert_fields: Vec<_> = fields
401        .iter()
402        .filter(|f| !has_field_attr(f, "auto_increment") && !has_field_attr(f, "skip_insert"))
403        .collect();
404
405    let mut column_names: Vec<String> = insert_fields
406        .iter()
407        .map(|f| f.ident.as_ref().unwrap().to_string())
408        .collect();
409
410    // Auto-timestamps: add created_at column
411    if has_timestamps {
412        column_names.push("created_at".to_string());
413    }
414
415
416    // Build bind calls, handling #[slug_from] fields
417    let mut bind_calls: Vec<proc_macro2::TokenStream> = Vec::new();
418    let mut slug_lets: Vec<proc_macro2::TokenStream> = Vec::new();
419
420    for field in &insert_fields {
421        let ident = field.ident.as_ref().unwrap();
422
423        if let Some(source_field) = extract_slug_from(field) {
424            let source_ident = format_ident!("{}", source_field);
425            let var_name = format_ident!("__slug_{}", ident);
426            slug_lets.push(quote! {
427                let #var_name = if self.#ident.is_empty() {
428                    slug::slugify(&self.#source_ident)
429                } else {
430                    self.#ident.clone()
431                };
432            });
433            bind_calls.push(quote! { .bind(&#var_name) });
434        } else {
435            bind_calls.push(quote! { .bind(&self.#ident) });
436        }
437    }
438
439    // Auto-timestamps: bind chrono::Utc::now()
440    if has_timestamps {
441        bind_calls.push(quote! { .bind(chrono::Utc::now()) });
442    }
443
444    // Build the SQL string at compile time
445    let columns_str = column_names.join(", ");
446    let placeholders_str: String = (1..=column_names.len())
447        .map(|_| "?".to_string())
448        .collect::<Vec<_>>()
449        .join(", ");
450    let sql_literal = format!("INSERT INTO {} ({}) VALUES ({})", table, columns_str, placeholders_str);
451
452    let expanded = quote! {
453        impl #name {
454            /// Auto-generated INSERT query from #[derive(Insertable)]
455            pub async fn insert<'e, E>(&self, executor: E) -> karbon::error::AppResult<u64>
456            where
457                E: sqlx::Executor<'e, Database = karbon::db::Db>,
458            {
459                #(#slug_lets)*
460                let result = sqlx::query(#sql_literal)
461                    #(#bind_calls)*
462                    .execute(executor)
463                    .await?;
464                Ok(karbon::db::last_insert_id(&result))
465            }
466        }
467    };
468
469    TokenStream::from(expanded)
470}
471
472// ─────────────────────────────────────────────
473// #[derive(Updatable)]
474// ─────────────────────────────────────────────
475
476/// Derive macro that generates an `update()` method for a struct.
477///
478/// All `Option<T>` fields are treated as partial updates — only `Some` values
479/// are included in the SET clause. Non-Option fields are always included.
480///
481/// ## Attributes
482/// - `#[table_name("table")]` — required
483/// - `#[primary_key]` — marks the WHERE clause field (required, exactly one)
484/// - `#[skip_update]` — exclude from SET clause
485///
486/// ## Example
487/// ```ignore
488/// #[derive(Updatable)]
489/// #[table_name("user")]
490/// pub struct UpdateUser {
491///     #[primary_key]
492///     pub id: i64,
493///     pub username: Option<String>,
494///     pub email: Option<String>,
495///     pub active: Option<bool>,
496/// }
497///
498/// // Generated:
499/// // UpdateUser { id: 42, username: Some("new"), email: None, active: None }.update(pool).await?;
500/// // → UPDATE user SET username=? WHERE id=?  (only Some fields)
501/// ```
502#[proc_macro_derive(Updatable, attributes(table_name, primary_key, skip_update, timestamps))]
503pub fn derive_updatable(input: TokenStream) -> TokenStream {
504    let input = parse_macro_input!(input as DeriveInput);
505    let name = &input.ident;
506
507    let table = match extract_table_name(&input.attrs) {
508        Some(t) => t,
509        None => {
510            return syn::Error::new_spanned(&input.ident, "Missing #[table_name(\"...\")]")
511                .to_compile_error()
512                .into();
513        }
514    };
515
516    let fields = match &input.data {
517        Data::Struct(data) => match &data.fields {
518            Fields::Named(f) => &f.named,
519            _ => {
520                return syn::Error::new_spanned(name, "Updatable only works on structs with named fields")
521                    .to_compile_error()
522                    .into();
523            }
524        },
525        _ => {
526            return syn::Error::new_spanned(name, "Updatable only works on structs")
527                .to_compile_error()
528                .into();
529        }
530    };
531
532    // Find the primary key field
533    let pk_field = fields.iter().find(|f| has_field_attr(f, "primary_key"));
534    let pk_field = match pk_field {
535        Some(f) => f,
536        None => {
537            return syn::Error::new_spanned(name, "Updatable requires exactly one #[primary_key] field")
538                .to_compile_error()
539                .into();
540        }
541    };
542    let pk_ident = pk_field.ident.as_ref().unwrap();
543    let pk_col = pk_ident.to_string();
544
545    let has_timestamps = has_struct_attr(&input.attrs, "timestamps");
546
547    // Collect updatable fields (not primary_key, not skip_update)
548    let update_fields: Vec<_> = fields
549        .iter()
550        .filter(|f| !has_field_attr(f, "primary_key") && !has_field_attr(f, "skip_update"))
551        .collect();
552
553    // Build the dynamic update logic with placeholder tracking
554    let mut set_pushes = Vec::new();
555    let mut bind_pushes = Vec::new();
556
557    for field in &update_fields {
558        let ident = field.ident.as_ref().unwrap();
559        let col_name = ident.to_string();
560
561        if is_option_type(&field.ty) {
562            set_pushes.push(quote! {
563                if self.#ident.is_some() {
564                    param_idx += 1;
565                    set_clauses.push(format!("{} = {}", #col_name, karbon::db::placeholder(param_idx)));
566                }
567            });
568            bind_pushes.push(quote! {
569                if let Some(ref val) = self.#ident {
570                    query = query.bind(val);
571                }
572            });
573        } else {
574            set_pushes.push(quote! {
575                param_idx += 1;
576                set_clauses.push(format!("{} = {}", #col_name, karbon::db::placeholder(param_idx)));
577            });
578            bind_pushes.push(quote! {
579                query = query.bind(&self.#ident);
580            });
581        }
582    }
583
584    // Auto-timestamps: add updated_at
585    let timestamps_set_push = if has_timestamps {
586        quote! {
587            param_idx += 1;
588            set_clauses.push(format!("{} = {}", "updated_at", karbon::db::placeholder(param_idx)));
589        }
590    } else {
591        quote! {}
592    };
593
594    let timestamps_bind_push = if has_timestamps {
595        quote! {
596            query = query.bind(chrono::Utc::now());
597        }
598    } else {
599        quote! {}
600    };
601
602    let expanded = quote! {
603        impl #name {
604            /// Auto-generated UPDATE query from #[derive(Updatable)]
605            /// Only `Some` fields on Option<T> are included in SET.
606            ///
607            /// Accepts either a `&DbPool` or a `&mut Transaction`:
608            /// ```ignore
609            /// dto.update(pool).await?;           // normal
610            /// dto.update(&mut *tx).await?;       // in transaction
611            /// ```
612            pub async fn update<'e, E>(&self, executor: E) -> karbon::error::AppResult<u64>
613            where
614                E: sqlx::Executor<'e, Database = karbon::db::Db>,
615            {
616                let mut set_clauses: Vec<String> = Vec::new();
617                let mut param_idx: usize = 0;
618
619                #(#set_pushes)*
620                #timestamps_set_push
621
622                if set_clauses.is_empty() {
623                    return Ok(0); // nothing to update
624                }
625
626                param_idx += 1;
627                let sql = format!(
628                    "UPDATE {} SET {} WHERE {} = {}",
629                    #table,
630                    set_clauses.join(", "),
631                    #pk_col,
632                    karbon::db::placeholder(param_idx),
633                );
634
635                let mut query = sqlx::query(&sql);
636
637                #(#bind_pushes)*
638                #timestamps_bind_push
639
640                // Bind the primary key for WHERE
641                query = query.bind(&self.#pk_ident);
642
643                let result = query.execute(executor).await?;
644                Ok(result.rows_affected())
645            }
646        }
647    };
648
649    TokenStream::from(expanded)
650}