Skip to main content

rapina_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{FnArg, ItemFn, LitStr, Pat};
4
5/// Parsed route macro attribute: `"/path"` or `"/path", group = "/prefix"`.
6struct RouteAttr {
7    path: LitStr,
8    group: Option<LitStr>,
9}
10
11impl syn::parse::Parse for RouteAttr {
12    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
13        let path: LitStr = input.parse()?;
14        let group = if input.peek(syn::Token![,]) {
15            input.parse::<syn::Token![,]>()?;
16            let ident: syn::Ident = input.parse()?;
17            if ident != "group" {
18                return Err(syn::Error::new(ident.span(), "expected `group`"));
19            }
20            input.parse::<syn::Token![=]>()?;
21            let value: LitStr = input.parse()?;
22            Some(value)
23        } else {
24            None
25        };
26        if !input.is_empty() {
27            return Err(input.error("unexpected tokens after route attribute"));
28        }
29        Ok(RouteAttr { path, group })
30    }
31}
32
33/// Join a group prefix with a route path at compile time.
34fn join_paths(prefix: &str, path: &str) -> String {
35    let prefix = prefix.trim_end_matches('/');
36    if path.is_empty() || path == "/" {
37        if prefix.is_empty() {
38            return "/".to_string();
39        }
40        return prefix.to_string();
41    }
42    let path = if path.starts_with('/') {
43        path.to_string()
44    } else {
45        format!("/{path}")
46    };
47    format!("{prefix}{path}")
48}
49
50mod schema;
51
52/// Registers a GET route handler.
53///
54/// # Syntax
55///
56/// ```ignore
57/// #[get("/users")]
58/// async fn list_users() -> Json<Vec<User>> { /* ... */ }
59///
60/// // Single path parameter:
61/// #[get("/users/:id")]
62/// async fn get_user(id: Path<u64>) -> Json<User> { /* ... */ }
63///
64/// // Multiple path parameters — tuple, positional (left to right in pattern):
65/// #[get("/orgs/:org_id/teams/:team_id")]
66/// async fn get_team(Path((org_id, team_id)): Path<(u64, u64)>) -> Json<Team> { /* ... */ }
67///
68/// // With a group prefix (registers at /api/users):
69/// #[get("/users", group = "/api")]
70/// async fn list_users() -> Json<Vec<User>> { /* ... */ }
71/// ```
72///
73/// The `group` parameter joins the prefix with the path at compile time,
74/// so the handler is registered at the full path during auto-discovery.
75#[proc_macro_attribute]
76pub fn get(attr: TokenStream, item: TokenStream) -> TokenStream {
77    route_macro("GET", attr, item)
78}
79
80/// Registers a POST route handler.
81///
82/// See [`get`] for syntax details including the optional `group` parameter.
83#[proc_macro_attribute]
84pub fn post(attr: TokenStream, item: TokenStream) -> TokenStream {
85    route_macro("POST", attr, item)
86}
87
88/// Registers a PUT route handler.
89///
90/// See [`get`] for syntax details including the optional `group` parameter.
91#[proc_macro_attribute]
92pub fn put(attr: TokenStream, item: TokenStream) -> TokenStream {
93    route_macro("PUT", attr, item)
94}
95
96/// Registers a PATCH route handler.
97///
98/// # Example
99///
100/// ```ignore
101/// #[patch("/users/:id")]
102/// async fn update_user(Path(id): Path<u64>) -> Json<User> { /* ... */ }
103/// ```
104///
105/// See [`get`] for syntax details including the optional `group` parameter.
106#[proc_macro_attribute]
107pub fn patch(attr: TokenStream, item: TokenStream) -> TokenStream {
108    route_macro("PATCH", attr, item)
109}
110
111/// Registers a DELETE route handler.
112///
113/// See [`get`] for syntax details including the optional `group` parameter.
114#[proc_macro_attribute]
115pub fn delete(attr: TokenStream, item: TokenStream) -> TokenStream {
116    route_macro("DELETE", attr, item)
117}
118
119/// Marks a route as public (no authentication required).
120///
121/// When authentication is enabled via `Rapina::with_auth()`, all routes
122/// require a valid JWT token by default. Use `#[public]` to allow
123/// unauthenticated access to specific routes.
124///
125/// # Example
126///
127/// ```ignore
128/// use rapina::prelude::*;
129///
130/// #[public]
131/// #[get("/health")]
132/// async fn health() -> &'static str {
133///     "ok"
134/// }
135///
136/// #[public]
137/// #[post("/login")]
138/// async fn login(body: Json<LoginRequest>) -> Result<Json<TokenResponse>> {
139///     // ... authenticate and return token
140/// }
141/// ```
142///
143/// Note: Routes starting with `/__rapina` are automatically public.
144#[proc_macro_attribute]
145pub fn public(_attr: TokenStream, item: TokenStream) -> TokenStream {
146    let func: ItemFn = syn::parse(item.clone()).expect("#[public] must be applied to a function");
147    let func_name_str = func.sig.ident.to_string();
148    let item2: proc_macro2::TokenStream = item.into();
149    quote! {
150        #item2
151        rapina::inventory::submit! {
152            rapina::discovery::PublicMarker {
153                handler_name: #func_name_str,
154            }
155        }
156    }
157    .into()
158}
159
160fn route_macro_core(
161    method: &str,
162    attr: proc_macro2::TokenStream,
163    item: proc_macro2::TokenStream,
164) -> proc_macro2::TokenStream {
165    let route_attr: RouteAttr = syn::parse2(attr).expect("expected path as string literal");
166    let path_str = if let Some(ref group) = route_attr.group {
167        let g = group.value();
168        assert!(
169            g.starts_with('/'),
170            "group prefix must start with `/`, got: {g:?}"
171        );
172        join_paths(&g, &route_attr.path.value())
173    } else {
174        route_attr.path.value()
175    };
176    let mut func: ItemFn = syn::parse2(item).expect("expected function");
177
178    let func_name = &func.sig.ident;
179    let func_name_str = func_name.to_string();
180    let func_vis = &func.vis;
181
182    // Extract #[public] attribute if present (when #[public] is below the route macro)
183    let is_public = extract_public_attr(&mut func.attrs);
184
185    // Extract #[errors(ErrorType)] attribute if present
186    let error_type = extract_errors_attr(&mut func.attrs);
187
188    // Extract #[cache(ttl = N)] attribute if present
189    let cache_ttl = extract_cache_attr(&mut func.attrs);
190
191    let error_responses_impl = if let Some(err_type) = &error_type {
192        quote! {
193            fn error_responses() -> Vec<rapina::error::ErrorVariant> {
194                <#err_type as rapina::error::DocumentedError>::error_variants()
195            }
196        }
197    } else {
198        quote! {}
199    };
200
201    // Extract return type for schema generation
202    let response_schema_impl = if let syn::ReturnType::Type(_, return_type) = &func.sig.output {
203        if let Some(inner_type) = extract_json_inner_type(return_type) {
204            quote! {
205                fn response_schema() -> Option<serde_json::Value> {
206                    Some(serde_json::to_value(rapina::schemars::schema_for!(#inner_type)).unwrap())
207                }
208            }
209        } else {
210            quote! {}
211        }
212    } else {
213        quote! {}
214    };
215
216    let args: Vec<_> = func.sig.inputs.iter().collect();
217
218    // Extract return type for type annotation (helps with type inference in async blocks)
219    let return_type_annotation = match &func.sig.output {
220        syn::ReturnType::Type(_, ty) => quote! { : #ty },
221        syn::ReturnType::Default => quote! {},
222    };
223
224    // Optional cache TTL header injection
225    let cache_header_injection = if let Some(ttl) = cache_ttl {
226        let ttl_str = ttl.to_string();
227        quote! {
228            let mut __rapina_response = __rapina_response;
229            __rapina_response.headers_mut().insert(
230                "x-rapina-cache-ttl",
231                rapina::http::HeaderValue::from_static(#ttl_str),
232            );
233        }
234    } else {
235        quote! {}
236    };
237
238    // Build the handler body
239    // Use __rapina_ prefix for internal variables to avoid shadowing user's variables
240    let handler_body = if args.is_empty() {
241        let inner_block = &func.block;
242        quote! {
243            let __rapina_result #return_type_annotation = (async #inner_block).await;
244            let __rapina_response = rapina::response::IntoResponse::into_response(__rapina_result);
245            #cache_header_injection
246            __rapina_response
247        }
248    } else {
249        let inner_block = &func.block;
250
251        if args.len() == 1 {
252            // Single arg: pass request directly to FromRequest
253            let arg = &args[0];
254            if let FnArg::Typed(pat_type) = arg {
255                let pat = &pat_type.pat;
256                let arg_type = &pat_type.ty;
257                let tmp = syn::Ident::new("__rapina_arg_0", proc_macro2::Span::call_site());
258                quote! {
259                    let #tmp = match <#arg_type as rapina::extract::FromRequest>::from_request(__rapina_req, &__rapina_params, &__rapina_state).await {
260                        Ok(v) => v,
261                        Err(e) => return rapina::response::IntoResponse::into_response(e),
262                    };
263                    let #pat = #tmp;
264                    let __rapina_result #return_type_annotation = (async #inner_block).await;
265                    let __rapina_response = rapina::response::IntoResponse::into_response(__rapina_result);
266                    #cache_header_injection
267                    __rapina_response
268                }
269            } else {
270                unreachable!("handler argument must be a typed pattern")
271            }
272        } else {
273            // Multiple args: all but last use FromRequestParts, last uses FromRequest
274            let mut parts_extractions = Vec::new();
275
276            for (i, arg) in args[..args.len() - 1].iter().enumerate() {
277                if let FnArg::Typed(pat_type) = arg {
278                    let pat = &pat_type.pat;
279                    let arg_type = &pat_type.ty;
280                    let tmp = syn::Ident::new(
281                        &format!("__rapina_arg_{}", i),
282                        proc_macro2::Span::call_site(),
283                    );
284                    parts_extractions.push(quote! {
285                        let #tmp = match <#arg_type as rapina::extract::FromRequestParts>::from_request_parts(&__rapina_parts, &__rapina_params, &__rapina_state).await {
286                            Ok(v) => v,
287                            Err(e) => return rapina::response::IntoResponse::into_response(e),
288                        };
289                        let #pat = #tmp;
290                    });
291                }
292            }
293
294            let last_arg = args.last().unwrap();
295            let last_extraction = if let FnArg::Typed(pat_type) = last_arg {
296                let pat = &pat_type.pat;
297                let arg_type = &pat_type.ty;
298                let tmp = syn::Ident::new(
299                    &format!("__rapina_arg_{}", args.len() - 1),
300                    proc_macro2::Span::call_site(),
301                );
302                quote! {
303                    let __rapina_req = rapina::http::Request::from_parts(__rapina_parts, __rapina_body);
304                    let #tmp = match <#arg_type as rapina::extract::FromRequest>::from_request(__rapina_req, &__rapina_params, &__rapina_state).await {
305                        Ok(v) => v,
306                        Err(e) => return rapina::response::IntoResponse::into_response(e),
307                    };
308                    let #pat = #tmp;
309                }
310            } else {
311                unreachable!("handler argument must be a typed pattern")
312            };
313
314            quote! {
315                let (__rapina_parts, __rapina_body) = __rapina_req.into_parts();
316                #(#parts_extractions)*
317                #last_extraction
318                let __rapina_result #return_type_annotation = (async #inner_block).await;
319                let __rapina_response = rapina::response::IntoResponse::into_response(__rapina_result);
320                #cache_header_injection
321                __rapina_response
322            }
323        }
324    };
325
326    // Build the router method call for the register function
327    let router_method = syn::Ident::new(&method.to_lowercase(), proc_macro2::Span::call_site());
328    let register_fn_name = syn::Ident::new(
329        &format!("__rapina_register_{}", func_name_str),
330        proc_macro2::Span::call_site(),
331    );
332
333    // Generate the struct, Handler impl, and inventory submission
334    quote! {
335        #[derive(Clone, Copy)]
336        #[allow(non_camel_case_types)]
337        #func_vis struct #func_name;
338
339        impl rapina::handler::Handler for #func_name {
340            const NAME: &'static str = #func_name_str;
341
342            #response_schema_impl
343            #error_responses_impl
344
345            fn call(
346                &self,
347                __rapina_req: rapina::hyper::Request<rapina::hyper::body::Incoming>,
348                __rapina_params: rapina::extract::PathParams,
349                __rapina_state: std::sync::Arc<rapina::state::AppState>,
350            ) -> std::pin::Pin<Box<dyn std::future::Future<Output = rapina::hyper::Response<rapina::response::BoxBody>> + Send>> {
351                Box::pin(async move {
352                    #handler_body
353                })
354            }
355        }
356
357        #[doc(hidden)]
358        fn #register_fn_name(__rapina_router: rapina::router::Router) -> rapina::router::Router {
359            __rapina_router.#router_method(#path_str, #func_name)
360        }
361
362        rapina::inventory::submit! {
363            rapina::discovery::RouteDescriptor {
364                method: #method,
365                path: #path_str,
366                handler_name: #func_name_str,
367                is_public: #is_public,
368                response_schema: <#func_name as rapina::handler::Handler>::response_schema,
369                error_responses: <#func_name as rapina::handler::Handler>::error_responses,
370                register: #register_fn_name,
371            }
372        }
373    }
374}
375
376/// Extracts the inner type from Json<T> wrapper for schema generation
377fn extract_json_inner_type(return_type: &syn::Type) -> Option<proc_macro2::TokenStream> {
378    if let syn::Type::Path(type_path) = return_type
379        && let Some(last_segment) = type_path.path.segments.last()
380    {
381        // Direct Json<T>
382        if last_segment.ident == "Json"
383            && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
384            && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
385        {
386            return Some(quote!(#inner_type));
387        }
388
389        // Result<Json<T>> or Result<Json<T>, E>
390        if last_segment.ident == "Result"
391            && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
392            && let Some(syn::GenericArgument::Type(ok_type)) = args.args.first()
393        {
394            return extract_json_inner_type(ok_type);
395        }
396    }
397    None
398}
399
400/// Extract #[errors(ErrorType)] attribute from function attributes, removing it if found.
401fn extract_errors_attr(attrs: &mut Vec<syn::Attribute>) -> Option<syn::Type> {
402    let idx = attrs
403        .iter()
404        .position(|attr| attr.path().is_ident("errors"))?;
405    let attr = attrs.remove(idx);
406    let err_type: syn::Type = attr.parse_args().expect("expected #[errors(ErrorType)]");
407    Some(err_type)
408}
409
410/// Extract #[cache(ttl = N)] attribute from function attributes, removing it if found.
411fn extract_cache_attr(attrs: &mut Vec<syn::Attribute>) -> Option<u64> {
412    let idx = attrs
413        .iter()
414        .position(|attr| attr.path().is_ident("cache"))?;
415    let attr = attrs.remove(idx);
416
417    let mut ttl: Option<u64> = None;
418    attr.parse_nested_meta(|meta| {
419        if meta.path.is_ident("ttl") {
420            let value = meta.value()?;
421            let lit: syn::LitInt = value.parse()?;
422            ttl = Some(lit.base10_parse()?);
423            Ok(())
424        } else {
425            Err(meta.error("expected `ttl`"))
426        }
427    })
428    .expect("expected #[cache(ttl = N)]");
429
430    ttl
431}
432
433/// Extract #[public] attribute from function attributes, removing it if found.
434fn extract_public_attr(attrs: &mut Vec<syn::Attribute>) -> bool {
435    if let Some(idx) = attrs.iter().position(|attr| attr.path().is_ident("public")) {
436        attrs.remove(idx);
437        true
438    } else {
439        false
440    }
441}
442
443/// Registers a channel handler for the relay system.
444///
445/// Channel handlers receive [`RelayEvent`](rapina::relay::RelayEvent) events
446/// when clients subscribe, send messages, or disconnect from matching topics.
447///
448/// The pattern supports exact matches and prefix matches (trailing `*`):
449///
450/// - `"chat:lobby"` — matches only the exact topic `"chat:lobby"`
451/// - `"room:*"` — matches any topic starting with `"room:"`
452///
453/// The first parameter must be `RelayEvent`. Remaining parameters are
454/// extracted via `FromRequestParts` with synthetic request parts (same
455/// extractors as HTTP handlers, minus body extractors).
456///
457/// # Example
458///
459/// ```ignore
460/// use rapina::prelude::*;
461/// use rapina::relay::{Relay, RelayEvent};
462///
463/// #[relay("room:*")]
464/// async fn room(event: RelayEvent, relay: Relay) -> Result<()> {
465///     match &event {
466///         RelayEvent::Join { topic, conn_id } => {
467///             relay.track(topic, *conn_id, serde_json::json!({}));
468///         }
469///         RelayEvent::Message { topic, event: ev, payload, .. } => {
470///             relay.push(topic, ev, payload).await?;
471///         }
472///         RelayEvent::Leave { topic, conn_id } => {
473///             relay.untrack(topic, *conn_id);
474///         }
475///     }
476///     Ok(())
477/// }
478/// ```
479#[proc_macro_attribute]
480pub fn relay(attr: TokenStream, item: TokenStream) -> TokenStream {
481    relay_macro_impl(attr.into(), item.into()).into()
482}
483
484fn relay_macro_impl(
485    attr: proc_macro2::TokenStream,
486    item: proc_macro2::TokenStream,
487) -> proc_macro2::TokenStream {
488    let pattern: LitStr = syn::parse2(attr).expect("expected pattern as string literal");
489    let pattern_str = pattern.value();
490    let func: ItemFn = syn::parse2(item).expect("#[relay] must be applied to an async function");
491
492    let func_name = &func.sig.ident;
493    let func_name_str = func_name.to_string();
494
495    let is_prefix = pattern_str.ends_with('*');
496    let match_prefix_str = if is_prefix {
497        &pattern_str[..pattern_str.len() - 1]
498    } else {
499        &pattern_str
500    };
501
502    let wrapper_name = syn::Ident::new(
503        &format!("__rapina_channel_{}", func_name_str),
504        proc_macro2::Span::call_site(),
505    );
506
507    // First arg is RelayEvent (passed directly). Remaining args are extractors.
508    let args: Vec<_> = func.sig.inputs.iter().collect();
509
510    let mut extractor_extractions = Vec::new();
511    let mut call_args = vec![quote! { __rapina_event }];
512
513    for (i, arg) in args.iter().enumerate() {
514        if i == 0 {
515            // First arg is RelayEvent — passed directly, not extracted
516            continue;
517        }
518        if let FnArg::Typed(pat_type) = arg {
519            if let Pat::Ident(pat_ident) = &*pat_type.pat {
520                let arg_name = &pat_ident.ident;
521                let arg_type = &pat_type.ty;
522
523                extractor_extractions.push(quote! {
524                    let #arg_name = <#arg_type as rapina::extract::FromRequestParts>::from_request_parts(
525                        &__rapina_parts, &__rapina_params, &__rapina_state
526                    ).await?;
527                });
528
529                call_args.push(quote! { #arg_name });
530            }
531        }
532    }
533
534    quote! {
535        #func
536
537        // Generated by #[relay] — not user-facing API
538        #[doc(hidden)]
539        fn #wrapper_name(
540            __rapina_event: rapina::relay::RelayEvent,
541            __rapina_state: std::sync::Arc<rapina::state::AppState>,
542            __rapina_current_user: Option<rapina::auth::CurrentUser>,
543        ) -> std::pin::Pin<Box<dyn std::future::Future<Output = std::result::Result<(), rapina::error::Error>> + Send>> {
544            Box::pin(async move {
545                let (mut __rapina_parts, _) = rapina::http::Request::new(()).into_parts();
546                if let Some(u) = __rapina_current_user {
547                    __rapina_parts.extensions.insert(u);
548                }
549                let __rapina_params = rapina::extract::PathParams::new();
550                #(#extractor_extractions)*
551                #func_name(#(#call_args),*).await
552            })
553        }
554
555        rapina::inventory::submit! {
556            rapina::relay::ChannelDescriptor {
557                pattern: #pattern_str,
558                is_prefix: #is_prefix,
559                match_prefix: #match_prefix_str,
560                handler_name: #func_name_str,
561                handle: #wrapper_name,
562            }
563        }
564    }
565}
566
567fn route_macro(method: &str, attr: TokenStream, item: TokenStream) -> TokenStream {
568    route_macro_core(method, attr.into(), item.into()).into()
569}
570
571/// Derive macro for type-safe configuration
572///
573/// Generates a `from_env()` method that loads configuration from environment variables.
574#[proc_macro_derive(Config, attributes(env, default))]
575pub fn derive_config(input: TokenStream) -> TokenStream {
576    derive_config_impl(input.into()).into()
577}
578
579/// Define database entities with Prisma-like syntax.
580///
581/// This macro generates SeaORM entity definitions from a declarative syntax
582/// where types indicate relationships. Each entity automatically gets `id`,
583/// `created_at`, and `updated_at` fields.
584///
585/// # Syntax
586///
587/// ```ignore
588/// rapina::schema! {
589///     User {
590///         email: String,
591///         name: String,
592///         posts: Vec<Post>,        // has_many relationship
593///     }
594///
595///     Post {
596///         title: String,
597///         content: Text,           // TEXT column type
598///         author: User,            // belongs_to -> generates author_id
599///         comments: Vec<Comment>,
600///     }
601///
602///     Comment {
603///         content: Text,
604///         post: Post,
605///         author: Option<User>,    // optional belongs_to
606///     }
607/// }
608/// ```
609///
610/// # Generated Code
611///
612/// For each entity, the macro generates a SeaORM module with:
613/// - `Model` struct with auto `id`, `created_at`, `updated_at`
614/// - `Relation` enum with proper SeaORM attributes
615/// - `Related<T>` trait implementations
616/// - `ActiveModelBehavior` implementation
617///
618/// # Supported Types
619///
620/// | Schema Type | Rust Type | Notes |
621/// |-------------|-----------|-------|
622/// | `String` | `String` | Default varchar |
623/// | `Text` | `String` | TEXT column |
624/// | `i32` | `i32` | |
625/// | `i64` | `i64` | |
626/// | `f32` | `f32` | |
627/// | `f64` | `f64` | |
628/// | `bool` | `bool` | |
629/// | `Uuid` | `Uuid` | |
630/// | `DateTime` | `DateTimeUtc` | |
631/// | `Date` | `Date` | |
632/// | `Decimal` | `Decimal` | |
633/// | `Json` | `Json` | |
634/// | `Option<T>` | `Option<T>` | Nullable |
635/// | `Vec<Entity>` | - | has_many relationship |
636/// | `Entity` | - | belongs_to (generates FK) |
637#[proc_macro]
638pub fn schema(input: TokenStream) -> TokenStream {
639    schema::schema_impl(input.into()).into()
640}
641
642fn derive_config_impl(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream {
643    let input: syn::DeriveInput = syn::parse2(input).expect("expected struct");
644    let name = &input.ident;
645
646    let fields = match &input.data {
647        syn::Data::Struct(data) => match &data.fields {
648            syn::Fields::Named(fields) => &fields.named,
649            _ => panic!("Config derive only supports structs with named fields"),
650        },
651        _ => panic!("Config derive only supports structs"),
652    };
653
654    let mut field_inits = Vec::new();
655    let mut missing_checks = Vec::new();
656
657    for field in fields {
658        let field_name = field.ident.as_ref().unwrap();
659        let field_type = &field.ty;
660
661        // Find #[env = "VAR_NAME"] attribute
662        let env_var = field
663            .attrs
664            .iter()
665            .find_map(|attr| {
666                if attr.path().is_ident("env")
667                    && let syn::Meta::NameValue(nv) = &attr.meta
668                    && let syn::Expr::Lit(expr_lit) = &nv.value
669                    && let syn::Lit::Str(lit_str) = &expr_lit.lit
670                {
671                    return Some(lit_str.value());
672                }
673                None
674            })
675            .unwrap_or_else(|| field_name.to_string().to_uppercase());
676
677        // Find #[default = "value"] attribute
678        let default_value = field.attrs.iter().find_map(|attr| {
679            if attr.path().is_ident("default")
680                && let syn::Meta::NameValue(nv) = &attr.meta
681                && let syn::Expr::Lit(expr_lit) = &nv.value
682                && let syn::Lit::Str(lit_str) = &expr_lit.lit
683            {
684                return Some(lit_str.value());
685            }
686            None
687        });
688
689        let env_var_lit = syn::LitStr::new(&env_var, proc_macro2::Span::call_site());
690
691        if let Some(default) = default_value {
692            let default_lit = syn::LitStr::new(&default, proc_macro2::Span::call_site());
693            field_inits.push(quote! {
694                #field_name: rapina::config::get_env_or(#env_var_lit, #default_lit).parse().unwrap_or_else(|_| #default_lit.parse().unwrap())
695            });
696        } else {
697            field_inits.push(quote! {
698                #field_name: rapina::config::get_env_parsed::<#field_type>(#env_var_lit)?
699            });
700            missing_checks.push(quote! {
701                if std::env::var(#env_var_lit).is_err() {
702                    missing.push(#env_var_lit);
703                }
704            });
705        }
706    }
707
708    quote! {
709        impl #name {
710            pub fn from_env() -> std::result::Result<Self, rapina::config::ConfigError> {
711                let mut missing: Vec<&str> = Vec::new();
712                #(#missing_checks)*
713
714                if !missing.is_empty() {
715                    return Err(rapina::config::ConfigError::MissingMultiple(
716                        missing.into_iter().map(String::from).collect()
717                    ));
718                }
719
720                Ok(Self {
721                    #(#field_inits),*
722                })
723            }
724        }
725    }
726}
727
728#[cfg(test)]
729mod tests {
730    use super::{join_paths, relay_macro_impl, route_macro_core};
731    use quote::quote;
732
733    #[test]
734    fn test_generates_struct_with_handler_impl() {
735        let path = quote!("/");
736        let input = quote! {
737            async fn hello() -> &'static str {
738                "Hello, Rapina!"
739            }
740        };
741
742        let output = route_macro_core("GET", path, input);
743        let output_str = output.to_string();
744
745        // Check struct is generated
746        assert!(output_str.contains("struct hello"));
747        // Check Handler impl is generated
748        assert!(output_str.contains("impl rapina :: handler :: Handler for hello"));
749        // Check NAME constant
750        assert!(output_str.contains("const NAME"));
751        assert!(output_str.contains("\"hello\""));
752    }
753
754    #[test]
755    fn test_generates_handler_with_extractors() {
756        let path = quote!("/users/:id");
757        let input = quote! {
758            async fn get_user(id: rapina::extract::Path<u64>) -> String {
759                format!("{}", id.into_inner())
760            }
761        };
762
763        let output = route_macro_core("GET", path, input);
764        let output_str = output.to_string();
765
766        assert!(output_str.contains("struct get_user"));
767        // Single arg is last arg — uses FromRequest (blanket impl handles parts-only)
768        assert!(output_str.contains("FromRequest"));
769        // Single arg should NOT destructure request into parts
770        assert!(!output_str.contains("into_parts"));
771    }
772
773    #[test]
774    fn test_function_with_multiple_extractors() {
775        let path = quote!("/users");
776        let input = quote! {
777            async fn create_user(
778                id: rapina::extract::Path<u64>,
779                body: rapina::extract::Json<String>
780            ) -> String {
781                "created".to_string()
782            }
783        };
784
785        let output = route_macro_core("POST", path, input);
786        let output_str = output.to_string();
787
788        // Check struct is generated
789        assert!(output_str.contains("struct create_user"));
790        // Check both extractors are handled
791        assert!(output_str.contains("FromRequestParts"));
792        assert!(output_str.contains("FromRequest"));
793    }
794
795    #[test]
796    fn test_two_body_extractors_no_macro_panic() {
797        // With positional convention, the macro does NOT panic for multiple body consumers.
798        // Instead, it generates code where the first Json is bounded by FromRequestParts
799        // (which it doesn't implement), so the compiler catches it at type-check time.
800        let path = quote!("/users");
801        let input = quote! {
802            async fn handler(
803                body1: rapina::extract::Json<String>,
804                body2: rapina::extract::Json<String>
805            ) -> String {
806                "ok".to_string()
807            }
808        };
809
810        // Should NOT panic — macro expansion succeeds, compiler catches the error later
811        let output = route_macro_core("POST", path, input);
812        let output_str = output.to_string();
813
814        // First arg gets FromRequestParts (will fail at compile time since Json doesn't impl it)
815        assert!(output_str.contains("FromRequestParts"));
816        // Last arg gets FromRequest
817        assert!(output_str.contains("FromRequest"));
818    }
819
820    #[test]
821    fn test_custom_type_name_not_misclassified() {
822        // UserPathInfo contains "Path" but should NOT be routed to FromRequestParts
823        // Positional convention: single (last) arg always uses FromRequest
824        let path = quote!("/users");
825        let input = quote! {
826            async fn handler(info: UserPathInfo) -> String {
827                "ok".to_string()
828            }
829        };
830
831        let output = route_macro_core("POST", path, input);
832        let output_str = output.to_string();
833
834        assert!(output_str.contains("FromRequest"));
835        assert!(!output_str.contains("FromRequestParts"));
836    }
837
838    #[test]
839    fn test_multiple_parts_only_extractors_positional() {
840        // All parts-only extractors: first N-1 use FromRequestParts, last uses FromRequest
841        let path = quote!("/users/:id");
842        let input = quote! {
843            async fn handler(
844                id: rapina::extract::Path<u64>,
845                query: rapina::extract::Query<Params>,
846                headers: rapina::extract::Headers,
847            ) -> String {
848                "ok".to_string()
849            }
850        };
851
852        let output = route_macro_core("GET", path, input);
853        let output_str = output.to_string();
854
855        // First two args use FromRequestParts
856        assert!(output_str.contains("FromRequestParts"));
857        // Last arg uses FromRequest (via blanket impl at runtime)
858        assert!(output_str.contains("FromRequest"));
859        // Request is destructured for multi-arg case
860        assert!(output_str.contains("into_parts"));
861        // Request is reassembled for last arg
862        assert!(output_str.contains("from_parts"));
863    }
864
865    #[test]
866    #[should_panic(expected = "expected function")]
867    fn test_invalid_input_panics() {
868        let path = quote!("/");
869        let invalid_input = quote! { not_a_function };
870
871        route_macro_core("GET", path, invalid_input);
872    }
873
874    #[test]
875    fn test_json_return_type_generates_response_schema() {
876        let path = quote!("/users");
877        let input = quote! {
878            async fn get_user() -> Json<UserResponse> {
879                Json(UserResponse { id: 1 })
880            }
881        };
882
883        let output = route_macro_core("GET", path, input);
884        let output_str = output.to_string();
885
886        // Check response_schema method is generated with schema_for!
887        assert!(output_str.contains("fn response_schema"));
888        assert!(output_str.contains("rapina :: schemars :: schema_for !"));
889        assert!(output_str.contains("UserResponse"));
890    }
891
892    #[test]
893    fn test_result_json_return_type_generates_response_schema() {
894        let path = quote!("/users");
895        let input = quote! {
896            async fn get_user() -> Result<Json<UserResponse>> {
897                Ok(Json(UserResponse { id: 1 }))
898            }
899        };
900
901        let output = route_macro_core("GET", path, input);
902        let output_str = output.to_string();
903
904        assert!(output_str.contains("fn response_schema"));
905        assert!(output_str.contains("rapina :: schemars :: schema_for !"));
906        assert!(output_str.contains("UserResponse"));
907    }
908
909    #[test]
910    fn test_errors_attr_generates_error_responses() {
911        let path = quote!("/users");
912        let input = quote! {
913            #[errors(UserError)]
914            async fn get_user() -> Result<Json<UserResponse>> {
915                Ok(Json(UserResponse { id: 1 }))
916            }
917        };
918
919        let output = route_macro_core("GET", path, input);
920        let output_str = output.to_string();
921
922        assert!(output_str.contains("fn error_responses"));
923        assert!(output_str.contains("DocumentedError"));
924        assert!(output_str.contains("UserError"));
925    }
926
927    #[test]
928    fn test_non_json_return_type_no_response_schema() {
929        let path = quote!("/health");
930        let input = quote! {
931            async fn health() -> &'static str {
932                "ok"
933            }
934        };
935
936        let output = route_macro_core("GET", path, input);
937        let output_str = output.to_string();
938
939        // Check response_schema method is NOT generated for non-Json types
940        assert!(!output_str.contains("fn response_schema"));
941        assert!(!output_str.contains("schema_for"));
942    }
943
944    #[test]
945    fn test_user_state_variable_not_shadowed() {
946        // Regression test for issue #134 - user naming their extractor 'state'
947        // should not conflict with internal macro variables
948        let path = quote!("/users");
949        let input = quote! {
950            async fn list_users(state: rapina::extract::State<MyState>) -> String {
951                "ok".to_string()
952            }
953        };
954
955        let output = route_macro_core("GET", path, input);
956        let output_str = output.to_string();
957
958        // Internal variables should use __rapina_ prefix
959        assert!(output_str.contains("__rapina_state"));
960        assert!(output_str.contains("__rapina_params"));
961        // User's variable 'state' should still be extracted
962        assert!(output_str.contains("let state ="));
963    }
964
965    #[test]
966    fn test_no_closure_wrapper_for_type_inference() {
967        // Regression test for issue #134 - Result type inference should work
968        let path = quote!("/users");
969        let input = quote! {
970            async fn get_user() -> Result<String, Error> {
971                Ok("user".to_string())
972            }
973        };
974
975        let output = route_macro_core("GET", path, input);
976        let output_str = output.to_string();
977
978        // Should NOT use closure wrapper (|| async ...)
979        assert!(!output_str.contains("|| async"));
980        // Should use typed result with async block (: ReturnType = (async ...).await)
981        assert!(output_str.contains("__rapina_result"));
982        assert!(output_str.contains("Result < String , Error >"));
983    }
984
985    #[test]
986    fn test_emits_route_descriptor() {
987        let path = quote!("/users");
988        let input = quote! {
989            async fn list_users() -> &'static str {
990                "users"
991            }
992        };
993
994        let output = route_macro_core("GET", path, input);
995        let output_str = output.to_string();
996
997        assert!(output_str.contains("inventory :: submit !"));
998        assert!(output_str.contains("RouteDescriptor"));
999        assert!(output_str.contains("method : \"GET\""));
1000        assert!(output_str.contains("path : \"/users\""));
1001        assert!(output_str.contains("handler_name : \"list_users\""));
1002        assert!(output_str.contains("is_public : false"));
1003        assert!(output_str.contains("__rapina_register_list_users"));
1004    }
1005
1006    #[test]
1007    fn test_emits_route_descriptor_with_method() {
1008        let path = quote!("/users");
1009        let input = quote! {
1010            async fn create_user() -> &'static str {
1011                "created"
1012            }
1013        };
1014
1015        let output = route_macro_core("POST", path, input);
1016        let output_str = output.to_string();
1017
1018        assert!(output_str.contains("method : \"POST\""));
1019        assert!(output_str.contains("__rapina_router . post"));
1020    }
1021
1022    #[test]
1023    fn test_public_attr_below_route_sets_is_public() {
1024        let path = quote!("/health");
1025        let input = quote! {
1026            #[public]
1027            async fn health() -> &'static str {
1028                "ok"
1029            }
1030        };
1031
1032        let output = route_macro_core("GET", path, input);
1033        let output_str = output.to_string();
1034
1035        assert!(output_str.contains("is_public : true"));
1036    }
1037
1038    #[test]
1039    fn test_cache_attr_injects_ttl_header() {
1040        let path = quote!("/products");
1041        let input = quote! {
1042            #[cache(ttl = 60)]
1043            async fn list_products() -> &'static str {
1044                "products"
1045            }
1046        };
1047
1048        let output = route_macro_core("GET", path, input);
1049        let output_str = output.to_string();
1050
1051        assert!(output_str.contains("x-rapina-cache-ttl"));
1052        assert!(output_str.contains("60"));
1053    }
1054
1055    #[test]
1056    fn test_relay_macro_generates_wrapper_and_inventory() {
1057        let attr = quote!("room:*");
1058        let input = quote! {
1059            async fn room(event: rapina::relay::RelayEvent, relay: rapina::relay::Relay) -> Result<(), rapina::error::Error> {
1060                Ok(())
1061            }
1062        };
1063
1064        let output = relay_macro_impl(attr, input);
1065        let output_str = output.to_string();
1066
1067        // Original function is preserved
1068        assert!(output_str.contains("async fn room"));
1069        // Wrapper function is generated
1070        assert!(output_str.contains("__rapina_channel_room"));
1071        // Inventory submission
1072        assert!(output_str.contains("inventory :: submit !"));
1073        assert!(output_str.contains("ChannelDescriptor"));
1074        assert!(output_str.contains("pattern : \"room:*\""));
1075        assert!(output_str.contains("is_prefix : true"));
1076        assert!(output_str.contains("match_prefix : \"room:\""));
1077        assert!(output_str.contains("handler_name : \"room\""));
1078    }
1079
1080    #[test]
1081    fn test_relay_macro_exact_match() {
1082        let attr = quote!("chat:lobby");
1083        let input = quote! {
1084            async fn lobby(event: rapina::relay::RelayEvent) -> Result<(), rapina::error::Error> {
1085                Ok(())
1086            }
1087        };
1088
1089        let output = relay_macro_impl(attr, input);
1090        let output_str = output.to_string();
1091
1092        assert!(output_str.contains("is_prefix : false"));
1093        assert!(output_str.contains("match_prefix : \"chat:lobby\""));
1094    }
1095
1096    #[test]
1097    fn test_relay_macro_extracts_additional_params() {
1098        let attr = quote!("room:*");
1099        let input = quote! {
1100            async fn room(
1101                event: rapina::relay::RelayEvent,
1102                relay: rapina::relay::Relay,
1103                log: rapina::extract::State<TestLog>,
1104            ) -> Result<(), rapina::error::Error> {
1105                Ok(())
1106            }
1107        };
1108
1109        let output = relay_macro_impl(attr, input);
1110        let output_str = output.to_string();
1111
1112        // Both extractors should use FromRequestParts
1113        assert!(output_str.contains("let relay ="));
1114        assert!(output_str.contains("let log ="));
1115        assert!(output_str.contains("FromRequestParts"));
1116    }
1117
1118    #[test]
1119    fn test_no_cache_attr_no_ttl_header() {
1120        let path = quote!("/products");
1121        let input = quote! {
1122            async fn list_products() -> &'static str {
1123                "products"
1124            }
1125        };
1126
1127        let output = route_macro_core("GET", path, input);
1128        let output_str = output.to_string();
1129
1130        assert!(!output_str.contains("x-rapina-cache-ttl"));
1131    }
1132
1133    #[test]
1134    fn test_cache_attr_with_extractors() {
1135        let path = quote!("/users/:id");
1136        let input = quote! {
1137            #[cache(ttl = 120)]
1138            async fn get_user(id: rapina::extract::Path<u64>) -> String {
1139                format!("{}", id.into_inner())
1140            }
1141        };
1142
1143        let output = route_macro_core("GET", path, input);
1144        let output_str = output.to_string();
1145
1146        assert!(output_str.contains("x-rapina-cache-ttl"));
1147        assert!(output_str.contains("120"));
1148        // Single arg uses FromRequest (positional convention)
1149        assert!(output_str.contains("FromRequest"));
1150    }
1151
1152    #[test]
1153    fn test_group_param_joins_path() {
1154        let attr = quote!("/users", group = "/api");
1155        let input = quote! {
1156            async fn list_users() -> &'static str {
1157                "users"
1158            }
1159        };
1160
1161        let output = route_macro_core("GET", attr, input);
1162        let output_str = output.to_string();
1163
1164        assert!(output_str.contains("path : \"/api/users\""));
1165        assert!(output_str.contains("__rapina_router . get (\"/api/users\""));
1166    }
1167
1168    #[test]
1169    fn test_group_param_with_nested_prefix() {
1170        let attr = quote!("/items", group = "/api/v1");
1171        let input = quote! {
1172            async fn list_items() -> &'static str {
1173                "items"
1174            }
1175        };
1176
1177        let output = route_macro_core("GET", attr, input);
1178        let output_str = output.to_string();
1179
1180        assert!(output_str.contains("path : \"/api/v1/items\""));
1181    }
1182
1183    #[test]
1184    fn test_without_group_param_backward_compatible() {
1185        let attr = quote!("/users");
1186        let input = quote! {
1187            async fn list_users() -> &'static str {
1188                "users"
1189            }
1190        };
1191
1192        let output = route_macro_core("GET", attr, input);
1193        let output_str = output.to_string();
1194
1195        assert!(output_str.contains("path : \"/users\""));
1196        assert!(output_str.contains("__rapina_router . get (\"/users\""));
1197    }
1198
1199    #[test]
1200    #[should_panic(expected = "group prefix must start with `/`")]
1201    fn test_group_prefix_must_start_with_slash() {
1202        let attr = quote!("/users", group = "api");
1203        let input = quote! {
1204            async fn list_users() -> &'static str {
1205                "users"
1206            }
1207        };
1208
1209        route_macro_core("GET", attr, input);
1210    }
1211
1212    #[test]
1213    fn test_group_with_trailing_slash_normalized() {
1214        let attr = quote!("/users", group = "/api/");
1215        let input = quote! {
1216            async fn list_users() -> &'static str {
1217                "users"
1218            }
1219        };
1220
1221        let output = route_macro_core("GET", attr, input);
1222        let output_str = output.to_string();
1223
1224        assert!(output_str.contains("path : \"/api/users\""));
1225    }
1226
1227    #[test]
1228    fn test_group_with_public_attr() {
1229        let attr = quote!("/health", group = "/api");
1230        let input = quote! {
1231            #[public]
1232            async fn health() -> &'static str {
1233                "ok"
1234            }
1235        };
1236
1237        let output = route_macro_core("GET", attr, input);
1238        let output_str = output.to_string();
1239
1240        assert!(output_str.contains("path : \"/api/health\""));
1241        assert!(output_str.contains("is_public : true"));
1242    }
1243
1244    #[test]
1245    fn test_group_with_cache_attr() {
1246        let attr = quote!("/products", group = "/api");
1247        let input = quote! {
1248            #[cache(ttl = 60)]
1249            async fn list_products() -> &'static str {
1250                "products"
1251            }
1252        };
1253
1254        let output = route_macro_core("GET", attr, input);
1255        let output_str = output.to_string();
1256
1257        assert!(output_str.contains("path : \"/api/products\""));
1258        assert!(output_str.contains("x-rapina-cache-ttl"));
1259        assert!(output_str.contains("60"));
1260    }
1261
1262    #[test]
1263    fn test_group_with_errors_attr() {
1264        let attr = quote!("/users", group = "/api");
1265        let input = quote! {
1266            #[errors(UserError)]
1267            async fn get_user() -> Result<Json<UserResponse>> {
1268                Ok(Json(UserResponse { id: 1 }))
1269            }
1270        };
1271
1272        let output = route_macro_core("GET", attr, input);
1273        let output_str = output.to_string();
1274
1275        assert!(output_str.contains("path : \"/api/users\""));
1276        assert!(output_str.contains("fn error_responses"));
1277        assert!(output_str.contains("UserError"));
1278    }
1279
1280    #[test]
1281    fn test_group_with_all_methods() {
1282        for method in &["GET", "POST", "PUT", "DELETE"] {
1283            let attr = quote!("/items", group = "/api");
1284            let input = quote! {
1285                async fn handler() -> &'static str {
1286                    "ok"
1287                }
1288            };
1289
1290            let output = route_macro_core(method, attr, input);
1291            let output_str = output.to_string();
1292
1293            assert!(
1294                output_str.contains("path : \"/api/items\""),
1295                "{method} should produce /api/items"
1296            );
1297            let method_lower = method.to_lowercase();
1298            assert!(
1299                output_str.contains(&format!("__rapina_router . {method_lower}")),
1300                "{method} should use .{method_lower}() on router"
1301            );
1302        }
1303    }
1304
1305    #[test]
1306    fn test_join_paths_basic() {
1307        assert_eq!(join_paths("/api", "/users"), "/api/users");
1308        assert_eq!(join_paths("/api/v1", "/items"), "/api/v1/items");
1309    }
1310
1311    #[test]
1312    fn test_join_paths_trailing_slash() {
1313        assert_eq!(join_paths("/api/", "/users"), "/api/users");
1314    }
1315
1316    #[test]
1317    fn test_join_paths_empty_path() {
1318        assert_eq!(join_paths("/api", ""), "/api");
1319        assert_eq!(join_paths("/api", "/"), "/api");
1320    }
1321
1322    #[test]
1323    fn test_join_paths_empty_prefix() {
1324        assert_eq!(join_paths("", "/users"), "/users");
1325        assert_eq!(join_paths("", ""), "/");
1326    }
1327}