Skip to main content

karbon_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{
4    parse::{Parse, ParseStream},
5    parse_macro_input, Data, DeriveInput, Fields, FnArg, ImplItem, ItemImpl, LitStr, Pat, Token,
6    Type,
7};
8
9// ─────────────────────────────────────────────
10// #[controller(prefix = "/api/v1/admin/articles")]
11// ─────────────────────────────────────────────
12
13struct ControllerArgs {
14    prefix: String,
15    role: Option<String>,
16    state: Option<String>,
17}
18
19impl Parse for ControllerArgs {
20    fn parse(input: ParseStream) -> syn::Result<Self> {
21        let mut prefix = None;
22        let mut role = None;
23        let mut state = None;
24
25        while !input.is_empty() {
26            let ident: syn::Ident = input.parse()?;
27            let _: Token![=] = input.parse()?;
28            let lit: LitStr = input.parse()?;
29
30            match ident.to_string().as_str() {
31                "prefix" => prefix = Some(lit.value()),
32                "role" => role = Some(lit.value()),
33                "state" => state = Some(lit.value()),
34                _ => return Err(syn::Error::new(ident.span(), "expected `prefix`, `role`, or `state`")),
35            }
36
37            // Consume optional comma
38            let _ = input.parse::<Token![,]>();
39        }
40
41        Ok(ControllerArgs {
42            prefix: prefix.unwrap_or_default(),
43            role,
44            state,
45        })
46    }
47}
48
49/// Route info extracted from method attributes
50struct RouteInfo {
51    method: String,
52    path: String,
53    fn_name: syn::Ident,
54}
55
56/// Find the parameter name that has type AuthGuard in a method signature
57fn find_auth_guard_param(method: &syn::ImplItemFn) -> Option<syn::Ident> {
58    for arg in &method.sig.inputs {
59        if let FnArg::Typed(pat_type) = arg {
60            if let Type::Path(type_path) = pat_type.ty.as_ref() {
61                let last_seg = type_path.path.segments.last();
62                if let Some(seg) = last_seg {
63                    if seg.ident == "AuthGuard" {
64                        // Extract the pattern name
65                        if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
66                            return Some(pat_ident.ident.clone());
67                        }
68                    }
69                }
70            }
71        }
72    }
73    None
74}
75
76/// Extracts route annotations from an impl method.
77fn extract_route_info(item: &ImplItem) -> Option<RouteInfo> {
78    let method_item = match item {
79        ImplItem::Fn(m) => m,
80        _ => return None,
81    };
82
83    let fn_name = method_item.sig.ident.clone();
84    let mut http_method = None;
85    let mut path = None;
86
87    for attr in &method_item.attrs {
88        let seg = attr.path().segments.last()?;
89        let name = seg.ident.to_string();
90
91        match name.as_str() {
92            "get" | "post" | "put" | "delete" | "patch" => {
93                http_method = Some(name);
94                if let Ok(lit) = attr.parse_args::<LitStr>() {
95                    path = Some(lit.value());
96                }
97            }
98            _ => {}
99        }
100    }
101
102    Some(RouteInfo {
103        method: http_method?,
104        path: path.unwrap_or_else(|| "/".to_string()),
105        fn_name,
106    })
107}
108
109/// # Controller macro
110///
111/// Generates a `router()` function from annotated methods.
112/// `#[require_role("ROLE_X")]` auto-injects `auth.require_role("ROLE_X")?;`
113/// at the beginning of the handler body.
114///
115/// ## Usage
116/// ```ignore
117/// #[controller(prefix = "/api/v1/admin/articles")]
118/// impl ArticleController {
119///     #[get("/")]
120///     #[require_role("ROLE_ADMIN")]
121///     async fn list(
122///         auth: AuthGuard,
123///         State(state): State<AppState>,
124///     ) -> AppResult<impl IntoResponse> {
125///         // auth.require_role("ROLE_ADMIN")?; ← auto-injected!
126///         let items = ...;
127///         Ok(Json(items))
128///     }
129/// }
130/// ```
131#[proc_macro_attribute]
132pub fn controller(args: TokenStream, input: TokenStream) -> TokenStream {
133    let args = parse_macro_input!(args as ControllerArgs);
134    let mut impl_block = parse_macro_input!(input as ItemImpl);
135    let prefix = &args.prefix;
136
137    // Extract route info BEFORE modifying the impl block
138    let routes: Vec<RouteInfo> = impl_block
139        .items
140        .iter()
141        .filter_map(extract_route_info)
142        .collect();
143
144    let controller_role = args.role.clone();
145
146    // Inject role checks and strip custom attributes
147    for item in &mut impl_block.items {
148        if let ImplItem::Fn(method) = item {
149            // Find require_role value for this method (method-level overrides controller-level)
150            let mut role_value = None;
151            for attr in &method.attrs {
152                let name = attr
153                    .path()
154                    .segments
155                    .last()
156                    .map(|s| s.ident.to_string())
157                    .unwrap_or_default();
158                if name == "require_role" {
159                    if let Ok(lit) = attr.parse_args::<LitStr>() {
160                        role_value = Some(lit.value());
161                    }
162                }
163            }
164
165            // Fall back to controller-level role if no method-level role
166            let effective_role = role_value.or_else(|| controller_role.clone());
167
168            // If a role is set, inject the check at the start of the body
169            if let Some(role) = effective_role {
170                if let Some(auth_param) = find_auth_guard_param(method) {
171                    let check = syn::parse2::<syn::Stmt>(quote! {
172                        #auth_param.require_role(#role)?;
173                    })
174                    .expect("Failed to parse role check statement");
175
176                    method.block.stmts.insert(0, check);
177                }
178            }
179
180            // Strip custom attributes
181            method.attrs.retain(|attr| {
182                let name = attr
183                    .path()
184                    .segments
185                    .last()
186                    .map(|s| s.ident.to_string())
187                    .unwrap_or_default();
188                !matches!(
189                    name.as_str(),
190                    "get" | "post" | "put" | "delete" | "patch" | "require_role"
191                )
192            });
193        }
194    }
195
196    // Build route registrations grouped by path
197    let mut path_groups: std::collections::BTreeMap<String, Vec<&RouteInfo>> =
198        std::collections::BTreeMap::new();
199    for route in &routes {
200        path_groups
201            .entry(route.path.clone())
202            .or_default()
203            .push(route);
204    }
205
206    let route_registrations: Vec<proc_macro2::TokenStream> = path_groups
207        .iter()
208        .map(|(path, methods)| {
209            let mut chain = Vec::new();
210            for (i, route) in methods.iter().enumerate() {
211                let method_ident = format_ident!("{}", route.method);
212                let fn_name = &route.fn_name;
213
214                if i == 0 {
215                    chain.push(quote! {
216                        axum::routing::#method_ident(Self::#fn_name)
217                    });
218                } else {
219                    chain.push(quote! {
220                        .#method_ident(Self::#fn_name)
221                    });
222                }
223            }
224
225            quote! {
226                .route(#path, #(#chain)*)
227            }
228        })
229        .collect();
230
231    let self_ty = &impl_block.self_ty;
232
233    // Use custom state type if provided, otherwise default to karbon::http::AppState
234    let state_type: proc_macro2::TokenStream = if let Some(ref state_path) = args.state {
235        state_path.parse().unwrap_or_else(|_| quote! { karbon::http::AppState })
236    } else {
237        quote! { karbon::http::AppState }
238    };
239
240    let expanded = quote! {
241        #impl_block
242
243        impl #self_ty {
244            /// Auto-generated router from #[controller] annotations
245            pub fn router() -> axum::Router<#state_type> {
246                axum::Router::new()
247                    #(#route_registrations)*
248            }
249
250            /// Returns the route prefix for this controller
251            pub fn prefix() -> &'static str {
252                #prefix
253            }
254        }
255    };
256
257    TokenStream::from(expanded)
258}
259
260// ─────────────────────────────────────────────
261// Standalone route macros (marker attributes, consumed by #[controller])
262// ─────────────────────────────────────────────
263
264/// Mark a handler as a GET route: `#[get("/path")]`
265#[proc_macro_attribute]
266pub fn get(_args: TokenStream, input: TokenStream) -> TokenStream {
267    input
268}
269
270/// Mark a handler as a POST route: `#[post("/path")]`
271#[proc_macro_attribute]
272pub fn post(_args: TokenStream, input: TokenStream) -> TokenStream {
273    input
274}
275
276/// Mark a handler as a PUT route: `#[put("/path")]`
277#[proc_macro_attribute]
278pub fn put(_args: TokenStream, input: TokenStream) -> TokenStream {
279    input
280}
281
282/// Mark a handler as a DELETE route: `#[delete("/path")]`
283#[proc_macro_attribute]
284pub fn delete(_args: TokenStream, input: TokenStream) -> TokenStream {
285    input
286}
287
288/// Mark a handler as a PATCH route: `#[patch("/path")]`
289#[proc_macro_attribute]
290pub fn patch(_args: TokenStream, input: TokenStream) -> TokenStream {
291    input
292}
293
294/// Require a role — auto-injects `auth.require_role("ROLE")?;` at the start of the handler.
295/// The handler must have an `AuthGuard` parameter.
296/// Usage: `#[require_role("ROLE_ADMIN")]`
297#[proc_macro_attribute]
298pub fn require_role(_args: TokenStream, input: TokenStream) -> TokenStream {
299    input
300}
301
302// ─────────────────────────────────────────────
303// Helper: extract #[table_name = "..."] from attributes
304// ─────────────────────────────────────────────
305
306fn extract_table_name(attrs: &[syn::Attribute]) -> Option<String> {
307    for attr in attrs {
308        if attr.path().is_ident("table_name") {
309            if let Ok(lit) = attr.parse_args::<LitStr>() {
310                return Some(lit.value());
311            }
312        }
313    }
314    None
315}
316
317/// Check if a field has a specific helper attribute like #[auto_increment], #[skip_insert], etc.
318fn has_field_attr(field: &syn::Field, name: &str) -> bool {
319    field.attrs.iter().any(|attr| attr.path().is_ident(name))
320}
321
322/// Check if the struct has a specific attribute like #[timestamps]
323fn has_struct_attr(attrs: &[syn::Attribute], name: &str) -> bool {
324    attrs.iter().any(|attr| attr.path().is_ident(name))
325}
326
327/// Extract #[slug_from("field_name")] value from a field
328fn extract_slug_from(field: &syn::Field) -> Option<String> {
329    for attr in &field.attrs {
330        if attr.path().is_ident("slug_from") {
331            if let Ok(lit) = attr.parse_args::<LitStr>() {
332                return Some(lit.value());
333            }
334        }
335    }
336    None
337}
338
339/// Check if a field type is Option<T>
340fn is_option_type(ty: &Type) -> bool {
341    if let Type::Path(type_path) = ty {
342        if let Some(seg) = type_path.path.segments.last() {
343            return seg.ident == "Option";
344        }
345    }
346    false
347}
348
349// ─────────────────────────────────────────────
350// #[derive(Insertable)]
351// ─────────────────────────────────────────────
352
353/// Derive macro that generates an `insert()` method for a struct.
354///
355/// ## Attributes
356/// - `#[table_name("table")]` — required, specifies the DB table
357/// - `#[auto_increment]` — on a field: exclude from INSERT (auto-generated by DB)
358/// - `#[skip_insert]` — on a field: exclude from INSERT (e.g. created with DEFAULT)
359/// - `#[timestamps]` — on the struct: auto-adds `created_at` column with `Utc::now()`
360/// - `#[slug_from("field")]` — on a field: auto-generates slug from another field if empty
361///
362/// ## Example
363/// ```ignore
364/// #[derive(Insertable)]
365/// #[table_name("posts")]
366/// #[timestamps]
367/// pub struct NewPost {
368///     pub title: String,
369///     #[slug_from("title")]
370///     pub slug: String,
371/// }
372///
373/// // Generated:
374/// // dto.insert(pool).await?;
375/// // → INSERT INTO posts (title, slug, created_at) VALUES (?, ?, ?)
376/// // slug auto-generated from title if empty
377/// ```
378#[proc_macro_derive(Insertable, attributes(table_name, auto_increment, skip_insert, timestamps, slug_from))]
379pub fn derive_insertable(input: TokenStream) -> TokenStream {
380    let input = parse_macro_input!(input as DeriveInput);
381    let name = &input.ident;
382
383    let table = match extract_table_name(&input.attrs) {
384        Some(t) => t,
385        None => {
386            return syn::Error::new_spanned(&input.ident, "Missing #[table_name(\"...\")]")
387                .to_compile_error()
388                .into();
389        }
390    };
391
392    let fields = match &input.data {
393        Data::Struct(data) => match &data.fields {
394            Fields::Named(f) => &f.named,
395            _ => {
396                return syn::Error::new_spanned(name, "Insertable only works on structs with named fields")
397                    .to_compile_error()
398                    .into();
399            }
400        },
401        _ => {
402            return syn::Error::new_spanned(name, "Insertable only works on structs")
403                .to_compile_error()
404                .into();
405        }
406    };
407
408    let has_timestamps = has_struct_attr(&input.attrs, "timestamps");
409
410    // Collect fields that participate in INSERT (skip auto_increment and skip_insert)
411    let insert_fields: Vec<_> = fields
412        .iter()
413        .filter(|f| !has_field_attr(f, "auto_increment") && !has_field_attr(f, "skip_insert"))
414        .collect();
415
416    let mut column_names: Vec<String> = insert_fields
417        .iter()
418        .map(|f| f.ident.as_ref().unwrap().to_string())
419        .collect();
420
421    // Auto-timestamps: add created_at column
422    if has_timestamps {
423        column_names.push("created_at".to_string());
424    }
425
426
427    // Build bind calls, handling #[slug_from] fields
428    let mut bind_calls: Vec<proc_macro2::TokenStream> = Vec::new();
429    let mut slug_lets: Vec<proc_macro2::TokenStream> = Vec::new();
430
431    for field in &insert_fields {
432        let ident = field.ident.as_ref().unwrap();
433
434        if let Some(source_field) = extract_slug_from(field) {
435            let source_ident = format_ident!("{}", source_field);
436            let var_name = format_ident!("__slug_{}", ident);
437            slug_lets.push(quote! {
438                let #var_name = if self.#ident.is_empty() {
439                    slug::slugify(&self.#source_ident)
440                } else {
441                    self.#ident.clone()
442                };
443            });
444            bind_calls.push(quote! { .bind(&#var_name) });
445        } else {
446            bind_calls.push(quote! { .bind(&self.#ident) });
447        }
448    }
449
450    // Auto-timestamps: bind chrono::Utc::now()
451    if has_timestamps {
452        bind_calls.push(quote! { .bind(chrono::Utc::now()) });
453    }
454
455    // Build the SQL string at compile time
456    let columns_str = column_names.join(", ");
457    let placeholders_str: String = (1..=column_names.len())
458        .map(|_| "?".to_string())
459        .collect::<Vec<_>>()
460        .join(", ");
461    let sql_literal = format!("INSERT INTO {} ({}) VALUES ({})", table, columns_str, placeholders_str);
462
463    let expanded = quote! {
464        impl #name {
465            /// Auto-generated INSERT query from #[derive(Insertable)]
466            pub async fn insert<'e, E>(&self, executor: E) -> karbon::error::AppResult<u64>
467            where
468                E: sqlx::Executor<'e, Database = karbon::db::Db>,
469            {
470                #(#slug_lets)*
471                let result = sqlx::query(#sql_literal)
472                    #(#bind_calls)*
473                    .execute(executor)
474                    .await?;
475                Ok(karbon::db::last_insert_id(&result))
476            }
477        }
478    };
479
480    TokenStream::from(expanded)
481}
482
483// ─────────────────────────────────────────────
484// #[derive(Updatable)]
485// ─────────────────────────────────────────────
486
487/// Derive macro that generates an `update()` method for a struct.
488///
489/// All `Option<T>` fields are treated as partial updates — only `Some` values
490/// are included in the SET clause. Non-Option fields are always included.
491///
492/// ## Attributes
493/// - `#[table_name("table")]` — required
494/// - `#[primary_key]` — marks the WHERE clause field (required, exactly one)
495/// - `#[skip_update]` — exclude from SET clause
496///
497/// ## Example
498/// ```ignore
499/// #[derive(Updatable)]
500/// #[table_name("user")]
501/// pub struct UpdateUser {
502///     #[primary_key]
503///     pub id: i64,
504///     pub username: Option<String>,
505///     pub email: Option<String>,
506///     pub active: Option<bool>,
507/// }
508///
509/// // Generated:
510/// // UpdateUser { id: 42, username: Some("new"), email: None, active: None }.update(pool).await?;
511/// // → UPDATE user SET username=? WHERE id=?  (only Some fields)
512/// ```
513#[proc_macro_derive(Updatable, attributes(table_name, primary_key, skip_update, timestamps))]
514pub fn derive_updatable(input: TokenStream) -> TokenStream {
515    let input = parse_macro_input!(input as DeriveInput);
516    let name = &input.ident;
517
518    let table = match extract_table_name(&input.attrs) {
519        Some(t) => t,
520        None => {
521            return syn::Error::new_spanned(&input.ident, "Missing #[table_name(\"...\")]")
522                .to_compile_error()
523                .into();
524        }
525    };
526
527    let fields = match &input.data {
528        Data::Struct(data) => match &data.fields {
529            Fields::Named(f) => &f.named,
530            _ => {
531                return syn::Error::new_spanned(name, "Updatable only works on structs with named fields")
532                    .to_compile_error()
533                    .into();
534            }
535        },
536        _ => {
537            return syn::Error::new_spanned(name, "Updatable only works on structs")
538                .to_compile_error()
539                .into();
540        }
541    };
542
543    // Find the primary key field
544    let pk_field = fields.iter().find(|f| has_field_attr(f, "primary_key"));
545    let pk_field = match pk_field {
546        Some(f) => f,
547        None => {
548            return syn::Error::new_spanned(name, "Updatable requires exactly one #[primary_key] field")
549                .to_compile_error()
550                .into();
551        }
552    };
553    let pk_ident = pk_field.ident.as_ref().unwrap();
554    let pk_col = pk_ident.to_string();
555
556    let has_timestamps = has_struct_attr(&input.attrs, "timestamps");
557
558    // Collect updatable fields (not primary_key, not skip_update)
559    let update_fields: Vec<_> = fields
560        .iter()
561        .filter(|f| !has_field_attr(f, "primary_key") && !has_field_attr(f, "skip_update"))
562        .collect();
563
564    // Build the dynamic update logic with placeholder tracking
565    let mut set_pushes = Vec::new();
566    let mut bind_pushes = Vec::new();
567
568    for field in &update_fields {
569        let ident = field.ident.as_ref().unwrap();
570        let col_name = ident.to_string();
571
572        if is_option_type(&field.ty) {
573            set_pushes.push(quote! {
574                if self.#ident.is_some() {
575                    param_idx += 1;
576                    set_clauses.push(format!("{} = {}", #col_name, karbon::db::placeholder(param_idx)));
577                }
578            });
579            bind_pushes.push(quote! {
580                if let Some(ref val) = self.#ident {
581                    query = query.bind(val);
582                }
583            });
584        } else {
585            set_pushes.push(quote! {
586                param_idx += 1;
587                set_clauses.push(format!("{} = {}", #col_name, karbon::db::placeholder(param_idx)));
588            });
589            bind_pushes.push(quote! {
590                query = query.bind(&self.#ident);
591            });
592        }
593    }
594
595    // Auto-timestamps: add updated_at
596    let timestamps_set_push = if has_timestamps {
597        quote! {
598            param_idx += 1;
599            set_clauses.push(format!("{} = {}", "updated_at", karbon::db::placeholder(param_idx)));
600        }
601    } else {
602        quote! {}
603    };
604
605    let timestamps_bind_push = if has_timestamps {
606        quote! {
607            query = query.bind(chrono::Utc::now());
608        }
609    } else {
610        quote! {}
611    };
612
613    let expanded = quote! {
614        impl #name {
615            /// Auto-generated UPDATE query from #[derive(Updatable)]
616            /// Only `Some` fields on Option<T> are included in SET.
617            ///
618            /// Accepts either a `&DbPool` or a `&mut Transaction`:
619            /// ```ignore
620            /// dto.update(pool).await?;           // normal
621            /// dto.update(&mut *tx).await?;       // in transaction
622            /// ```
623            pub async fn update<'e, E>(&self, executor: E) -> karbon::error::AppResult<u64>
624            where
625                E: sqlx::Executor<'e, Database = karbon::db::Db>,
626            {
627                let mut set_clauses: Vec<String> = Vec::new();
628                let mut param_idx: usize = 0;
629
630                #(#set_pushes)*
631                #timestamps_set_push
632
633                if set_clauses.is_empty() {
634                    return Ok(0); // nothing to update
635                }
636
637                param_idx += 1;
638                let sql = format!(
639                    "UPDATE {} SET {} WHERE {} = {}",
640                    #table,
641                    set_clauses.join(", "),
642                    #pk_col,
643                    karbon::db::placeholder(param_idx),
644                );
645
646                let mut query = sqlx::query(&sql);
647
648                #(#bind_pushes)*
649                #timestamps_bind_push
650
651                // Bind the primary key for WHERE
652                query = query.bind(&self.#pk_ident);
653
654                let result = query.execute(executor).await?;
655                Ok(result.rows_affected())
656            }
657        }
658    };
659
660    TokenStream::from(expanded)
661}