rsrpc_macro/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::{format_ident, quote};
4use syn::{
5    parse_macro_input, Attribute, FnArg, Ident, ItemTrait, Pat, ReturnType, TraitItem, TraitItemFn,
6    Type,
7};
8
9// =============================================================================
10// HTTP METHOD TYPES
11// =============================================================================
12
13/// HTTP method for REST endpoints
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15enum HttpMethod {
16    Get,
17    Post,
18    Put,
19    Patch,
20    Delete,
21}
22
23impl HttpMethod {
24    fn from_str(s: &str) -> Option<Self> {
25        match s.to_lowercase().as_str() {
26            "get" => Some(Self::Get),
27            "post" => Some(Self::Post),
28            "put" => Some(Self::Put),
29            "patch" => Some(Self::Patch),
30            "delete" => Some(Self::Delete),
31            _ => None,
32        }
33    }
34
35    fn to_tokens(&self) -> TokenStream2 {
36        match self {
37            Self::Get => quote!(::rsrpc::http::Method::GET),
38            Self::Post => quote!(::rsrpc::http::Method::POST),
39            Self::Put => quote!(::rsrpc::http::Method::PUT),
40            Self::Patch => quote!(::rsrpc::http::Method::PATCH),
41            Self::Delete => quote!(::rsrpc::http::Method::DELETE),
42        }
43    }
44
45    fn to_axum_method(&self) -> TokenStream2 {
46        match self {
47            Self::Get => quote!(::rsrpc::axum::routing::get),
48            Self::Post => quote!(::rsrpc::axum::routing::post),
49            Self::Put => quote!(::rsrpc::axum::routing::put),
50            Self::Patch => quote!(::rsrpc::axum::routing::patch),
51            Self::Delete => quote!(::rsrpc::axum::routing::delete),
52        }
53    }
54}
55
56/// Parsed HTTP route information
57#[derive(Debug, Clone)]
58struct HttpRoute {
59    method: HttpMethod,
60    path: String,              // Original path for URL building
61    axum_path: String,         // Axum-compatible path (uses :param instead of {param})
62    path_params: Vec<String>,  // Parameters extracted from {param} in path
63    query_params: Vec<String>, // Parameters extracted from ?param
64}
65
66/// Parse route string like "/logs/{id}/?limit" into components
67fn parse_route(path: &str, method: HttpMethod) -> HttpRoute {
68    let mut path_params = Vec::new();
69    let mut query_params = Vec::new();
70    let mut clean_path = String::new();
71    let mut axum_path = String::new();
72
73    // Split path from query params (marked with /?)
74    let (path_part, query_part) = if let Some(idx) = path.find("/?") {
75        (&path[..idx], Some(&path[idx + 2..]))
76    } else {
77        (path, None)
78    };
79
80    // Parse path parameters {param}
81    let mut chars = path_part.chars().peekable();
82    while let Some(c) = chars.next() {
83        if c == '{' {
84            let mut param = String::new();
85            while let Some(&pc) = chars.peek() {
86                chars.next();
87                if pc == '}' {
88                    break;
89                }
90                param.push(pc);
91            }
92            path_params.push(param.clone());
93            clean_path.push_str(&format!("{{{}}}", param));
94            axum_path.push(':');
95            axum_path.push_str(&param);
96        } else {
97            clean_path.push(c);
98            axum_path.push(c);
99        }
100    }
101
102    // Parse query parameters ?param1&param2 or ?param1/?param2
103    if let Some(query) = query_part {
104        for part in query.split(&['?', '&', '/'][..]) {
105            let param = part.trim();
106            if !param.is_empty() {
107                query_params.push(param.to_string());
108            }
109        }
110    }
111
112    HttpRoute {
113        method,
114        path: clean_path,
115        axum_path,
116        path_params,
117        query_params,
118    }
119}
120
121/// Extract HTTP route from method attributes
122fn parse_http_attrs(attrs: &[Attribute]) -> syn::Result<Option<HttpRoute>> {
123    for attr in attrs {
124        let path = attr.path();
125        if path.segments.len() == 1 {
126            let method_name = path.segments[0].ident.to_string();
127            if let Some(method) = HttpMethod::from_str(&method_name) {
128                // Parse the route string from attribute like #[get("/path")]
129                let route_str: syn::LitStr = attr.parse_args()?;
130                return Ok(Some(parse_route(&route_str.value(), method)));
131            }
132        }
133    }
134    Ok(None)
135}
136
137/// Marks a trait as an RPC service.
138///
139/// This macro generates:
140/// - Per-method request structs (private)
141/// - `impl Trait for Client<dyn Trait>` so clients can call methods directly
142/// - `<dyn Trait>::serve(impl)` method to create a server
143///
144/// # Example
145///
146/// ```ignore
147/// #[rsrpc::service]
148/// pub trait Worker: Send + Sync + 'static {
149///     async fn run_task(&self, task: Task) -> Result<Output, Error>;
150///     async fn status(&self) -> WorkerStatus;
151/// }
152///
153/// // Client: impl Worker for Client<dyn Worker>
154/// // Server: <dyn Worker>::serve(my_impl)
155/// ```
156///
157/// # Streaming
158///
159/// Methods returning `Result<RpcStream<T>>` are automatically handled as
160/// server-side streaming - no special annotation needed:
161///
162/// ```ignore
163/// #[rsrpc::service]
164/// pub trait LogService: Send + Sync + 'static {
165///     async fn stream_logs(&self, filter: Filter) -> Result<RpcStream<LogEntry>>;
166/// }
167/// ```
168///
169/// # Local (non-Send) services
170///
171/// Use `#[rsrpc::service(?Send)]` to allow non-Send futures. This is useful
172/// when your async methods hold non-Send types across await points:
173///
174/// ```ignore
175/// #[rsrpc::service(?Send)]
176/// pub trait LocalWorker: 'static {
177///     async fn run_task(&self, task: Task) -> Result<Output, Error>;
178/// }
179/// ```
180///
181/// Note: `?Send` services require a single-threaded runtime for the server.
182#[proc_macro_attribute]
183pub fn service(attr: TokenStream, item: TokenStream) -> TokenStream {
184    let trait_def = parse_macro_input!(item as ItemTrait);
185    let not_send = attr.to_string().contains("?Send");
186    match generate_service(&trait_def, not_send) {
187        Ok(tokens) => tokens.into(),
188        Err(e) => e.to_compile_error().into(),
189    }
190}
191
192fn generate_service(trait_def: &ItemTrait, not_send: bool) -> syn::Result<TokenStream2> {
193    let trait_name = &trait_def.ident;
194    let trait_vis = &trait_def.vis;
195    let trait_name_lower = to_snake_case(&trait_name.to_string());
196    let mod_name = format_ident!("__{}_rpc_impl", trait_name_lower);
197
198    // Collect method info
199    let methods: Vec<MethodInfo> = trait_def
200        .items
201        .iter()
202        .filter_map(|item| {
203            if let TraitItem::Fn(method) = item {
204                Some(parse_method(method))
205            } else {
206                None
207            }
208        })
209        .collect::<syn::Result<Vec<_>>>()?;
210
211    // Generate request structs for each method
212    let request_structs: Vec<TokenStream2> =
213        methods.iter().map(|m| generate_request_struct(m)).collect();
214
215    // Generate method ID constants
216    let method_ids: Vec<TokenStream2> = methods
217        .iter()
218        .enumerate()
219        .map(|(idx, m)| {
220            let const_name = format_ident!("{}_METHOD_ID", m.name.to_string().to_uppercase());
221            let idx = idx as u16;
222            quote! {
223                const #const_name: u16 = #idx;
224            }
225        })
226        .collect();
227
228    // Generate Client<dyn Trait> impl
229    let client_impl_methods: Vec<TokenStream2> = methods
230        .iter()
231        .enumerate()
232        .map(|(idx, m)| generate_client_method(m, idx as u16, trait_name))
233        .collect();
234
235    // Generate dispatch match arms
236    let dispatch_arms: Vec<TokenStream2> = methods
237        .iter()
238        .enumerate()
239        .map(|(idx, m)| generate_dispatch_arm(m, idx as u16))
240        .collect();
241
242    // Keep original trait but add async_trait
243    let _trait_items = &trait_def.items;
244    let trait_supertraits = &trait_def.supertraits;
245    let trait_generics = &trait_def.generics;
246
247    // Choose async_trait attribute based on Send requirement
248    let async_trait_attr = if not_send {
249        quote! { #[::rsrpc::async_trait(?Send)] }
250    } else {
251        quote! { #[::rsrpc::async_trait] }
252    };
253
254    // Dispatch function return type depends on Send requirement
255    let dispatch_fn = if not_send {
256        quote! {
257            pub fn dispatch<'a, T: #trait_name + ?Sized>(
258                service: &'a T,
259                method_id: u16,
260                payload: &'a [u8],
261            ) -> ::std::pin::Pin<::std::boxed::Box<dyn ::std::future::Future<Output = ::rsrpc::DispatchResult> + 'a>> {
262                Box::pin(async move {
263                    match method_id {
264                        #(#dispatch_arms)*
265                        _ => ::rsrpc::DispatchResult::Error(format!("Unknown method ID: {}", method_id)),
266                    }
267                })
268            }
269        }
270    } else {
271        quote! {
272            pub fn dispatch<'a, T: #trait_name + ?Sized>(
273                service: &'a T,
274                method_id: u16,
275                payload: &'a [u8],
276            ) -> ::std::pin::Pin<::std::boxed::Box<dyn ::std::future::Future<Output = ::rsrpc::DispatchResult> + Send + 'a>> {
277                Box::pin(async move {
278                    match method_id {
279                        #(#dispatch_arms)*
280                        _ => ::rsrpc::DispatchResult::Error(format!("Unknown method ID: {}", method_id)),
281                    }
282                })
283            }
284        }
285    };
286
287    // Generate HTTP client methods
288    // Methods with HTTP attributes get real implementations
289    // Methods without HTTP attributes get panic stubs
290    let http_client_methods: Vec<TokenStream2> = methods
291        .iter()
292        .map(|m| {
293            if let Some(tokens) = generate_http_client_method(m, trait_name) {
294                tokens
295            } else {
296                // Generate panic stub for non-HTTP methods
297                let name = &m.name;
298                let return_type = &m.return_type;
299                let arg_decls: Vec<TokenStream2> = m
300                    .args
301                    .iter()
302                    .map(|(name, ty)| quote! { #name: #ty })
303                    .collect();
304                let method_name = name.to_string();
305                quote! {
306                    async fn #name(&self, #(#arg_decls),*) -> #return_type {
307                        panic!(
308                            "Method '{}' is not available over HTTP. Use RPC client instead.",
309                            #method_name
310                        )
311                    }
312                }
313            }
314        })
315        .collect();
316
317    // Generate HTTP server routes
318    let http_routes = generate_http_routes(&methods, trait_name, trait_vis);
319
320    // Check if any methods have HTTP attributes
321    let has_http_methods = methods.iter().any(|m| m.http.is_some());
322
323    // HTTP client impl (only if there are HTTP methods)
324    let http_client_impl = if !has_http_methods {
325        quote! {}
326    } else {
327        quote! {
328            #[cfg(feature = "http")]
329            #async_trait_attr
330            impl #trait_name for ::rsrpc::HttpClient<dyn #trait_name> {
331                #(#http_client_methods)*
332            }
333        }
334    };
335
336    // Strip HTTP attributes from trait items for the output
337    let clean_trait_items: Vec<TokenStream2> = trait_def
338        .items
339        .iter()
340        .map(|item| {
341            if let TraitItem::Fn(method) = item {
342                let mut clean_method = method.clone();
343                // Remove HTTP method attributes (get, post, put, patch, delete)
344                clean_method.attrs.retain(|attr| {
345                    let path = attr.path();
346                    if path.segments.len() == 1 {
347                        let name = path.segments[0].ident.to_string().to_lowercase();
348                        !matches!(name.as_str(), "get" | "post" | "put" | "patch" | "delete")
349                    } else {
350                        true
351                    }
352                });
353                quote! { #clean_method }
354            } else {
355                quote! { #item }
356            }
357        })
358        .collect();
359
360    Ok(quote! {
361        #async_trait_attr
362        #trait_vis trait #trait_name #trait_generics : #trait_supertraits {
363            #(#clean_trait_items)*
364        }
365
366        #[doc(hidden)]
367        mod #mod_name {
368            use super::*;
369            use ::rsrpc::serde::{Serialize, Deserialize};
370
371            #(#request_structs)*
372            #(#method_ids)*
373
374            // Dispatch function for the server
375            #dispatch_fn
376        }
377
378        #async_trait_attr
379        impl #trait_name for ::rsrpc::Client<dyn #trait_name> {
380            #(#client_impl_methods)*
381        }
382
383        #http_client_impl
384
385        /// Extension trait for creating servers from service implementations.
386        impl dyn #trait_name {
387            /// Create a server that hosts this service.
388            #trait_vis fn serve<T: #trait_name + 'static>(service: T) -> ::rsrpc::Server<dyn #trait_name> {
389                let service: ::std::sync::Arc<dyn #trait_name> = ::std::sync::Arc::new(service);
390                ::rsrpc::Server::from_arc(service, #mod_name::dispatch)
391            }
392        }
393
394        #http_routes
395    })
396}
397
398struct MethodInfo {
399    name: Ident,
400    args: Vec<(Ident, Type)>, // (name, type) excluding self
401    return_type: Type,
402    http: Option<HttpRoute>, // HTTP route info if method has #[get], #[post], etc.
403}
404
405fn parse_method(method: &TraitItemFn) -> syn::Result<MethodInfo> {
406    let name = method.sig.ident.clone();
407
408    // Extract HTTP route from attributes like #[get("/path")]
409    let http = parse_http_attrs(&method.attrs)?;
410
411    // Extract args, skipping self
412    let args: Vec<(Ident, Type)> = method
413        .sig
414        .inputs
415        .iter()
416        .filter_map(|arg| {
417            if let FnArg::Typed(pat_type) = arg {
418                if let Pat::Ident(pat_ident) = &*pat_type.pat {
419                    return Some((pat_ident.ident.clone(), (*pat_type.ty).clone()));
420                }
421            }
422            None
423        })
424        .collect();
425
426    // Extract return type
427    let return_type = match &method.sig.output {
428        ReturnType::Default => syn::parse_quote!(()),
429        ReturnType::Type(_, ty) => (**ty).clone(),
430    };
431
432    Ok(MethodInfo {
433        name,
434        args,
435        return_type,
436        http,
437    })
438}
439
440fn generate_request_struct(method: &MethodInfo) -> TokenStream2 {
441    // Use double underscore prefix to avoid conflicts with user types
442    let struct_name = format_ident!("__{}Request", to_pascal_case(&method.name.to_string()));
443    let fields: Vec<TokenStream2> = method
444        .args
445        .iter()
446        .map(|(name, ty)| {
447            quote! { pub #name: #ty }
448        })
449        .collect();
450
451    if fields.is_empty() {
452        quote! {
453            #[derive(Serialize, Deserialize)]
454            struct #struct_name;
455        }
456    } else {
457        quote! {
458            #[derive(Serialize, Deserialize)]
459            struct #struct_name {
460                #(#fields),*
461            }
462        }
463    }
464}
465
466fn generate_client_method(method: &MethodInfo, method_id: u16, trait_name: &Ident) -> TokenStream2 {
467    let name = &method.name;
468    let return_type = &method.return_type;
469
470    let arg_names: Vec<&Ident> = method.args.iter().map(|(n, _)| n).collect();
471    let arg_decls: Vec<TokenStream2> = method
472        .args
473        .iter()
474        .map(|(name, ty)| quote! { #name: #ty })
475        .collect();
476
477    // Use trait-based dispatch - ClientEncoding will select the right behavior
478    // based on whether the return type is Result<T> or Result<RpcStream<T>>
479    if arg_names.is_empty() {
480        quote! {
481            async fn #name(&self) -> #return_type {
482                <#return_type as ::rsrpc::ClientEncoding<dyn #trait_name>>::invoke(
483                    self,
484                    #method_id,
485                    &(),
486                ).await
487            }
488        }
489    } else {
490        let request_fields: Vec<TokenStream2> = method
491            .args
492            .iter()
493            .map(|(name, ty)| quote! { #name: #ty })
494            .collect();
495
496        quote! {
497            async fn #name(&self, #(#arg_decls),*) -> #return_type {
498                #[derive(::rsrpc::serde::Serialize)]
499                struct __Request { #(#request_fields),* }
500                <#return_type as ::rsrpc::ClientEncoding<dyn #trait_name>>::invoke(
501                    self,
502                    #method_id,
503                    &__Request { #(#arg_names),* },
504                ).await
505            }
506        }
507    }
508}
509
510fn generate_dispatch_arm(method: &MethodInfo, method_id: u16) -> TokenStream2 {
511    let name = &method.name;
512    let request_struct = format_ident!("__{}Request", to_pascal_case(&name.to_string()));
513    let arg_names: Vec<&Ident> = method.args.iter().map(|(n, _)| n).collect();
514
515    // Use trait-based dispatch - ServerEncoding will select the right behavior
516    // based on whether the return type is Result<T> or Result<RpcStream<T>>
517    if arg_names.is_empty() {
518        quote! {
519            #method_id => {
520                let result = service.#name().await;
521                ::rsrpc::ServerEncoding::into_dispatch(result)
522            }
523        }
524    } else {
525        quote! {
526            #method_id => {
527                let req: #request_struct = match ::rsrpc::postcard::from_bytes(payload) {
528                    Ok(r) => r,
529                    Err(e) => return ::rsrpc::DispatchResult::Error(e.to_string()),
530                };
531                let result = service.#name(#(req.#arg_names),*).await;
532                ::rsrpc::ServerEncoding::into_dispatch(result)
533            }
534        }
535    }
536}
537
538fn to_snake_case(s: &str) -> String {
539    let mut result = String::new();
540    for (i, c) in s.chars().enumerate() {
541        if c.is_uppercase() {
542            if i > 0 {
543                result.push('_');
544            }
545            result.push(c.to_lowercase().next().unwrap());
546        } else {
547            result.push(c);
548        }
549    }
550    result
551}
552
553fn to_pascal_case(s: &str) -> String {
554    let mut result = String::new();
555    let mut capitalize_next = true;
556    for c in s.chars() {
557        if c == '_' {
558            capitalize_next = true;
559        } else if capitalize_next {
560            result.push(c.to_uppercase().next().unwrap());
561            capitalize_next = false;
562        } else {
563            result.push(c);
564        }
565    }
566    result
567}
568
569// =============================================================================
570// HTTP CODE GENERATION
571// =============================================================================
572
573/// Generate HTTP client method implementation
574fn generate_http_client_method(method: &MethodInfo, _trait_name: &Ident) -> Option<TokenStream2> {
575    let route = method.http.as_ref()?;
576    let name = &method.name;
577    let return_type = &method.return_type;
578
579    let arg_decls: Vec<TokenStream2> = method
580        .args
581        .iter()
582        .map(|(name, ty)| quote! { #name: #ty })
583        .collect();
584
585    let http_method = route.method.to_tokens();
586
587    // Build path with parameter substitution
588    let path_template = &route.path;
589    let path_params: Vec<&Ident> = method
590        .args
591        .iter()
592        .filter(|(n, _)| route.path_params.contains(&n.to_string()))
593        .map(|(n, _)| n)
594        .collect();
595
596    // Query params
597    let query_params: Vec<&Ident> = method
598        .args
599        .iter()
600        .filter(|(n, _)| route.query_params.contains(&n.to_string()))
601        .map(|(n, _)| n)
602        .collect();
603
604    // Body params (everything not in path or query)
605    let body_params: Vec<(&Ident, &Type)> = method
606        .args
607        .iter()
608        .filter(|(n, _)| {
609            !route.path_params.contains(&n.to_string())
610                && !route.query_params.contains(&n.to_string())
611        })
612        .map(|(n, t)| (n, t))
613        .collect();
614
615    // Generate path building code
616    let path_build = if path_params.is_empty() {
617        quote! { #path_template.to_string() }
618    } else {
619        // Replace {param} with actual values
620        let mut format_str = path_template.clone();
621        let mut format_args = Vec::new();
622        for param in &path_params {
623            let param_str = param.to_string();
624            format_str = format_str.replace(&format!("{{{}}}", param_str), "{}");
625            format_args.push(quote! { #param });
626        }
627        quote! { format!(#format_str, #(#format_args),*) }
628    };
629
630    // Generate query string
631    let query_build = if query_params.is_empty() {
632        quote! { Vec::<(&str, String)>::new() }
633    } else {
634        let query_pairs: Vec<TokenStream2> = query_params
635            .iter()
636            .map(|p| {
637                let p_str = p.to_string();
638                quote! { (#p_str, #p.to_string()) }
639            })
640            .collect();
641        quote! { vec![#(#query_pairs),*] }
642    };
643
644    // Generate body
645    let body_build = if body_params.is_empty() {
646        quote! { None::<&()> }
647    } else if body_params.len() == 1 {
648        let (body_name, _) = body_params[0];
649        quote! { Some(&#body_name) }
650    } else {
651        // Multiple body params - wrap in anonymous struct
652        let body_fields: Vec<TokenStream2> =
653            body_params.iter().map(|(n, t)| quote! { #n: #t }).collect();
654        let body_values: Vec<&Ident> = body_params.iter().map(|(n, _)| *n).collect();
655        quote! {
656            {
657                #[derive(::rsrpc::serde::Serialize)]
658                struct __Body { #(#body_fields),* }
659                Some(&__Body { #(#body_values),* })
660            }
661        }
662    };
663
664    Some(quote! {
665        async fn #name(&self, #(#arg_decls),*) -> #return_type {
666            let path = #path_build;
667            let query = #query_build;
668            let body = #body_build;
669            self.request(#http_method, &path, &query, body).await
670        }
671    })
672}
673
674/// Generate query struct for axum handler
675fn generate_query_struct(method: &MethodInfo) -> Option<TokenStream2> {
676    let route = method.http.as_ref()?;
677    if route.query_params.is_empty() {
678        return None;
679    }
680
681    let struct_name = format_ident!("__{}Query", to_pascal_case(&method.name.to_string()));
682    let fields: Vec<TokenStream2> = method
683        .args
684        .iter()
685        .filter(|(n, _)| route.query_params.contains(&n.to_string()))
686        .map(|(name, ty)| quote! { pub #name: #ty })
687        .collect();
688
689    Some(quote! {
690        #[derive(::rsrpc::serde::Deserialize)]
691        struct #struct_name {
692            #(#fields),*
693        }
694    })
695}
696
697/// Generate HTTP server handler for a method
698fn generate_http_handler(method: &MethodInfo, trait_name: &Ident) -> Option<TokenStream2> {
699    let route = method.http.as_ref()?;
700    let name = &method.name;
701    let handler_name = format_ident!("__{}_handler", name);
702
703    // Path params extraction
704    let path_param_types: Vec<&Type> = method
705        .args
706        .iter()
707        .filter(|(n, _)| route.path_params.contains(&n.to_string()))
708        .map(|(_, t)| t)
709        .collect();
710
711    let path_param_names: Vec<&Ident> = method
712        .args
713        .iter()
714        .filter(|(n, _)| route.path_params.contains(&n.to_string()))
715        .map(|(n, _)| n)
716        .collect();
717
718    // Query params
719    let query_struct_name = format_ident!("__{}Query", to_pascal_case(&method.name.to_string()));
720    let has_query = !route.query_params.is_empty();
721    let query_param_names: Vec<&Ident> = method
722        .args
723        .iter()
724        .filter(|(n, _)| route.query_params.contains(&n.to_string()))
725        .map(|(n, _)| n)
726        .collect();
727
728    // Body params
729    let body_params: Vec<(&Ident, &Type)> = method
730        .args
731        .iter()
732        .filter(|(n, _)| {
733            !route.path_params.contains(&n.to_string())
734                && !route.query_params.contains(&n.to_string())
735        })
736        .map(|(n, t)| (n, t))
737        .collect();
738
739    // Build extractor list
740    let mut extractors = Vec::new();
741    let mut call_args = Vec::new();
742
743    // State extractor (always first)
744    extractors.push(quote! {
745        ::rsrpc::axum::extract::State(service): ::rsrpc::axum::extract::State<::std::sync::Arc<T>>
746    });
747
748    // Path extractor
749    if !path_param_names.is_empty() {
750        if path_param_names.len() == 1 {
751            let ty = path_param_types[0];
752            let name = path_param_names[0];
753            extractors.push(quote! {
754                ::rsrpc::axum::extract::Path(#name): ::rsrpc::axum::extract::Path<#ty>
755            });
756            call_args.push(quote! { #name });
757        } else {
758            let names = &path_param_names;
759            let types = &path_param_types;
760            extractors.push(quote! {
761                ::rsrpc::axum::extract::Path((#(#names),*)): ::rsrpc::axum::extract::Path<(#(#types),*)>
762            });
763            for name in &path_param_names {
764                call_args.push(quote! { #name });
765            }
766        }
767    }
768
769    // Query extractor
770    if has_query {
771        extractors.push(quote! {
772            ::rsrpc::axum::extract::Query(query): ::rsrpc::axum::extract::Query<#query_struct_name>
773        });
774        for name in &query_param_names {
775            call_args.push(quote! { query.#name });
776        }
777    }
778
779    // Body extractor
780    if !body_params.is_empty() {
781        if body_params.len() == 1 {
782            let (name, ty) = body_params[0];
783            extractors.push(quote! {
784                ::rsrpc::axum::extract::Json(#name): ::rsrpc::axum::extract::Json<#ty>
785            });
786            call_args.push(quote! { #name });
787        } else {
788            let body_struct_name =
789                format_ident!("__{}Body", to_pascal_case(&method.name.to_string()));
790            let body_names: Vec<&Ident> = body_params.iter().map(|(n, _)| *n).collect();
791
792            // The struct definition is generated in generate_http_routes
793            extractors.push(quote! {
794                ::rsrpc::axum::extract::Json(body): ::rsrpc::axum::extract::Json<#body_struct_name>
795            });
796            for name in &body_names {
797                call_args.push(quote! { body.#name });
798            }
799        }
800    }
801
802    Some(quote! {
803        async fn #handler_name<T: #trait_name>(
804            #(#extractors),*
805        ) -> impl ::rsrpc::axum::response::IntoResponse {
806            match service.#name(#(#call_args),*).await {
807                Ok(v) => ::rsrpc::axum::Json(v).into_response(),
808                Err(e) => (
809                    ::rsrpc::axum::http::StatusCode::INTERNAL_SERVER_ERROR,
810                    e.to_string()
811                ).into_response(),
812            }
813        }
814    })
815}
816
817/// Generate the http_routes method that returns an axum Router
818fn generate_http_routes(
819    methods: &[MethodInfo],
820    trait_name: &Ident,
821    trait_vis: &syn::Visibility,
822) -> TokenStream2 {
823    let http_methods: Vec<&MethodInfo> = methods.iter().filter(|m| m.http.is_some()).collect();
824
825    if http_methods.is_empty() {
826        return quote! {};
827    }
828
829    // Generate query structs
830    let query_structs: Vec<TokenStream2> = http_methods
831        .iter()
832        .filter_map(|m| generate_query_struct(m))
833        .collect();
834
835    // Generate body structs for multi-param bodies
836    let body_structs: Vec<TokenStream2> = http_methods
837        .iter()
838        .filter_map(|m| {
839            let route = m.http.as_ref()?;
840            let body_params: Vec<(&Ident, &Type)> = m
841                .args
842                .iter()
843                .filter(|(n, _)| {
844                    !route.path_params.contains(&n.to_string())
845                        && !route.query_params.contains(&n.to_string())
846                })
847                .map(|(n, t)| (n, t))
848                .collect();
849
850            if body_params.len() > 1 {
851                let struct_name = format_ident!("__{}Body", to_pascal_case(&m.name.to_string()));
852                let fields: Vec<TokenStream2> = body_params
853                    .iter()
854                    .map(|(n, t)| quote! { pub #n: #t })
855                    .collect();
856                Some(quote! {
857                    #[derive(::rsrpc::serde::Deserialize)]
858                    struct #struct_name {
859                        #(#fields),*
860                    }
861                })
862            } else {
863                None
864            }
865        })
866        .collect();
867
868    // Generate handlers
869    let handlers: Vec<TokenStream2> = http_methods
870        .iter()
871        .filter_map(|m| generate_http_handler(m, trait_name))
872        .collect();
873
874    // Generate route registrations
875    let routes: Vec<TokenStream2> = http_methods
876        .iter()
877        .filter_map(|m| {
878            let route = m.http.as_ref()?;
879            let handler_name = format_ident!("__{}_handler", m.name);
880            let axum_path = &route.axum_path;
881            let method_fn = route.method.to_axum_method();
882            Some(quote! {
883                .route(#axum_path, #method_fn(#handler_name::<T>))
884            })
885        })
886        .collect();
887
888    quote! {
889        #[cfg(feature = "http")]
890        const _: () = {
891            use ::rsrpc::axum::response::IntoResponse;
892
893            #(#query_structs)*
894            #(#body_structs)*
895            #(#handlers)*
896
897            impl dyn #trait_name {
898                /// Create an axum Router for HTTP endpoints.
899                #trait_vis fn http_routes<T: #trait_name + 'static>(
900                    service: ::std::sync::Arc<T>
901                ) -> ::rsrpc::axum::Router {
902                    ::rsrpc::axum::Router::new()
903                        #(#routes)*
904                        .with_state(service)
905                }
906            }
907        };
908    }
909}