Skip to main content

karbon_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{
4    parse::{Parse, ParseStream},
5    parse_macro_input, Data, DeriveInput, Fields, FnArg, ImplItem, ItemImpl, LitStr, Pat, Token,
6    Type,
7};
8
9// ─────────────────────────────────────────────
10// #[controller(prefix = "/api/v1/admin/articles")]
11// ─────────────────────────────────────────────
12
13struct ControllerArgs {
14    prefix: String,
15    role: Option<String>,
16}
17
18impl Parse for ControllerArgs {
19    fn parse(input: ParseStream) -> syn::Result<Self> {
20        let mut prefix = None;
21        let mut role = None;
22
23        while !input.is_empty() {
24            let ident: syn::Ident = input.parse()?;
25            let _: Token![=] = input.parse()?;
26            let lit: LitStr = input.parse()?;
27
28            match ident.to_string().as_str() {
29                "prefix" => prefix = Some(lit.value()),
30                "role" => role = Some(lit.value()),
31                _ => return Err(syn::Error::new(ident.span(), "expected `prefix` or `role`")),
32            }
33
34            // Consume optional comma
35            let _ = input.parse::<Token![,]>();
36        }
37
38        Ok(ControllerArgs {
39            prefix: prefix.unwrap_or_default(),
40            role,
41        })
42    }
43}
44
45/// Route info extracted from method attributes
46struct RouteInfo {
47    method: String,
48    path: String,
49    fn_name: syn::Ident,
50}
51
52/// Find the parameter name that has type AuthGuard in a method signature
53fn find_auth_guard_param(method: &syn::ImplItemFn) -> Option<syn::Ident> {
54    for arg in &method.sig.inputs {
55        if let FnArg::Typed(pat_type) = arg {
56            if let Type::Path(type_path) = pat_type.ty.as_ref() {
57                let last_seg = type_path.path.segments.last();
58                if let Some(seg) = last_seg {
59                    if seg.ident == "AuthGuard" {
60                        // Extract the pattern name
61                        if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
62                            return Some(pat_ident.ident.clone());
63                        }
64                    }
65                }
66            }
67        }
68    }
69    None
70}
71
72/// Extracts route annotations from an impl method.
73fn extract_route_info(item: &ImplItem) -> Option<RouteInfo> {
74    let method_item = match item {
75        ImplItem::Fn(m) => m,
76        _ => return None,
77    };
78
79    let fn_name = method_item.sig.ident.clone();
80    let mut http_method = None;
81    let mut path = None;
82
83    for attr in &method_item.attrs {
84        let seg = attr.path().segments.last()?;
85        let name = seg.ident.to_string();
86
87        match name.as_str() {
88            "get" | "post" | "put" | "delete" | "patch" => {
89                http_method = Some(name);
90                if let Ok(lit) = attr.parse_args::<LitStr>() {
91                    path = Some(lit.value());
92                }
93            }
94            _ => {}
95        }
96    }
97
98    Some(RouteInfo {
99        method: http_method?,
100        path: path.unwrap_or_else(|| "/".to_string()),
101        fn_name,
102    })
103}
104
105/// # Controller macro
106///
107/// Generates a `router()` function from annotated methods.
108/// `#[require_role("ROLE_X")]` auto-injects `auth.require_role("ROLE_X")?;`
109/// at the beginning of the handler body.
110///
111/// ## Usage
112/// ```ignore
113/// #[controller(prefix = "/api/v1/admin/articles")]
114/// impl ArticleController {
115///     #[get("/")]
116///     #[require_role("ROLE_ADMIN")]
117///     async fn list(
118///         auth: AuthGuard,
119///         State(state): State<AppState>,
120///     ) -> AppResult<impl IntoResponse> {
121///         // auth.require_role("ROLE_ADMIN")?; ← auto-injected!
122///         let items = ...;
123///         Ok(Json(items))
124///     }
125/// }
126/// ```
127#[proc_macro_attribute]
128pub fn controller(args: TokenStream, input: TokenStream) -> TokenStream {
129    let args = parse_macro_input!(args as ControllerArgs);
130    let mut impl_block = parse_macro_input!(input as ItemImpl);
131    let prefix = &args.prefix;
132
133    // Extract route info BEFORE modifying the impl block
134    let routes: Vec<RouteInfo> = impl_block
135        .items
136        .iter()
137        .filter_map(extract_route_info)
138        .collect();
139
140    let controller_role = args.role.clone();
141
142    // Inject role checks and strip custom attributes
143    for item in &mut impl_block.items {
144        if let ImplItem::Fn(method) = item {
145            // Find require_role value for this method (method-level overrides controller-level)
146            let mut role_value = None;
147            for attr in &method.attrs {
148                let name = attr
149                    .path()
150                    .segments
151                    .last()
152                    .map(|s| s.ident.to_string())
153                    .unwrap_or_default();
154                if name == "require_role" {
155                    if let Ok(lit) = attr.parse_args::<LitStr>() {
156                        role_value = Some(lit.value());
157                    }
158                }
159            }
160
161            // Fall back to controller-level role if no method-level role
162            let effective_role = role_value.or_else(|| controller_role.clone());
163
164            // If a role is set, inject the check at the start of the body
165            if let Some(role) = effective_role {
166                if let Some(auth_param) = find_auth_guard_param(method) {
167                    let check = syn::parse2::<syn::Stmt>(quote! {
168                        #auth_param.require_role(#role)?;
169                    })
170                    .expect("Failed to parse role check statement");
171
172                    method.block.stmts.insert(0, check);
173                }
174            }
175
176            // Strip custom attributes
177            method.attrs.retain(|attr| {
178                let name = attr
179                    .path()
180                    .segments
181                    .last()
182                    .map(|s| s.ident.to_string())
183                    .unwrap_or_default();
184                !matches!(
185                    name.as_str(),
186                    "get" | "post" | "put" | "delete" | "patch" | "require_role"
187                )
188            });
189        }
190    }
191
192    // Build route registrations grouped by path
193    let mut path_groups: std::collections::BTreeMap<String, Vec<&RouteInfo>> =
194        std::collections::BTreeMap::new();
195    for route in &routes {
196        path_groups
197            .entry(route.path.clone())
198            .or_default()
199            .push(route);
200    }
201
202    let route_registrations: Vec<proc_macro2::TokenStream> = path_groups
203        .iter()
204        .map(|(path, methods)| {
205            let mut chain = Vec::new();
206            for (i, route) in methods.iter().enumerate() {
207                let method_ident = format_ident!("{}", route.method);
208                let fn_name = &route.fn_name;
209
210                if i == 0 {
211                    chain.push(quote! {
212                        axum::routing::#method_ident(Self::#fn_name)
213                    });
214                } else {
215                    chain.push(quote! {
216                        .#method_ident(Self::#fn_name)
217                    });
218                }
219            }
220
221            quote! {
222                .route(#path, #(#chain)*)
223            }
224        })
225        .collect();
226
227    let self_ty = &impl_block.self_ty;
228
229    let expanded = quote! {
230        #impl_block
231
232        impl #self_ty {
233            /// Auto-generated router from #[controller] annotations
234            pub fn router() -> axum::Router<framework::http::AppState> {
235                axum::Router::new()
236                    #(#route_registrations)*
237            }
238
239            /// Returns the route prefix for this controller
240            pub fn prefix() -> &'static str {
241                #prefix
242            }
243        }
244    };
245
246    TokenStream::from(expanded)
247}
248
249// ─────────────────────────────────────────────
250// Standalone route macros (marker attributes, consumed by #[controller])
251// ─────────────────────────────────────────────
252
253/// Mark a handler as a GET route: `#[get("/path")]`
254#[proc_macro_attribute]
255pub fn get(_args: TokenStream, input: TokenStream) -> TokenStream {
256    input
257}
258
259/// Mark a handler as a POST route: `#[post("/path")]`
260#[proc_macro_attribute]
261pub fn post(_args: TokenStream, input: TokenStream) -> TokenStream {
262    input
263}
264
265/// Mark a handler as a PUT route: `#[put("/path")]`
266#[proc_macro_attribute]
267pub fn put(_args: TokenStream, input: TokenStream) -> TokenStream {
268    input
269}
270
271/// Mark a handler as a DELETE route: `#[delete("/path")]`
272#[proc_macro_attribute]
273pub fn delete(_args: TokenStream, input: TokenStream) -> TokenStream {
274    input
275}
276
277/// Mark a handler as a PATCH route: `#[patch("/path")]`
278#[proc_macro_attribute]
279pub fn patch(_args: TokenStream, input: TokenStream) -> TokenStream {
280    input
281}
282
283/// Require a role — auto-injects `auth.require_role("ROLE")?;` at the start of the handler.
284/// The handler must have an `AuthGuard` parameter.
285/// Usage: `#[require_role("ROLE_ADMIN")]`
286#[proc_macro_attribute]
287pub fn require_role(_args: TokenStream, input: TokenStream) -> TokenStream {
288    input
289}
290
291// ─────────────────────────────────────────────
292// Helper: extract #[table_name = "..."] from attributes
293// ─────────────────────────────────────────────
294
295fn extract_table_name(attrs: &[syn::Attribute]) -> Option<String> {
296    for attr in attrs {
297        if attr.path().is_ident("table_name") {
298            if let Ok(lit) = attr.parse_args::<LitStr>() {
299                return Some(lit.value());
300            }
301        }
302    }
303    None
304}
305
306/// Check if a field has a specific helper attribute like #[auto_increment], #[skip_insert], etc.
307fn has_field_attr(field: &syn::Field, name: &str) -> bool {
308    field.attrs.iter().any(|attr| attr.path().is_ident(name))
309}
310
311/// Check if the struct has a specific attribute like #[timestamps]
312fn has_struct_attr(attrs: &[syn::Attribute], name: &str) -> bool {
313    attrs.iter().any(|attr| attr.path().is_ident(name))
314}
315
316/// Extract #[slug_from("field_name")] value from a field
317fn extract_slug_from(field: &syn::Field) -> Option<String> {
318    for attr in &field.attrs {
319        if attr.path().is_ident("slug_from") {
320            if let Ok(lit) = attr.parse_args::<LitStr>() {
321                return Some(lit.value());
322            }
323        }
324    }
325    None
326}
327
328/// Check if a field type is Option<T>
329fn is_option_type(ty: &Type) -> bool {
330    if let Type::Path(type_path) = ty {
331        if let Some(seg) = type_path.path.segments.last() {
332            return seg.ident == "Option";
333        }
334    }
335    false
336}
337
338// ─────────────────────────────────────────────
339// #[derive(Insertable)]
340// ─────────────────────────────────────────────
341
342/// Derive macro that generates an `insert()` method for a struct.
343///
344/// ## Attributes
345/// - `#[table_name("table")]` — required, specifies the DB table
346/// - `#[auto_increment]` — on a field: exclude from INSERT (auto-generated by DB)
347/// - `#[skip_insert]` — on a field: exclude from INSERT (e.g. created with DEFAULT)
348/// - `#[timestamps]` — on the struct: auto-adds `created_at` column with `Utc::now()`
349/// - `#[slug_from("field")]` — on a field: auto-generates slug from another field if empty
350///
351/// ## Example
352/// ```ignore
353/// #[derive(Insertable)]
354/// #[table_name("posts")]
355/// #[timestamps]
356/// pub struct NewPost {
357///     pub title: String,
358///     #[slug_from("title")]
359///     pub slug: String,
360/// }
361///
362/// // Generated:
363/// // dto.insert(pool).await?;
364/// // → INSERT INTO posts (title, slug, created_at) VALUES (?, ?, ?)
365/// // slug auto-generated from title if empty
366/// ```
367#[proc_macro_derive(Insertable, attributes(table_name, auto_increment, skip_insert, timestamps, slug_from))]
368pub fn derive_insertable(input: TokenStream) -> TokenStream {
369    let input = parse_macro_input!(input as DeriveInput);
370    let name = &input.ident;
371
372    let table = match extract_table_name(&input.attrs) {
373        Some(t) => t,
374        None => {
375            return syn::Error::new_spanned(&input.ident, "Missing #[table_name(\"...\")]")
376                .to_compile_error()
377                .into();
378        }
379    };
380
381    let fields = match &input.data {
382        Data::Struct(data) => match &data.fields {
383            Fields::Named(f) => &f.named,
384            _ => {
385                return syn::Error::new_spanned(name, "Insertable only works on structs with named fields")
386                    .to_compile_error()
387                    .into();
388            }
389        },
390        _ => {
391            return syn::Error::new_spanned(name, "Insertable only works on structs")
392                .to_compile_error()
393                .into();
394        }
395    };
396
397    let has_timestamps = has_struct_attr(&input.attrs, "timestamps");
398
399    // Collect fields that participate in INSERT (skip auto_increment and skip_insert)
400    let insert_fields: Vec<_> = fields
401        .iter()
402        .filter(|f| !has_field_attr(f, "auto_increment") && !has_field_attr(f, "skip_insert"))
403        .collect();
404
405    let mut column_names: Vec<String> = insert_fields
406        .iter()
407        .map(|f| f.ident.as_ref().unwrap().to_string())
408        .collect();
409
410    // Auto-timestamps: add created_at column
411    if has_timestamps {
412        column_names.push("created_at".to_string());
413    }
414
415    let column_refs_tokens: Vec<proc_macro2::TokenStream> = column_names
416        .iter()
417        .map(|c| quote! { #c })
418        .collect();
419
420    // Build bind calls, handling #[slug_from] fields
421    let mut bind_calls: Vec<proc_macro2::TokenStream> = Vec::new();
422    let mut slug_lets: Vec<proc_macro2::TokenStream> = Vec::new();
423
424    for field in &insert_fields {
425        let ident = field.ident.as_ref().unwrap();
426
427        if let Some(source_field) = extract_slug_from(field) {
428            let source_ident = format_ident!("{}", source_field);
429            let var_name = format_ident!("__slug_{}", ident);
430            slug_lets.push(quote! {
431                let #var_name = if self.#ident.is_empty() {
432                    slug::slugify(&self.#source_ident)
433                } else {
434                    self.#ident.clone()
435                };
436            });
437            bind_calls.push(quote! { .bind(&#var_name) });
438        } else {
439            bind_calls.push(quote! { .bind(&self.#ident) });
440        }
441    }
442
443    // Auto-timestamps: bind chrono::Utc::now()
444    if has_timestamps {
445        bind_calls.push(quote! { .bind(chrono::Utc::now()) });
446    }
447
448    let expanded = quote! {
449        impl #name {
450            /// Auto-generated INSERT query from #[derive(Insertable)]
451            pub async fn insert<'e, E>(&self, executor: E) -> framework::error::AppResult<u64>
452            where
453                E: sqlx::Executor<'e, Database = framework::db::Db>,
454            {
455                #(#slug_lets)*
456                let sql = framework::db::insert_sql(#table, &[#(#column_refs_tokens),*]);
457                framework::db::execute_insert(
458                    sqlx::query(&sql)
459                        #(#bind_calls)*,
460                    executor
461                ).await
462            }
463        }
464    };
465
466    TokenStream::from(expanded)
467}
468
469// ─────────────────────────────────────────────
470// #[derive(Updatable)]
471// ─────────────────────────────────────────────
472
473/// Derive macro that generates an `update()` method for a struct.
474///
475/// All `Option<T>` fields are treated as partial updates — only `Some` values
476/// are included in the SET clause. Non-Option fields are always included.
477///
478/// ## Attributes
479/// - `#[table_name("table")]` — required
480/// - `#[primary_key]` — marks the WHERE clause field (required, exactly one)
481/// - `#[skip_update]` — exclude from SET clause
482///
483/// ## Example
484/// ```ignore
485/// #[derive(Updatable)]
486/// #[table_name("user")]
487/// pub struct UpdateUser {
488///     #[primary_key]
489///     pub id: i64,
490///     pub username: Option<String>,
491///     pub email: Option<String>,
492///     pub active: Option<bool>,
493/// }
494///
495/// // Generated:
496/// // UpdateUser { id: 42, username: Some("new"), email: None, active: None }.update(pool).await?;
497/// // → UPDATE user SET username=? WHERE id=?  (only Some fields)
498/// ```
499#[proc_macro_derive(Updatable, attributes(table_name, primary_key, skip_update, timestamps))]
500pub fn derive_updatable(input: TokenStream) -> TokenStream {
501    let input = parse_macro_input!(input as DeriveInput);
502    let name = &input.ident;
503
504    let table = match extract_table_name(&input.attrs) {
505        Some(t) => t,
506        None => {
507            return syn::Error::new_spanned(&input.ident, "Missing #[table_name(\"...\")]")
508                .to_compile_error()
509                .into();
510        }
511    };
512
513    let fields = match &input.data {
514        Data::Struct(data) => match &data.fields {
515            Fields::Named(f) => &f.named,
516            _ => {
517                return syn::Error::new_spanned(name, "Updatable only works on structs with named fields")
518                    .to_compile_error()
519                    .into();
520            }
521        },
522        _ => {
523            return syn::Error::new_spanned(name, "Updatable only works on structs")
524                .to_compile_error()
525                .into();
526        }
527    };
528
529    // Find the primary key field
530    let pk_field = fields.iter().find(|f| has_field_attr(f, "primary_key"));
531    let pk_field = match pk_field {
532        Some(f) => f,
533        None => {
534            return syn::Error::new_spanned(name, "Updatable requires exactly one #[primary_key] field")
535                .to_compile_error()
536                .into();
537        }
538    };
539    let pk_ident = pk_field.ident.as_ref().unwrap();
540    let pk_col = pk_ident.to_string();
541
542    let has_timestamps = has_struct_attr(&input.attrs, "timestamps");
543
544    // Collect updatable fields (not primary_key, not skip_update)
545    let update_fields: Vec<_> = fields
546        .iter()
547        .filter(|f| !has_field_attr(f, "primary_key") && !has_field_attr(f, "skip_update"))
548        .collect();
549
550    // Build the dynamic update logic with placeholder tracking
551    let mut set_pushes = Vec::new();
552    let mut bind_pushes = Vec::new();
553
554    for field in &update_fields {
555        let ident = field.ident.as_ref().unwrap();
556        let col_name = ident.to_string();
557
558        if is_option_type(&field.ty) {
559            set_pushes.push(quote! {
560                if self.#ident.is_some() {
561                    param_idx += 1;
562                    set_clauses.push(format!("{} = {}", #col_name, framework::db::placeholder(param_idx)));
563                }
564            });
565            bind_pushes.push(quote! {
566                if let Some(ref val) = self.#ident {
567                    query = query.bind(val);
568                }
569            });
570        } else {
571            set_pushes.push(quote! {
572                param_idx += 1;
573                set_clauses.push(format!("{} = {}", #col_name, framework::db::placeholder(param_idx)));
574            });
575            bind_pushes.push(quote! {
576                query = query.bind(&self.#ident);
577            });
578        }
579    }
580
581    // Auto-timestamps: add updated_at
582    let timestamps_set_push = if has_timestamps {
583        quote! {
584            param_idx += 1;
585            set_clauses.push(format!("{} = {}", "updated_at", framework::db::placeholder(param_idx)));
586        }
587    } else {
588        quote! {}
589    };
590
591    let timestamps_bind_push = if has_timestamps {
592        quote! {
593            query = query.bind(chrono::Utc::now());
594        }
595    } else {
596        quote! {}
597    };
598
599    let expanded = quote! {
600        impl #name {
601            /// Auto-generated UPDATE query from #[derive(Updatable)]
602            /// Only `Some` fields on Option<T> are included in SET.
603            ///
604            /// Accepts either a `&DbPool` or a `&mut Transaction`:
605            /// ```ignore
606            /// dto.update(pool).await?;           // normal
607            /// dto.update(&mut *tx).await?;       // in transaction
608            /// ```
609            pub async fn update<'e, E>(&self, executor: E) -> framework::error::AppResult<u64>
610            where
611                E: sqlx::Executor<'e, Database = framework::db::Db>,
612            {
613                let mut set_clauses: Vec<String> = Vec::new();
614                let mut param_idx: usize = 0;
615
616                #(#set_pushes)*
617                #timestamps_set_push
618
619                if set_clauses.is_empty() {
620                    return Ok(0); // nothing to update
621                }
622
623                param_idx += 1;
624                let sql = format!(
625                    "UPDATE {} SET {} WHERE {} = {}",
626                    #table,
627                    set_clauses.join(", "),
628                    #pk_col,
629                    framework::db::placeholder(param_idx),
630                );
631
632                let mut query = sqlx::query(&sql);
633
634                #(#bind_pushes)*
635                #timestamps_bind_push
636
637                // Bind the primary key for WHERE
638                query = query.bind(&self.#pk_ident);
639
640                let result = query.execute(executor).await?;
641                Ok(result.rows_affected())
642            }
643        }
644    };
645
646    TokenStream::from(expanded)
647}