Skip to main content

rapina_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{FnArg, ItemFn, LitStr, Pat};
4
5/// Parsed route macro attribute: `"/path"` or `"/path", group = "/prefix"`.
6struct RouteAttr {
7    path: LitStr,
8    group: Option<LitStr>,
9}
10
11impl syn::parse::Parse for RouteAttr {
12    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
13        let path: LitStr = input.parse()?;
14        let group = if input.peek(syn::Token![,]) {
15            input.parse::<syn::Token![,]>()?;
16            let ident: syn::Ident = input.parse()?;
17            if ident != "group" {
18                return Err(syn::Error::new(ident.span(), "expected `group`"));
19            }
20            input.parse::<syn::Token![=]>()?;
21            let value: LitStr = input.parse()?;
22            Some(value)
23        } else {
24            None
25        };
26        if !input.is_empty() {
27            return Err(input.error("unexpected tokens after route attribute"));
28        }
29        Ok(RouteAttr { path, group })
30    }
31}
32
33/// Join a group prefix with a route path at compile time.
34fn join_paths(prefix: &str, path: &str) -> String {
35    let prefix = prefix.trim_end_matches('/');
36    if path.is_empty() || path == "/" {
37        if prefix.is_empty() {
38            return "/".to_string();
39        }
40        return prefix.to_string();
41    }
42    let path = if path.starts_with('/') {
43        path.to_string()
44    } else {
45        format!("/{path}")
46    };
47    format!("{prefix}{path}")
48}
49
50mod schema;
51
52/// Registers a GET route handler.
53///
54/// # Syntax
55///
56/// ```ignore
57/// #[get("/users")]
58/// async fn list_users() -> Json<Vec<User>> { /* ... */ }
59///
60/// // With a group prefix (registers at /api/users):
61/// #[get("/users", group = "/api")]
62/// async fn list_users() -> Json<Vec<User>> { /* ... */ }
63/// ```
64///
65/// The `group` parameter joins the prefix with the path at compile time,
66/// so the handler is registered at the full path during auto-discovery.
67#[proc_macro_attribute]
68pub fn get(attr: TokenStream, item: TokenStream) -> TokenStream {
69    route_macro("GET", attr, item)
70}
71
72/// Registers a POST route handler.
73///
74/// See [`get`] for syntax details including the optional `group` parameter.
75#[proc_macro_attribute]
76pub fn post(attr: TokenStream, item: TokenStream) -> TokenStream {
77    route_macro("POST", attr, item)
78}
79
80/// Registers a PUT route handler.
81///
82/// See [`get`] for syntax details including the optional `group` parameter.
83#[proc_macro_attribute]
84pub fn put(attr: TokenStream, item: TokenStream) -> TokenStream {
85    route_macro("PUT", attr, item)
86}
87
88/// Registers a DELETE route handler.
89///
90/// See [`get`] for syntax details including the optional `group` parameter.
91#[proc_macro_attribute]
92pub fn delete(attr: TokenStream, item: TokenStream) -> TokenStream {
93    route_macro("DELETE", attr, item)
94}
95
96/// Marks a route as public (no authentication required).
97///
98/// When authentication is enabled via `Rapina::with_auth()`, all routes
99/// require a valid JWT token by default. Use `#[public]` to allow
100/// unauthenticated access to specific routes.
101///
102/// # Example
103///
104/// ```ignore
105/// use rapina::prelude::*;
106///
107/// #[public]
108/// #[get("/health")]
109/// async fn health() -> &'static str {
110///     "ok"
111/// }
112///
113/// #[public]
114/// #[post("/login")]
115/// async fn login(body: Json<LoginRequest>) -> Result<Json<TokenResponse>> {
116///     // ... authenticate and return token
117/// }
118/// ```
119///
120/// Note: Routes starting with `/__rapina` are automatically public.
121#[proc_macro_attribute]
122pub fn public(_attr: TokenStream, item: TokenStream) -> TokenStream {
123    let func: ItemFn = syn::parse(item.clone()).expect("#[public] must be applied to a function");
124    let func_name_str = func.sig.ident.to_string();
125    let item2: proc_macro2::TokenStream = item.into();
126    quote! {
127        #item2
128        rapina::inventory::submit! {
129            rapina::discovery::PublicMarker {
130                handler_name: #func_name_str,
131            }
132        }
133    }
134    .into()
135}
136
137fn route_macro_core(
138    method: &str,
139    attr: proc_macro2::TokenStream,
140    item: proc_macro2::TokenStream,
141) -> proc_macro2::TokenStream {
142    let route_attr: RouteAttr = syn::parse2(attr).expect("expected path as string literal");
143    let path_str = if let Some(ref group) = route_attr.group {
144        let g = group.value();
145        assert!(
146            g.starts_with('/'),
147            "group prefix must start with `/`, got: {g:?}"
148        );
149        join_paths(&g, &route_attr.path.value())
150    } else {
151        route_attr.path.value()
152    };
153    let mut func: ItemFn = syn::parse2(item).expect("expected function");
154
155    let func_name = &func.sig.ident;
156    let func_name_str = func_name.to_string();
157    let func_vis = &func.vis;
158
159    // Extract #[public] attribute if present (when #[public] is below the route macro)
160    let is_public = extract_public_attr(&mut func.attrs);
161
162    // Extract #[errors(ErrorType)] attribute if present
163    let error_type = extract_errors_attr(&mut func.attrs);
164
165    // Extract #[cache(ttl = N)] attribute if present
166    let cache_ttl = extract_cache_attr(&mut func.attrs);
167
168    let error_responses_impl = if let Some(err_type) = &error_type {
169        quote! {
170            fn error_responses() -> Vec<rapina::error::ErrorVariant> {
171                <#err_type as rapina::error::DocumentedError>::error_variants()
172            }
173        }
174    } else {
175        quote! {}
176    };
177
178    // Extract return type for schema generation
179    let response_schema_impl = if let syn::ReturnType::Type(_, return_type) = &func.sig.output {
180        if let Some(inner_type) = extract_json_inner_type(return_type) {
181            quote! {
182                fn response_schema() -> Option<serde_json::Value> {
183                    Some(serde_json::to_value(rapina::schemars::schema_for!(#inner_type)).unwrap())
184                }
185            }
186        } else {
187            quote! {}
188        }
189    } else {
190        quote! {}
191    };
192
193    let args: Vec<_> = func.sig.inputs.iter().collect();
194
195    // Extract return type for type annotation (helps with type inference in async blocks)
196    let return_type_annotation = match &func.sig.output {
197        syn::ReturnType::Type(_, ty) => quote! { : #ty },
198        syn::ReturnType::Default => quote! {},
199    };
200
201    // Optional cache TTL header injection
202    let cache_header_injection = if let Some(ttl) = cache_ttl {
203        let ttl_str = ttl.to_string();
204        quote! {
205            let mut __rapina_response = __rapina_response;
206            __rapina_response.headers_mut().insert(
207                "x-rapina-cache-ttl",
208                rapina::http::HeaderValue::from_static(#ttl_str),
209            );
210        }
211    } else {
212        quote! {}
213    };
214
215    // Build the handler body
216    // Use __rapina_ prefix for internal variables to avoid shadowing user's variables
217    let handler_body = if args.is_empty() {
218        let inner_block = &func.block;
219        quote! {
220            let __rapina_result #return_type_annotation = (async #inner_block).await;
221            let __rapina_response = rapina::response::IntoResponse::into_response(__rapina_result);
222            #cache_header_injection
223            __rapina_response
224        }
225    } else {
226        let mut parts_extractions = Vec::new();
227        let mut body_extractors: Vec<(syn::Ident, Box<syn::Type>)> = Vec::new();
228
229        for arg in &args {
230            if let FnArg::Typed(pat_type) = arg
231                && let Pat::Ident(pat_ident) = &*pat_type.pat
232            {
233                let arg_name = &pat_ident.ident;
234                let arg_type = &pat_type.ty;
235
236                let type_str = quote!(#arg_type).to_string();
237                if is_parts_only_extractor(&type_str) {
238                    parts_extractions.push(quote! {
239                        let #arg_name = match <#arg_type as rapina::extract::FromRequestParts>::from_request_parts(&__rapina_parts, &__rapina_params, &__rapina_state).await {
240                            Ok(v) => v,
241                            Err(e) => return rapina::response::IntoResponse::into_response(e),
242                        };
243                    });
244                } else {
245                    body_extractors.push((arg_name.clone(), arg_type.clone()));
246                }
247            }
248        }
249
250        let body_extraction = if body_extractors.is_empty() {
251            quote! {}
252        } else if body_extractors.len() == 1 {
253            let (arg_name, arg_type) = &body_extractors[0];
254            quote! {
255                let __rapina_req = rapina::http::Request::from_parts(__rapina_parts, __rapina_body);
256                let #arg_name = match <#arg_type as rapina::extract::FromRequest>::from_request(__rapina_req, &__rapina_params, &__rapina_state).await {
257                    Ok(v) => v,
258                    Err(e) => return rapina::response::IntoResponse::into_response(e),
259                };
260            }
261        } else {
262            let names: Vec<_> = body_extractors.iter().map(|(n, _)| n.to_string()).collect();
263            panic!(
264                "Multiple body-consuming extractors are not supported: {}. Only one extractor can consume the request body.",
265                names.join(", ")
266            );
267        };
268
269        let inner_block = &func.block;
270
271        quote! {
272            let (__rapina_parts, __rapina_body) = __rapina_req.into_parts();
273            #(#parts_extractions)*
274            #body_extraction
275            let __rapina_result #return_type_annotation = (async #inner_block).await;
276            let __rapina_response = rapina::response::IntoResponse::into_response(__rapina_result);
277            #cache_header_injection
278            __rapina_response
279        }
280    };
281
282    // Build the router method call for the register function
283    let router_method = syn::Ident::new(&method.to_lowercase(), proc_macro2::Span::call_site());
284    let register_fn_name = syn::Ident::new(
285        &format!("__rapina_register_{}", func_name_str),
286        proc_macro2::Span::call_site(),
287    );
288
289    // Generate the struct, Handler impl, and inventory submission
290    quote! {
291        #[derive(Clone, Copy)]
292        #[allow(non_camel_case_types)]
293        #func_vis struct #func_name;
294
295        impl rapina::handler::Handler for #func_name {
296            const NAME: &'static str = #func_name_str;
297
298            #response_schema_impl
299            #error_responses_impl
300
301            fn call(
302                &self,
303                __rapina_req: rapina::hyper::Request<rapina::hyper::body::Incoming>,
304                __rapina_params: rapina::extract::PathParams,
305                __rapina_state: std::sync::Arc<rapina::state::AppState>,
306            ) -> std::pin::Pin<Box<dyn std::future::Future<Output = rapina::hyper::Response<rapina::response::BoxBody>> + Send>> {
307                Box::pin(async move {
308                    #handler_body
309                })
310            }
311        }
312
313        #[doc(hidden)]
314        fn #register_fn_name(__rapina_router: rapina::router::Router) -> rapina::router::Router {
315            __rapina_router.#router_method(#path_str, #func_name)
316        }
317
318        rapina::inventory::submit! {
319            rapina::discovery::RouteDescriptor {
320                method: #method,
321                path: #path_str,
322                handler_name: #func_name_str,
323                is_public: #is_public,
324                response_schema: <#func_name as rapina::handler::Handler>::response_schema,
325                error_responses: <#func_name as rapina::handler::Handler>::error_responses,
326                register: #register_fn_name,
327            }
328        }
329    }
330}
331
332fn is_parts_only_extractor(type_str: &str) -> bool {
333    type_str.contains("Path")
334        || type_str.contains("Query")
335        || type_str.contains("Headers")
336        || type_str.contains("State")
337        || type_str.contains("Context")
338        || type_str.contains("CurrentUser")
339        || type_str.contains("Db")
340        || type_str.contains("Cookie")
341        || type_str.contains("Relay")
342}
343
344/// Extracts the inner type from Json<T> wrapper for schema generation
345fn extract_json_inner_type(return_type: &syn::Type) -> Option<proc_macro2::TokenStream> {
346    if let syn::Type::Path(type_path) = return_type
347        && let Some(last_segment) = type_path.path.segments.last()
348    {
349        // Direct Json<T>
350        if last_segment.ident == "Json"
351            && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
352            && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
353        {
354            return Some(quote!(#inner_type));
355        }
356
357        // Result<Json<T>> or Result<Json<T>, E>
358        if last_segment.ident == "Result"
359            && let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
360            && let Some(syn::GenericArgument::Type(ok_type)) = args.args.first()
361        {
362            return extract_json_inner_type(ok_type);
363        }
364    }
365    None
366}
367
368/// Extract #[errors(ErrorType)] attribute from function attributes, removing it if found.
369fn extract_errors_attr(attrs: &mut Vec<syn::Attribute>) -> Option<syn::Type> {
370    let idx = attrs
371        .iter()
372        .position(|attr| attr.path().is_ident("errors"))?;
373    let attr = attrs.remove(idx);
374    let err_type: syn::Type = attr.parse_args().expect("expected #[errors(ErrorType)]");
375    Some(err_type)
376}
377
378/// Extract #[cache(ttl = N)] attribute from function attributes, removing it if found.
379fn extract_cache_attr(attrs: &mut Vec<syn::Attribute>) -> Option<u64> {
380    let idx = attrs
381        .iter()
382        .position(|attr| attr.path().is_ident("cache"))?;
383    let attr = attrs.remove(idx);
384
385    let mut ttl: Option<u64> = None;
386    attr.parse_nested_meta(|meta| {
387        if meta.path.is_ident("ttl") {
388            let value = meta.value()?;
389            let lit: syn::LitInt = value.parse()?;
390            ttl = Some(lit.base10_parse()?);
391            Ok(())
392        } else {
393            Err(meta.error("expected `ttl`"))
394        }
395    })
396    .expect("expected #[cache(ttl = N)]");
397
398    ttl
399}
400
401/// Extract #[public] attribute from function attributes, removing it if found.
402fn extract_public_attr(attrs: &mut Vec<syn::Attribute>) -> bool {
403    if let Some(idx) = attrs.iter().position(|attr| attr.path().is_ident("public")) {
404        attrs.remove(idx);
405        true
406    } else {
407        false
408    }
409}
410
411/// Registers a channel handler for the relay system.
412///
413/// Channel handlers receive [`RelayEvent`](rapina::relay::RelayEvent) events
414/// when clients subscribe, send messages, or disconnect from matching topics.
415///
416/// The pattern supports exact matches and prefix matches (trailing `*`):
417///
418/// - `"chat:lobby"` — matches only the exact topic `"chat:lobby"`
419/// - `"room:*"` — matches any topic starting with `"room:"`
420///
421/// The first parameter must be `RelayEvent`. Remaining parameters are
422/// extracted via `FromRequestParts` with synthetic request parts (same
423/// extractors as HTTP handlers, minus body extractors).
424///
425/// # Example
426///
427/// ```ignore
428/// use rapina::prelude::*;
429/// use rapina::relay::{Relay, RelayEvent};
430///
431/// #[relay("room:*")]
432/// async fn room(event: RelayEvent, relay: Relay) -> Result<()> {
433///     match &event {
434///         RelayEvent::Join { topic, conn_id } => {
435///             relay.track(topic, *conn_id, serde_json::json!({}));
436///         }
437///         RelayEvent::Message { topic, event: ev, payload, .. } => {
438///             relay.push(topic, ev, payload).await?;
439///         }
440///         RelayEvent::Leave { topic, conn_id } => {
441///             relay.untrack(topic, *conn_id);
442///         }
443///     }
444///     Ok(())
445/// }
446/// ```
447#[proc_macro_attribute]
448pub fn relay(attr: TokenStream, item: TokenStream) -> TokenStream {
449    relay_macro_impl(attr.into(), item.into()).into()
450}
451
452fn relay_macro_impl(
453    attr: proc_macro2::TokenStream,
454    item: proc_macro2::TokenStream,
455) -> proc_macro2::TokenStream {
456    let pattern: LitStr = syn::parse2(attr).expect("expected pattern as string literal");
457    let pattern_str = pattern.value();
458    let func: ItemFn = syn::parse2(item).expect("#[relay] must be applied to an async function");
459
460    let func_name = &func.sig.ident;
461    let func_name_str = func_name.to_string();
462
463    let is_prefix = pattern_str.ends_with('*');
464    let match_prefix_str = if is_prefix {
465        &pattern_str[..pattern_str.len() - 1]
466    } else {
467        &pattern_str
468    };
469
470    let wrapper_name = syn::Ident::new(
471        &format!("__rapina_channel_{}", func_name_str),
472        proc_macro2::Span::call_site(),
473    );
474
475    // First arg is RelayEvent (passed directly). Remaining args are extractors.
476    let args: Vec<_> = func.sig.inputs.iter().collect();
477
478    let mut extractor_extractions = Vec::new();
479    let mut call_args = vec![quote! { __rapina_event }];
480
481    for (i, arg) in args.iter().enumerate() {
482        if i == 0 {
483            // First arg is RelayEvent — passed directly, not extracted
484            continue;
485        }
486        if let FnArg::Typed(pat_type) = arg {
487            if let Pat::Ident(pat_ident) = &*pat_type.pat {
488                let arg_name = &pat_ident.ident;
489                let arg_type = &pat_type.ty;
490
491                extractor_extractions.push(quote! {
492                    let #arg_name = <#arg_type as rapina::extract::FromRequestParts>::from_request_parts(
493                        &__rapina_parts, &__rapina_params, &__rapina_state
494                    ).await?;
495                });
496
497                call_args.push(quote! { #arg_name });
498            }
499        }
500    }
501
502    quote! {
503        #func
504
505        // Generated by #[relay] — not user-facing API
506        #[doc(hidden)]
507        fn #wrapper_name(
508            __rapina_event: rapina::relay::RelayEvent,
509            __rapina_state: std::sync::Arc<rapina::state::AppState>,
510            __rapina_current_user: Option<rapina::auth::CurrentUser>,
511        ) -> std::pin::Pin<Box<dyn std::future::Future<Output = std::result::Result<(), rapina::error::Error>> + Send>> {
512            Box::pin(async move {
513                let (mut __rapina_parts, _) = rapina::http::Request::new(()).into_parts();
514                if let Some(u) = __rapina_current_user {
515                    __rapina_parts.extensions.insert(u);
516                }
517                let __rapina_params: rapina::extract::PathParams = std::collections::HashMap::new();
518                #(#extractor_extractions)*
519                #func_name(#(#call_args),*).await
520            })
521        }
522
523        rapina::inventory::submit! {
524            rapina::relay::ChannelDescriptor {
525                pattern: #pattern_str,
526                is_prefix: #is_prefix,
527                match_prefix: #match_prefix_str,
528                handler_name: #func_name_str,
529                handle: #wrapper_name,
530            }
531        }
532    }
533}
534
535fn route_macro(method: &str, attr: TokenStream, item: TokenStream) -> TokenStream {
536    route_macro_core(method, attr.into(), item.into()).into()
537}
538
539/// Derive macro for type-safe configuration
540///
541/// Generates a `from_env()` method that loads configuration from environment variables.
542#[proc_macro_derive(Config, attributes(env, default))]
543pub fn derive_config(input: TokenStream) -> TokenStream {
544    derive_config_impl(input.into()).into()
545}
546
547/// Define database entities with Prisma-like syntax.
548///
549/// This macro generates SeaORM entity definitions from a declarative syntax
550/// where types indicate relationships. Each entity automatically gets `id`,
551/// `created_at`, and `updated_at` fields.
552///
553/// # Syntax
554///
555/// ```ignore
556/// rapina::schema! {
557///     User {
558///         email: String,
559///         name: String,
560///         posts: Vec<Post>,        // has_many relationship
561///     }
562///
563///     Post {
564///         title: String,
565///         content: Text,           // TEXT column type
566///         author: User,            // belongs_to -> generates author_id
567///         comments: Vec<Comment>,
568///     }
569///
570///     Comment {
571///         content: Text,
572///         post: Post,
573///         author: Option<User>,    // optional belongs_to
574///     }
575/// }
576/// ```
577///
578/// # Generated Code
579///
580/// For each entity, the macro generates a SeaORM module with:
581/// - `Model` struct with auto `id`, `created_at`, `updated_at`
582/// - `Relation` enum with proper SeaORM attributes
583/// - `Related<T>` trait implementations
584/// - `ActiveModelBehavior` implementation
585///
586/// # Supported Types
587///
588/// | Schema Type | Rust Type | Notes |
589/// |-------------|-----------|-------|
590/// | `String` | `String` | Default varchar |
591/// | `Text` | `String` | TEXT column |
592/// | `i32` | `i32` | |
593/// | `i64` | `i64` | |
594/// | `f32` | `f32` | |
595/// | `f64` | `f64` | |
596/// | `bool` | `bool` | |
597/// | `Uuid` | `Uuid` | |
598/// | `DateTime` | `DateTimeUtc` | |
599/// | `Date` | `Date` | |
600/// | `Decimal` | `Decimal` | |
601/// | `Json` | `Json` | |
602/// | `Option<T>` | `Option<T>` | Nullable |
603/// | `Vec<Entity>` | - | has_many relationship |
604/// | `Entity` | - | belongs_to (generates FK) |
605#[proc_macro]
606pub fn schema(input: TokenStream) -> TokenStream {
607    schema::schema_impl(input.into()).into()
608}
609
610fn derive_config_impl(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream {
611    let input: syn::DeriveInput = syn::parse2(input).expect("expected struct");
612    let name = &input.ident;
613
614    let fields = match &input.data {
615        syn::Data::Struct(data) => match &data.fields {
616            syn::Fields::Named(fields) => &fields.named,
617            _ => panic!("Config derive only supports structs with named fields"),
618        },
619        _ => panic!("Config derive only supports structs"),
620    };
621
622    let mut field_inits = Vec::new();
623    let mut missing_checks = Vec::new();
624
625    for field in fields {
626        let field_name = field.ident.as_ref().unwrap();
627        let field_type = &field.ty;
628
629        // Find #[env = "VAR_NAME"] attribute
630        let env_var = field
631            .attrs
632            .iter()
633            .find_map(|attr| {
634                if attr.path().is_ident("env")
635                    && let syn::Meta::NameValue(nv) = &attr.meta
636                    && let syn::Expr::Lit(expr_lit) = &nv.value
637                    && let syn::Lit::Str(lit_str) = &expr_lit.lit
638                {
639                    return Some(lit_str.value());
640                }
641                None
642            })
643            .unwrap_or_else(|| field_name.to_string().to_uppercase());
644
645        // Find #[default = "value"] attribute
646        let default_value = field.attrs.iter().find_map(|attr| {
647            if attr.path().is_ident("default")
648                && let syn::Meta::NameValue(nv) = &attr.meta
649                && let syn::Expr::Lit(expr_lit) = &nv.value
650                && let syn::Lit::Str(lit_str) = &expr_lit.lit
651            {
652                return Some(lit_str.value());
653            }
654            None
655        });
656
657        let env_var_lit = syn::LitStr::new(&env_var, proc_macro2::Span::call_site());
658
659        if let Some(default) = default_value {
660            let default_lit = syn::LitStr::new(&default, proc_macro2::Span::call_site());
661            field_inits.push(quote! {
662                #field_name: rapina::config::get_env_or(#env_var_lit, #default_lit).parse().unwrap_or_else(|_| #default_lit.parse().unwrap())
663            });
664        } else {
665            field_inits.push(quote! {
666                #field_name: rapina::config::get_env_parsed::<#field_type>(#env_var_lit)?
667            });
668            missing_checks.push(quote! {
669                if std::env::var(#env_var_lit).is_err() {
670                    missing.push(#env_var_lit);
671                }
672            });
673        }
674    }
675
676    quote! {
677        impl #name {
678            pub fn from_env() -> std::result::Result<Self, rapina::config::ConfigError> {
679                let mut missing: Vec<&str> = Vec::new();
680                #(#missing_checks)*
681
682                if !missing.is_empty() {
683                    return Err(rapina::config::ConfigError::MissingMultiple(
684                        missing.into_iter().map(String::from).collect()
685                    ));
686                }
687
688                Ok(Self {
689                    #(#field_inits),*
690                })
691            }
692        }
693    }
694}
695
696#[cfg(test)]
697mod tests {
698    use super::{join_paths, relay_macro_impl, route_macro_core};
699    use quote::quote;
700
701    #[test]
702    fn test_generates_struct_with_handler_impl() {
703        let path = quote!("/");
704        let input = quote! {
705            async fn hello() -> &'static str {
706                "Hello, Rapina!"
707            }
708        };
709
710        let output = route_macro_core("GET", path, input);
711        let output_str = output.to_string();
712
713        // Check struct is generated
714        assert!(output_str.contains("struct hello"));
715        // Check Handler impl is generated
716        assert!(output_str.contains("impl rapina :: handler :: Handler for hello"));
717        // Check NAME constant
718        assert!(output_str.contains("const NAME"));
719        assert!(output_str.contains("\"hello\""));
720    }
721
722    #[test]
723    fn test_generates_handler_with_extractors() {
724        let path = quote!("/users/:id");
725        let input = quote! {
726            async fn get_user(id: rapina::extract::Path<u64>) -> String {
727                format!("{}", id.into_inner())
728            }
729        };
730
731        let output = route_macro_core("GET", path, input);
732        let output_str = output.to_string();
733
734        // Check struct is generated
735        assert!(output_str.contains("struct get_user"));
736        // Check extraction code is present
737        assert!(output_str.contains("FromRequestParts"));
738    }
739
740    #[test]
741    fn test_function_with_multiple_extractors() {
742        let path = quote!("/users");
743        let input = quote! {
744            async fn create_user(
745                id: rapina::extract::Path<u64>,
746                body: rapina::extract::Json<String>
747            ) -> String {
748                "created".to_string()
749            }
750        };
751
752        let output = route_macro_core("POST", path, input);
753        let output_str = output.to_string();
754
755        // Check struct is generated
756        assert!(output_str.contains("struct create_user"));
757        // Check both extractors are handled
758        assert!(output_str.contains("FromRequestParts"));
759        assert!(output_str.contains("FromRequest"));
760    }
761
762    #[test]
763    #[should_panic(expected = "Multiple body-consuming extractors are not supported")]
764    fn test_multiple_body_extractors_panics() {
765        let path = quote!("/users");
766        let input = quote! {
767            async fn handler(
768                body1: rapina::extract::Json<String>,
769                body2: rapina::extract::Json<String>
770            ) -> String {
771                "ok".to_string()
772            }
773        };
774
775        route_macro_core("POST", path, input);
776    }
777
778    #[test]
779    #[should_panic(expected = "expected function")]
780    fn test_invalid_input_panics() {
781        let path = quote!("/");
782        let invalid_input = quote! { not_a_function };
783
784        route_macro_core("GET", path, invalid_input);
785    }
786
787    #[test]
788    fn test_json_return_type_generates_response_schema() {
789        let path = quote!("/users");
790        let input = quote! {
791            async fn get_user() -> Json<UserResponse> {
792                Json(UserResponse { id: 1 })
793            }
794        };
795
796        let output = route_macro_core("GET", path, input);
797        let output_str = output.to_string();
798
799        // Check response_schema method is generated with schema_for!
800        assert!(output_str.contains("fn response_schema"));
801        assert!(output_str.contains("rapina :: schemars :: schema_for !"));
802        assert!(output_str.contains("UserResponse"));
803    }
804
805    #[test]
806    fn test_result_json_return_type_generates_response_schema() {
807        let path = quote!("/users");
808        let input = quote! {
809            async fn get_user() -> Result<Json<UserResponse>> {
810                Ok(Json(UserResponse { id: 1 }))
811            }
812        };
813
814        let output = route_macro_core("GET", path, input);
815        let output_str = output.to_string();
816
817        assert!(output_str.contains("fn response_schema"));
818        assert!(output_str.contains("rapina :: schemars :: schema_for !"));
819        assert!(output_str.contains("UserResponse"));
820    }
821
822    #[test]
823    fn test_errors_attr_generates_error_responses() {
824        let path = quote!("/users");
825        let input = quote! {
826            #[errors(UserError)]
827            async fn get_user() -> Result<Json<UserResponse>> {
828                Ok(Json(UserResponse { id: 1 }))
829            }
830        };
831
832        let output = route_macro_core("GET", path, input);
833        let output_str = output.to_string();
834
835        assert!(output_str.contains("fn error_responses"));
836        assert!(output_str.contains("DocumentedError"));
837        assert!(output_str.contains("UserError"));
838    }
839
840    #[test]
841    fn test_non_json_return_type_no_response_schema() {
842        let path = quote!("/health");
843        let input = quote! {
844            async fn health() -> &'static str {
845                "ok"
846            }
847        };
848
849        let output = route_macro_core("GET", path, input);
850        let output_str = output.to_string();
851
852        // Check response_schema method is NOT generated for non-Json types
853        assert!(!output_str.contains("fn response_schema"));
854        assert!(!output_str.contains("schema_for"));
855    }
856
857    #[test]
858    fn test_user_state_variable_not_shadowed() {
859        // Regression test for issue #134 - user naming their extractor 'state'
860        // should not conflict with internal macro variables
861        let path = quote!("/users");
862        let input = quote! {
863            async fn list_users(state: rapina::extract::State<MyState>) -> String {
864                "ok".to_string()
865            }
866        };
867
868        let output = route_macro_core("GET", path, input);
869        let output_str = output.to_string();
870
871        // Internal variables should use __rapina_ prefix
872        assert!(output_str.contains("__rapina_state"));
873        assert!(output_str.contains("__rapina_params"));
874        // User's variable 'state' should still be extracted
875        assert!(output_str.contains("let state ="));
876    }
877
878    #[test]
879    fn test_no_closure_wrapper_for_type_inference() {
880        // Regression test for issue #134 - Result type inference should work
881        let path = quote!("/users");
882        let input = quote! {
883            async fn get_user() -> Result<String, Error> {
884                Ok("user".to_string())
885            }
886        };
887
888        let output = route_macro_core("GET", path, input);
889        let output_str = output.to_string();
890
891        // Should NOT use closure wrapper (|| async ...)
892        assert!(!output_str.contains("|| async"));
893        // Should use typed result with async block (: ReturnType = (async ...).await)
894        assert!(output_str.contains("__rapina_result"));
895        assert!(output_str.contains("Result < String , Error >"));
896    }
897
898    #[test]
899    fn test_emits_route_descriptor() {
900        let path = quote!("/users");
901        let input = quote! {
902            async fn list_users() -> &'static str {
903                "users"
904            }
905        };
906
907        let output = route_macro_core("GET", path, input);
908        let output_str = output.to_string();
909
910        assert!(output_str.contains("inventory :: submit !"));
911        assert!(output_str.contains("RouteDescriptor"));
912        assert!(output_str.contains("method : \"GET\""));
913        assert!(output_str.contains("path : \"/users\""));
914        assert!(output_str.contains("handler_name : \"list_users\""));
915        assert!(output_str.contains("is_public : false"));
916        assert!(output_str.contains("__rapina_register_list_users"));
917    }
918
919    #[test]
920    fn test_emits_route_descriptor_with_method() {
921        let path = quote!("/users");
922        let input = quote! {
923            async fn create_user() -> &'static str {
924                "created"
925            }
926        };
927
928        let output = route_macro_core("POST", path, input);
929        let output_str = output.to_string();
930
931        assert!(output_str.contains("method : \"POST\""));
932        assert!(output_str.contains("__rapina_router . post"));
933    }
934
935    #[test]
936    fn test_public_attr_below_route_sets_is_public() {
937        let path = quote!("/health");
938        let input = quote! {
939            #[public]
940            async fn health() -> &'static str {
941                "ok"
942            }
943        };
944
945        let output = route_macro_core("GET", path, input);
946        let output_str = output.to_string();
947
948        assert!(output_str.contains("is_public : true"));
949    }
950
951    #[test]
952    fn test_cache_attr_injects_ttl_header() {
953        let path = quote!("/products");
954        let input = quote! {
955            #[cache(ttl = 60)]
956            async fn list_products() -> &'static str {
957                "products"
958            }
959        };
960
961        let output = route_macro_core("GET", path, input);
962        let output_str = output.to_string();
963
964        assert!(output_str.contains("x-rapina-cache-ttl"));
965        assert!(output_str.contains("60"));
966    }
967
968    #[test]
969    fn test_relay_macro_generates_wrapper_and_inventory() {
970        let attr = quote!("room:*");
971        let input = quote! {
972            async fn room(event: rapina::relay::RelayEvent, relay: rapina::relay::Relay) -> Result<(), rapina::error::Error> {
973                Ok(())
974            }
975        };
976
977        let output = relay_macro_impl(attr, input);
978        let output_str = output.to_string();
979
980        // Original function is preserved
981        assert!(output_str.contains("async fn room"));
982        // Wrapper function is generated
983        assert!(output_str.contains("__rapina_channel_room"));
984        // Inventory submission
985        assert!(output_str.contains("inventory :: submit !"));
986        assert!(output_str.contains("ChannelDescriptor"));
987        assert!(output_str.contains("pattern : \"room:*\""));
988        assert!(output_str.contains("is_prefix : true"));
989        assert!(output_str.contains("match_prefix : \"room:\""));
990        assert!(output_str.contains("handler_name : \"room\""));
991    }
992
993    #[test]
994    fn test_relay_macro_exact_match() {
995        let attr = quote!("chat:lobby");
996        let input = quote! {
997            async fn lobby(event: rapina::relay::RelayEvent) -> Result<(), rapina::error::Error> {
998                Ok(())
999            }
1000        };
1001
1002        let output = relay_macro_impl(attr, input);
1003        let output_str = output.to_string();
1004
1005        assert!(output_str.contains("is_prefix : false"));
1006        assert!(output_str.contains("match_prefix : \"chat:lobby\""));
1007    }
1008
1009    #[test]
1010    fn test_relay_macro_extracts_additional_params() {
1011        let attr = quote!("room:*");
1012        let input = quote! {
1013            async fn room(
1014                event: rapina::relay::RelayEvent,
1015                relay: rapina::relay::Relay,
1016                log: rapina::extract::State<TestLog>,
1017            ) -> Result<(), rapina::error::Error> {
1018                Ok(())
1019            }
1020        };
1021
1022        let output = relay_macro_impl(attr, input);
1023        let output_str = output.to_string();
1024
1025        // Both extractors should use FromRequestParts
1026        assert!(output_str.contains("let relay ="));
1027        assert!(output_str.contains("let log ="));
1028        assert!(output_str.contains("FromRequestParts"));
1029    }
1030
1031    #[test]
1032    fn test_no_cache_attr_no_ttl_header() {
1033        let path = quote!("/products");
1034        let input = quote! {
1035            async fn list_products() -> &'static str {
1036                "products"
1037            }
1038        };
1039
1040        let output = route_macro_core("GET", path, input);
1041        let output_str = output.to_string();
1042
1043        assert!(!output_str.contains("x-rapina-cache-ttl"));
1044    }
1045
1046    #[test]
1047    fn test_cache_attr_with_extractors() {
1048        let path = quote!("/users/:id");
1049        let input = quote! {
1050            #[cache(ttl = 120)]
1051            async fn get_user(id: rapina::extract::Path<u64>) -> String {
1052                format!("{}", id.into_inner())
1053            }
1054        };
1055
1056        let output = route_macro_core("GET", path, input);
1057        let output_str = output.to_string();
1058
1059        assert!(output_str.contains("x-rapina-cache-ttl"));
1060        assert!(output_str.contains("120"));
1061        assert!(output_str.contains("FromRequestParts"));
1062    }
1063
1064    #[test]
1065    fn test_group_param_joins_path() {
1066        let attr = quote!("/users", group = "/api");
1067        let input = quote! {
1068            async fn list_users() -> &'static str {
1069                "users"
1070            }
1071        };
1072
1073        let output = route_macro_core("GET", attr, input);
1074        let output_str = output.to_string();
1075
1076        assert!(output_str.contains("path : \"/api/users\""));
1077        assert!(output_str.contains("__rapina_router . get (\"/api/users\""));
1078    }
1079
1080    #[test]
1081    fn test_group_param_with_nested_prefix() {
1082        let attr = quote!("/items", group = "/api/v1");
1083        let input = quote! {
1084            async fn list_items() -> &'static str {
1085                "items"
1086            }
1087        };
1088
1089        let output = route_macro_core("GET", attr, input);
1090        let output_str = output.to_string();
1091
1092        assert!(output_str.contains("path : \"/api/v1/items\""));
1093    }
1094
1095    #[test]
1096    fn test_without_group_param_backward_compatible() {
1097        let attr = quote!("/users");
1098        let input = quote! {
1099            async fn list_users() -> &'static str {
1100                "users"
1101            }
1102        };
1103
1104        let output = route_macro_core("GET", attr, input);
1105        let output_str = output.to_string();
1106
1107        assert!(output_str.contains("path : \"/users\""));
1108        assert!(output_str.contains("__rapina_router . get (\"/users\""));
1109    }
1110
1111    #[test]
1112    #[should_panic(expected = "group prefix must start with `/`")]
1113    fn test_group_prefix_must_start_with_slash() {
1114        let attr = quote!("/users", group = "api");
1115        let input = quote! {
1116            async fn list_users() -> &'static str {
1117                "users"
1118            }
1119        };
1120
1121        route_macro_core("GET", attr, input);
1122    }
1123
1124    #[test]
1125    fn test_group_with_trailing_slash_normalized() {
1126        let attr = quote!("/users", group = "/api/");
1127        let input = quote! {
1128            async fn list_users() -> &'static str {
1129                "users"
1130            }
1131        };
1132
1133        let output = route_macro_core("GET", attr, input);
1134        let output_str = output.to_string();
1135
1136        assert!(output_str.contains("path : \"/api/users\""));
1137    }
1138
1139    #[test]
1140    fn test_group_with_public_attr() {
1141        let attr = quote!("/health", group = "/api");
1142        let input = quote! {
1143            #[public]
1144            async fn health() -> &'static str {
1145                "ok"
1146            }
1147        };
1148
1149        let output = route_macro_core("GET", attr, input);
1150        let output_str = output.to_string();
1151
1152        assert!(output_str.contains("path : \"/api/health\""));
1153        assert!(output_str.contains("is_public : true"));
1154    }
1155
1156    #[test]
1157    fn test_group_with_cache_attr() {
1158        let attr = quote!("/products", group = "/api");
1159        let input = quote! {
1160            #[cache(ttl = 60)]
1161            async fn list_products() -> &'static str {
1162                "products"
1163            }
1164        };
1165
1166        let output = route_macro_core("GET", attr, input);
1167        let output_str = output.to_string();
1168
1169        assert!(output_str.contains("path : \"/api/products\""));
1170        assert!(output_str.contains("x-rapina-cache-ttl"));
1171        assert!(output_str.contains("60"));
1172    }
1173
1174    #[test]
1175    fn test_group_with_errors_attr() {
1176        let attr = quote!("/users", group = "/api");
1177        let input = quote! {
1178            #[errors(UserError)]
1179            async fn get_user() -> Result<Json<UserResponse>> {
1180                Ok(Json(UserResponse { id: 1 }))
1181            }
1182        };
1183
1184        let output = route_macro_core("GET", attr, input);
1185        let output_str = output.to_string();
1186
1187        assert!(output_str.contains("path : \"/api/users\""));
1188        assert!(output_str.contains("fn error_responses"));
1189        assert!(output_str.contains("UserError"));
1190    }
1191
1192    #[test]
1193    fn test_group_with_all_methods() {
1194        for method in &["GET", "POST", "PUT", "DELETE"] {
1195            let attr = quote!("/items", group = "/api");
1196            let input = quote! {
1197                async fn handler() -> &'static str {
1198                    "ok"
1199                }
1200            };
1201
1202            let output = route_macro_core(method, attr, input);
1203            let output_str = output.to_string();
1204
1205            assert!(
1206                output_str.contains("path : \"/api/items\""),
1207                "{method} should produce /api/items"
1208            );
1209            let method_lower = method.to_lowercase();
1210            assert!(
1211                output_str.contains(&format!("__rapina_router . {method_lower}")),
1212                "{method} should use .{method_lower}() on router"
1213            );
1214        }
1215    }
1216
1217    #[test]
1218    fn test_join_paths_basic() {
1219        assert_eq!(join_paths("/api", "/users"), "/api/users");
1220        assert_eq!(join_paths("/api/v1", "/items"), "/api/v1/items");
1221    }
1222
1223    #[test]
1224    fn test_join_paths_trailing_slash() {
1225        assert_eq!(join_paths("/api/", "/users"), "/api/users");
1226    }
1227
1228    #[test]
1229    fn test_join_paths_empty_path() {
1230        assert_eq!(join_paths("/api", ""), "/api");
1231        assert_eq!(join_paths("/api", "/"), "/api");
1232    }
1233
1234    #[test]
1235    fn test_join_paths_empty_prefix() {
1236        assert_eq!(join_paths("", "/users"), "/users");
1237        assert_eq!(join_paths("", ""), "/");
1238    }
1239}