Skip to main content

chopin_macros/
lib.rs

1//! Derive macros and attribute macros for the Chopin framework.
2//!
3//! | Macro | Kind | Purpose |
4//! |---|---|---|
5//! | `#[get("/path")]`, `#[post]`, `#[put]`, `#[delete]`, `#[patch]` | attribute | Register a handler function as a route at link time via `inventory` |
6//! | `#[derive(IntoResponse)]` | derive | Generate `From<MyError> for Response` for an error enum |
7//!
8//! ## Route registration
9//!
10//! Route attribute macros register handlers globally.  Call
11//! `Chopin::new().mount_all_routes()` to load all registered routes at startup:
12//!
13//! ```rust,ignore
14//! use chopin_macros::{get, post, delete};
15//! use chopin_core::{Context, Response, Chopin};
16//!
17//! #[get("/users/:id")]
18//! fn show_user(ctx: Context) -> Response { Response::text("hello") }
19//!
20//! #[post("/users")]
21//! fn create_user(ctx: Context) -> Response { Response::text("created") }
22//!
23//! fn main() {
24//!     Chopin::new().mount_all_routes().serve("0.0.0.0:8080").unwrap();
25//! }
26//! ```
27//!
28//! ## `#[derive(IntoResponse)]`
29//!
30//! Generates `From<YourError> for Response` so handlers can use `?`:
31//!
32//! ```rust,ignore
33//! use chopin_macros::IntoResponse;
34//!
35//! #[derive(IntoResponse)]
36//! enum ApiError {
37//!     #[status(404)] NotFound,
38//!     #[status(422)] Validation(String),
39//!     #[status(500)] Internal,
40//! }
41//! ```
42use proc_macro::TokenStream;
43use quote::quote;
44use syn::{Data, DeriveInput, Fields, ItemFn, parse_macro_input};
45
46// ─── #[derive(IntoResponse)] ─────────────────────────────────────────────────
47
48/// Derive `From<YourError> for chopin_core::http::Response` for an error enum.
49///
50/// Each variant must be annotated with `#[status(N)]` where `N` is the HTTP
51/// status code.  Variants without `#[status]` default to `500`.
52///
53/// # Example
54///
55/// ```rust,ignore
56/// use chopin_macros::IntoResponse;
57///
58/// #[derive(IntoResponse)]
59/// pub enum PostError {
60///     #[status(404)] NotFound(i32),
61///     #[status(422)] Validation(String),
62///     #[status(500)] Db(chopin_orm::OrmError),
63/// }
64///
65/// // Handlers can now use `?` directly:
66/// #[get("/posts/:id")]
67/// fn show(ctx: Context) -> Response {
68///     let id: i32 = ctx.param_parse("id")?;
69///     let post = services::get(id)?;   // PostError converts → Response
70///     Response::json(&post)
71/// }
72/// ```
73#[proc_macro_derive(IntoResponse, attributes(status))]
74pub fn derive_into_response(input: TokenStream) -> TokenStream {
75    let input = parse_macro_input!(input as DeriveInput);
76    let name = &input.ident;
77
78    let data_enum = match &input.data {
79        Data::Enum(e) => e,
80        _ => {
81            return syn::Error::new_spanned(name, "#[derive(IntoResponse)] only works on enums")
82                .to_compile_error()
83                .into();
84        }
85    };
86
87    let arms = data_enum.variants.iter().map(|variant| {
88        let variant_name = &variant.ident;
89
90        // Read #[status(N)], default 500.
91        let status_code: u16 = variant
92            .attrs
93            .iter()
94            .find_map(|attr| {
95                if !attr.path().is_ident("status") {
96                    return None;
97                }
98                attr.parse_args::<syn::LitInt>().ok()?.base10_parse().ok()
99            })
100            .unwrap_or(500u16);
101
102        match &variant.fields {
103            Fields::Unit => quote! {
104                #name::#variant_name => ::chopin_core::http::Response::new(#status_code),
105            },
106            Fields::Unnamed(_) => quote! {
107                #name::#variant_name(..) => ::chopin_core::http::Response::new(#status_code),
108            },
109            Fields::Named(_) => quote! {
110                #name::#variant_name { .. } => ::chopin_core::http::Response::new(#status_code),
111            },
112        }
113    });
114
115    TokenStream::from(quote! {
116        impl From<#name> for ::chopin_core::http::Response {
117            fn from(e: #name) -> Self {
118                match e {
119                    #(#arms)*
120                }
121            }
122        }
123    })
124}
125
126#[proc_macro_attribute]
127pub fn get(attr: TokenStream, item: TokenStream) -> TokenStream {
128    generate_route("Get", attr, item)
129}
130
131#[proc_macro_attribute]
132pub fn post(attr: TokenStream, item: TokenStream) -> TokenStream {
133    generate_route("Post", attr, item)
134}
135
136#[proc_macro_attribute]
137pub fn put(attr: TokenStream, item: TokenStream) -> TokenStream {
138    generate_route("Put", attr, item)
139}
140
141#[proc_macro_attribute]
142pub fn delete(attr: TokenStream, item: TokenStream) -> TokenStream {
143    generate_route("Delete", attr, item)
144}
145
146#[proc_macro_attribute]
147pub fn patch(attr: TokenStream, item: TokenStream) -> TokenStream {
148    generate_route("Patch", attr, item)
149}
150
151#[proc_macro_attribute]
152pub fn head(attr: TokenStream, item: TokenStream) -> TokenStream {
153    generate_route("Head", attr, item)
154}
155
156#[proc_macro_attribute]
157pub fn options(attr: TokenStream, item: TokenStream) -> TokenStream {
158    generate_route("Options", attr, item)
159}
160
161#[proc_macro_attribute]
162pub fn trace(attr: TokenStream, item: TokenStream) -> TokenStream {
163    generate_route("Trace", attr, item)
164}
165
166#[proc_macro_attribute]
167pub fn connect(attr: TokenStream, item: TokenStream) -> TokenStream {
168    generate_route("Connect", attr, item)
169}
170
171// ─── #[require_role] ──────────────────────────────────────────────────────────
172
173/// Inline role guard that wraps a Chopin handler with a JWT + RBAC check.
174///
175/// Returns `401` for missing/invalid tokens and `403` for insufficient role.
176///
177/// # Usage
178///
179/// ```rust,ignore
180/// use chopin_macros::{get, require_role};
181/// use chopin_auth::{Role, StandardClaims};
182///
183/// #[derive(Debug, Clone, PartialEq)]
184/// enum MyRole { Admin, User }
185/// impl Role for MyRole {}
186///
187/// type Claims = StandardClaims<MyRole>;
188///
189/// // Place #[require_role] ABOVE #[get] so it wraps the handler body
190/// // before the route is registered in the inventory.
191/// #[require_role(Claims, MyRole::Admin)]
192/// #[get("/admin/dashboard")]
193/// pub fn admin_dashboard(ctx: chopin_core::Context) -> chopin_core::Response {
194///     ctx.json(&"welcome, admin")
195/// }
196/// ```
197///
198/// # Requirements
199/// - `chopin_auth` must be a dependency of the consuming crate.
200/// - [`chopin_auth::init_jwt_manager`] must be called before the server starts.
201/// - The claims type must implement [`chopin_auth::HasJti`].
202/// - The claims type must implement [`chopin_auth::middleware::RoleCheck<R>`]
203///   for the given role type (satisfied automatically by [`chopin_auth::StandardClaims<R>`]).
204#[proc_macro_attribute]
205pub fn require_role(attr: TokenStream, item: TokenStream) -> TokenStream {
206    let args = parse_macro_input!(attr as RequireRoleArgs);
207    let mut func = parse_macro_input!(item as ItemFn);
208
209    let claims_type = &args.claims_type;
210    let role_expr = &args.role_expr;
211    let ctx_ident = first_param_ident(&func);
212
213    let original_stmts = func.block.stmts.clone();
214
215    let new_block: syn::Block = syn::parse_quote! {
216        {
217            let __chopin_token = (0..#ctx_ident.req.header_count as usize)
218                .find_map(|__ci| {
219                    let (__ck, __cv) = #ctx_ident.req.headers[__ci];
220                    if __ck.eq_ignore_ascii_case("Authorization") {
221                        __cv.strip_prefix("Bearer ")
222                    } else {
223                        None
224                    }
225                });
226            let Some(__chopin_token) = __chopin_token else {
227                return ::chopin_core::http::Response::new(401);
228            };
229            let Some(__chopin_mgr) = ::chopin_auth::extractor::GLOBAL_JWT_MANAGER.get() else {
230                return ::chopin_core::http::Response::server_error();
231            };
232            let __chopin_claims = match __chopin_mgr.decode::<#claims_type>(__chopin_token) {
233                ::std::result::Result::Ok(__c) => __c,
234                ::std::result::Result::Err(_) => {
235                    return ::chopin_core::http::Response::new(401);
236                }
237            };
238            if !::chopin_auth::middleware::RoleCheck::has_role(&__chopin_claims, &#role_expr) {
239                return ::chopin_core::http::Response::new(403);
240            }
241            #(#original_stmts)*
242        }
243    };
244
245    *func.block = new_block;
246    TokenStream::from(quote! { #func })
247}
248
249// ─── #[require_scope] ─────────────────────────────────────────────────────────
250
251/// Inline OAuth 2.0 scope guard that wraps a Chopin handler with a JWT + scope check.
252///
253/// Returns `401` for missing/invalid tokens and `403` for insufficient scope.
254///
255/// # Usage
256///
257/// ```rust,ignore
258/// use chopin_macros::{get, require_scope};
259/// use chopin_auth::StandardClaims;
260///
261/// type Claims = StandardClaims<()>;
262///
263/// // Place #[require_scope] ABOVE #[get].
264/// #[require_scope(Claims, "read:reports")]
265/// #[get("/reports")]
266/// pub fn list_reports(ctx: chopin_core::Context) -> chopin_core::Response {
267///     ctx.json(&"reports")
268/// }
269/// ```
270///
271/// # Requirements
272/// - `chopin_auth` must be a dependency of the consuming crate.
273/// - [`chopin_auth::init_jwt_manager`] must be called before the server starts.
274/// - The claims type must implement [`chopin_auth::middleware::ScopeCheck`]
275///   (satisfied automatically by [`chopin_auth::StandardClaims<R>`]).
276#[proc_macro_attribute]
277pub fn require_scope(attr: TokenStream, item: TokenStream) -> TokenStream {
278    let args = parse_macro_input!(attr as RequireScopeArgs);
279    let mut func = parse_macro_input!(item as ItemFn);
280
281    let claims_type = &args.claims_type;
282    let scope = &args.scope;
283    let ctx_ident = first_param_ident(&func);
284
285    let original_stmts = func.block.stmts.clone();
286
287    let new_block: syn::Block = syn::parse_quote! {
288        {
289            let __chopin_token = (0..#ctx_ident.req.header_count as usize)
290                .find_map(|__ci| {
291                    let (__ck, __cv) = #ctx_ident.req.headers[__ci];
292                    if __ck.eq_ignore_ascii_case("Authorization") {
293                        __cv.strip_prefix("Bearer ")
294                    } else {
295                        None
296                    }
297                });
298            let Some(__chopin_token) = __chopin_token else {
299                return ::chopin_core::http::Response::new(401);
300            };
301            let Some(__chopin_mgr) = ::chopin_auth::extractor::GLOBAL_JWT_MANAGER.get() else {
302                return ::chopin_core::http::Response::server_error();
303            };
304            let __chopin_claims = match __chopin_mgr.decode::<#claims_type>(__chopin_token) {
305                ::std::result::Result::Ok(__c) => __c,
306                ::std::result::Result::Err(_) => {
307                    return ::chopin_core::http::Response::new(401);
308                }
309            };
310            if !::chopin_auth::middleware::ScopeCheck::has_scope(&__chopin_claims, #scope) {
311                return ::chopin_core::http::Response::new(403);
312            }
313            #(#original_stmts)*
314        }
315    };
316
317    *func.block = new_block;
318    TokenStream::from(quote! { #func })
319}
320
321// ─── Argument parsers ─────────────────────────────────────────────────────────
322
323struct RequireRoleArgs {
324    claims_type: syn::Type,
325    role_expr: syn::Expr,
326}
327
328impl syn::parse::Parse for RequireRoleArgs {
329    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
330        let claims_type = input.parse::<syn::Type>()?;
331        input.parse::<syn::Token![,]>()?;
332        let role_expr = input.parse::<syn::Expr>()?;
333        Ok(RequireRoleArgs {
334            claims_type,
335            role_expr,
336        })
337    }
338}
339
340struct RequireScopeArgs {
341    claims_type: syn::Type,
342    scope: syn::LitStr,
343}
344
345impl syn::parse::Parse for RequireScopeArgs {
346    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
347        let claims_type = input.parse::<syn::Type>()?;
348        input.parse::<syn::Token![,]>()?;
349        let scope = input.parse::<syn::LitStr>()?;
350        Ok(RequireScopeArgs { claims_type, scope })
351    }
352}
353
354// ─── Helpers ──────────────────────────────────────────────────────────────────
355
356/// Extract the identifier of the first function parameter (typically `ctx`).
357fn first_param_ident(func: &ItemFn) -> proc_macro2::Ident {
358    func.sig
359        .inputs
360        .first()
361        .and_then(|arg| match arg {
362            syn::FnArg::Typed(pt) => match pt.pat.as_ref() {
363                syn::Pat::Ident(pi) => Some(pi.ident.clone()),
364                _ => None,
365            },
366            _ => None,
367        })
368        .unwrap_or_else(|| syn::Ident::new("ctx", proc_macro2::Span::call_site()))
369}
370
371fn generate_route(method: &str, attr: TokenStream, item: TokenStream) -> TokenStream {
372    let path = parse_macro_input!(attr as syn::LitStr).value();
373    let input_fn = parse_macro_input!(item as ItemFn);
374
375    let fn_name = &input_fn.sig.ident;
376    let method_ident = syn::Ident::new(method, proc_macro2::Span::call_site());
377
378    // Generate a trampoline name that is unlikely to clash with user symbols.
379    let trampoline_name = syn::Ident::new(
380        &format!("__chopin_{}_h", fn_name),
381        proc_macro2::Span::call_site(),
382    );
383
384    // Extract doc comments
385    let mut docs = Vec::new();
386    for attr in &input_fn.attrs {
387        if attr.path().is_ident("doc")
388            && let syn::Meta::NameValue(nv) = &attr.meta
389            && let syn::Expr::Lit(syn::ExprLit {
390                lit: syn::Lit::Str(s),
391                ..
392            }) = &nv.value
393        {
394            docs.push(s.value().trim().to_string());
395        }
396    }
397
398    let summary = docs.first().cloned().unwrap_or_default();
399    let description = if docs.len() > 1 {
400        docs[1..].join("\n")
401    } else {
402        String::new()
403    };
404
405    // The trampoline converts any `impl IntoResponse` return type (including
406    // `Response`, `Result<Response, E>`, etc.) into a plain `Response`.
407    // This allows handlers to use `?` for early error returns.
408    let expanded = quote! {
409        #input_fn
410
411        #[doc(hidden)]
412        #[allow(non_snake_case)]
413        fn #trampoline_name(__ctx: ::chopin_core::Context) -> ::chopin_core::Response {
414            ::chopin_core::http::IntoResponse::into_response(#fn_name(__ctx))
415        }
416
417        ::chopin_core::inventory::submit! {
418            ::chopin_core::RouteDef {
419                method: ::chopin_core::http::Method::#method_ident,
420                path: #path,
421                handler: #trampoline_name,
422                summary: #summary,
423                description: #description,
424            }
425        }
426    };
427
428    TokenStream::from(expanded)
429}