Skip to main content

dioxus_fullstack_macro/
lib.rs

1// TODO: Create README, uncomment this: #![doc = include_str!("../README.md")]
2#![doc(html_logo_url = "https://avatars.githubusercontent.com/u/79236386")]
3#![doc(html_favicon_url = "https://avatars.githubusercontent.com/u/79236386")]
4
5use core::panic;
6use proc_macro::TokenStream;
7use proc_macro2::{Span, TokenStream as TokenStream2};
8use quote::ToTokens;
9use quote::{format_ident, quote};
10use std::collections::HashMap;
11use syn::{
12    Attribute, Expr, ExprClosure, Lit, Result,
13    token::{Brace, Star},
14};
15use syn::{
16    Error, ExprTuple, FnArg, Meta, PathArguments, PathSegment, Token, Type, TypePath, braced,
17    bracketed,
18    parse::ParseStream,
19    punctuated::Punctuated,
20    token::{Comma, Slash},
21};
22use syn::{Ident, ItemFn, LitStr, Path, parse::Parse, parse_quote};
23use syn::{LitBool, LitInt, Pat, PatType, spanned::Spanned};
24
25/// ## Usage
26///
27/// ```rust,ignore
28/// # use dioxus::prelude::*;
29/// # #[derive(serde::Deserialize, serde::Serialize)]
30/// # struct BlogPost;
31/// # async fn load_posts(category: &str) -> Result<Vec<BlogPost>> { unimplemented!() }
32///
33/// #[server]
34/// async fn blog_posts(category: String) -> Result<Vec<BlogPost>> {
35///     let posts = load_posts(&category).await?;
36///     // maybe do some other work
37///     Ok(posts)
38/// }
39/// ```
40///
41/// ## Named Arguments
42///
43/// You can use any combination of the following named arguments:
44/// - `endpoint`: a prefix at which the server function handler will be mounted (defaults to `/api`).
45///   Example: `endpoint = "/my_api/my_serverfn"`.
46/// - `input`: the encoding for the arguments, defaults to `Json<T>`
47///     - You may customize the encoding of the arguments by specifying a different type for `input`.
48///     - Any axum `IntoRequest` extractor can be used here, and dioxus provides
49///       - `Json<T>`: The default axum `Json` extractor that decodes JSON-encoded request bodies.
50///       - `Cbor<T>`: A custom axum `Cbor` extractor that decodes CBOR-encoded request bodies.
51///       - `MessagePack<T>`: A custom axum `MessagePack` extractor that decodes MessagePack-encoded request bodies.
52/// - `output`: the encoding for the response (defaults to `Json`).
53///     - The `output` argument specifies how the server should encode the response data.
54///     - Acceptable values include:
55///       - `Json`: A response encoded as JSON (default). This is ideal for most web applications.
56///       - `Cbor`: A response encoded in the CBOR format for efficient, binary-encoded data.
57/// - `client`: a custom `Client` implementation that will be used for this server function. This allows
58///   customization of the client-side behavior if needed.
59///
60/// ## Advanced Usage of `input` and `output` Fields
61///
62/// The `input` and `output` fields allow you to customize how arguments and responses are encoded and decoded.
63/// These fields impose specific trait bounds on the types you use. Here are detailed examples for different scenarios:
64///
65/// ## Adding layers to server functions
66///
67/// Layers allow you to transform the request and response of a server function. You can use layers
68/// to add authentication, logging, or other functionality to your server functions. Server functions integrate
69/// with the tower ecosystem, so you can use any layer that is compatible with tower.
70///
71/// Common layers include:
72/// - [`tower_http::trace::TraceLayer`](https://docs.rs/tower-http/latest/tower_http/trace/struct.TraceLayer.html) for tracing requests and responses
73/// - [`tower_http::compression::CompressionLayer`](https://docs.rs/tower-http/latest/tower_http/compression/struct.CompressionLayer.html) for compressing large responses
74/// - [`tower_http::cors::CorsLayer`](https://docs.rs/tower-http/latest/tower_http/cors/struct.CorsLayer.html) for adding CORS headers to responses
75/// - [`tower_http::timeout::TimeoutLayer`](https://docs.rs/tower-http/latest/tower_http/timeout/struct.TimeoutLayer.html) for adding timeouts to requests
76/// - [`tower_sessions::service::SessionManagerLayer`](https://docs.rs/tower-sessions/0.13.0/tower_sessions/service/struct.SessionManagerLayer.html) for adding session management to requests
77///
78/// You can add a tower [`Layer`](https://docs.rs/tower/latest/tower/trait.Layer.html) to your server function with the middleware attribute:
79///
80/// ```rust,ignore
81/// # use dioxus::prelude::*;
82/// #[server]
83/// // The TraceLayer will log all requests to the console
84/// #[middleware(tower_http::timeout::TimeoutLayer::new(std::time::Duration::from_secs(5)))]
85/// pub async fn my_wacky_server_fn(input: Vec<String>) -> ServerFnResult<usize> {
86///     unimplemented!()
87/// }
88/// ```
89#[proc_macro_attribute]
90pub fn server(attr: proc_macro::TokenStream, mut item: TokenStream) -> TokenStream {
91    // Parse the attribute list using the old server_fn arg parser.
92    let args = match syn::parse::<ServerFnArgs>(attr) {
93        Ok(args) => args,
94        Err(err) => {
95            let err: TokenStream = err.to_compile_error().into();
96            item.extend(err);
97            return item;
98        }
99    };
100
101    let method = Method::Post(Ident::new("POST", proc_macro2::Span::call_site()));
102    let prefix = args
103        .prefix
104        .unwrap_or_else(|| LitStr::new("/api", Span::call_site()));
105
106    let route: Route = Route {
107        method: None,
108        path_params: vec![],
109        query_params: vec![],
110        route_lit: args.fn_path,
111        oapi_options: None,
112        server_args: args.server_args,
113        prefix: Some(prefix),
114        _input_encoding: args.input,
115        _output_encoding: args.output,
116    };
117
118    match route_impl_with_route(route, item.clone(), Some(method)) {
119        Ok(mut tokens) => {
120            // Let's add some deprecated warnings to the various fields from `args` if the user is using them...
121            // We don't generate structs anymore, don't use various protocols, etc
122            if let Some(name) = args.struct_name {
123                tokens.extend(quote! {
124                    const _: () = {
125                        #[deprecated(note = "Dioxus server functions no longer generate a struct for the server function. The function itself is used directly.")]
126                        struct #name;
127                        fn ___assert_deprecated() {
128                            let _ = #name;
129                        }
130
131                        ()
132                    };
133                });
134            }
135
136            //
137            tokens.into()
138        }
139
140        // Retain the original function item and append the error to it. Better for autocomplete.
141        Err(err) => {
142            let err: TokenStream = err.to_compile_error().into();
143            item.extend(err);
144            item
145        }
146    }
147}
148
149#[proc_macro_attribute]
150pub fn get(args: proc_macro::TokenStream, body: TokenStream) -> TokenStream {
151    wrapped_route_impl(args, body, Some(Method::new_from_string("GET")))
152}
153
154#[proc_macro_attribute]
155pub fn post(args: proc_macro::TokenStream, body: TokenStream) -> TokenStream {
156    wrapped_route_impl(args, body, Some(Method::new_from_string("POST")))
157}
158
159#[proc_macro_attribute]
160pub fn put(args: proc_macro::TokenStream, body: TokenStream) -> TokenStream {
161    wrapped_route_impl(args, body, Some(Method::new_from_string("PUT")))
162}
163
164#[proc_macro_attribute]
165pub fn delete(args: proc_macro::TokenStream, body: TokenStream) -> TokenStream {
166    wrapped_route_impl(args, body, Some(Method::new_from_string("DELETE")))
167}
168
169#[proc_macro_attribute]
170pub fn patch(args: proc_macro::TokenStream, body: TokenStream) -> TokenStream {
171    wrapped_route_impl(args, body, Some(Method::new_from_string("PATCH")))
172}
173
174fn wrapped_route_impl(
175    attr: TokenStream,
176    mut item: TokenStream,
177    method: Option<Method>,
178) -> TokenStream {
179    match route_impl(attr, item.clone(), method) {
180        Ok(tokens) => tokens.into(),
181        Err(err) => {
182            let err: TokenStream = err.to_compile_error().into();
183            item.extend(err);
184            item
185        }
186    }
187}
188
189fn route_impl(
190    attr: TokenStream,
191    item: TokenStream,
192    method_from_macro: Option<Method>,
193) -> syn::Result<TokenStream2> {
194    let route = syn::parse::<Route>(attr)?;
195    route_impl_with_route(route, item, method_from_macro)
196}
197
198fn route_impl_with_route(
199    route: Route,
200    item: TokenStream,
201    method_from_macro: Option<Method>,
202) -> syn::Result<TokenStream2> {
203    // Parse the route and function
204    let mut function = syn::parse::<ItemFn>(item)?;
205
206    // Collect the middleware initializers
207    let middleware_layers = function
208        .attrs
209        .iter()
210        .filter(|attr| attr.path().is_ident("middleware"))
211        .map(|f| match &f.meta {
212            Meta::List(meta_list) => Ok({
213                let tokens = &meta_list.tokens;
214                quote! { .layer(#tokens) }
215            }),
216            _ => Err(Error::new(
217                f.span(),
218                "Expected middleware attribute to be a list, e.g. #[middleware(MyLayer::new())]",
219            )),
220        })
221        .collect::<Result<Vec<_>>>()?;
222
223    // don't re-emit the middleware attribute on the inner
224    function
225        .attrs
226        .retain(|attr| !attr.path().is_ident("middleware"));
227
228    // Attach `#[allow(unused_mut)]` to all original inputs to avoid warnings
229    let outer_inputs = function
230        .sig
231        .inputs
232        .iter()
233        .enumerate()
234        .map(|(i, arg)| match arg {
235            FnArg::Receiver(_receiver) => panic!("Self type is not supported"),
236            FnArg::Typed(pat_type) => match pat_type.pat.as_ref() {
237                Pat::Ident(_) => {
238                    quote! { #[allow(unused_mut)] #pat_type }
239                }
240                _ => {
241                    let ident = format_ident!("___Arg{}", i);
242                    let ty = &pat_type.ty;
243                    quote! { #[allow(unused_mut)] #ident: #ty }
244                }
245            },
246        })
247        .collect::<Punctuated<_, Token![,]>>();
248    // .collect::<Punctuated<_, Token![,]>>();
249
250    let route = CompiledRoute::from_route(route, &function, false, method_from_macro)?;
251    let query_params_struct = route.query_params_struct(false);
252    let method_ident = &route.method;
253    let body_json_args = route.remaining_pattypes_named(&function.sig.inputs);
254    let body_json_names = body_json_args
255        .iter()
256        .map(|(i, pat_type)| match &*pat_type.pat {
257            Pat::Ident(pat_ident) => pat_ident.ident.clone(),
258            _ => format_ident!("___Arg{}", i),
259        })
260        .collect::<Vec<_>>();
261    let body_json_types = body_json_args
262        .iter()
263        .map(|pat_type| &pat_type.1.ty)
264        .collect::<Vec<_>>();
265    let route_docs = route.to_doc_comments();
266
267    // Get the variables we need for code generation
268    let fn_on_server_name = &function.sig.ident;
269    let vis = &function.vis;
270    let (impl_generics, ty_generics, where_clause) = &function.sig.generics.split_for_impl();
271    let ty_generics = ty_generics.as_turbofish();
272    let fn_docs = function
273        .attrs
274        .iter()
275        .filter(|attr| attr.path().is_ident("doc"));
276
277    let __axum = quote! { dioxus_server::axum };
278
279    let output_type = match &function.sig.output {
280        syn::ReturnType::Default => parse_quote! { () },
281        syn::ReturnType::Type(_, ty) => (*ty).clone(),
282    };
283
284    let query_param_names = route
285        .query_params
286        .iter()
287        .filter(|c| !c.catch_all)
288        .map(|param| &param.binding);
289
290    let path_param_args = route.path_params.iter().map(|(_slash, param)| match param {
291        PathParam::Capture(_lit, _brace_1, ident, _ty, _brace_2) => {
292            Some(quote! { #ident = #ident, })
293        }
294        PathParam::WildCard(_lit, _brace_1, _star, ident, _ty, _brace_2) => {
295            Some(quote! { #ident = #ident, })
296        }
297        PathParam::Static(_lit) => None,
298    });
299
300    let out_ty = match output_type.as_ref() {
301        Type::Tuple(tuple) if tuple.elems.is_empty() => parse_quote! { () },
302        _ => output_type.clone(),
303    };
304
305    let mut function_on_server = function.clone();
306    function_on_server
307        .sig
308        .inputs
309        .extend(route.server_args.clone());
310
311    let server_names = route
312        .server_args
313        .iter()
314        .enumerate()
315        .map(|(i, pat_type)| match pat_type {
316            FnArg::Typed(_pat_type) => format_ident!("___sarg___{}", i),
317            FnArg::Receiver(_) => panic!("Self type is not supported"),
318        })
319        .collect::<Vec<_>>();
320
321    let server_types = route
322        .server_args
323        .iter()
324        .map(|pat_type| match pat_type {
325            FnArg::Receiver(_) => parse_quote! { () },
326            FnArg::Typed(pat_type) => (*pat_type.ty).clone(),
327        })
328        .collect::<Vec<_>>();
329
330    let body_struct_impl = {
331        let tys = body_json_types
332            .iter()
333            .enumerate()
334            .map(|(idx, _)| format_ident!("__Ty{}", idx));
335
336        let names = body_json_names.iter().enumerate().map(|(idx, name)| {
337            let ty_name = format_ident!("__Ty{}", idx);
338            quote! { #name: #ty_name }
339        });
340
341        quote! {
342            #[derive(serde::Serialize, serde::Deserialize)]
343            #[serde(crate = "serde")]
344            struct ___Body_Serialize___< #(#tys,)* > {
345                #(#names,)*
346            }
347        }
348    };
349
350    // This unpacks the body struct into the individual variables that get scoped
351    let unpack_closure = {
352        let unpack_args = body_json_names.iter().map(|name| quote! { data.#name });
353        quote! {
354            |data| { ( #(#unpack_args,)* ) }
355        }
356    };
357
358    let as_axum_path = route.to_axum_path_string();
359
360    let query_endpoint = if let Some(full_url) = route.url_without_queries_for_format() {
361        quote! { format!(#full_url, #( #path_param_args)*) }
362    } else {
363        quote! { __ENDPOINT_PATH.to_string() }
364    };
365
366    let endpoint_path = {
367        let prefix = route
368            .prefix
369            .as_ref()
370            .cloned()
371            .unwrap_or_else(|| LitStr::new("", Span::call_site()));
372
373        let route_lit = if let Some(lit) = as_axum_path {
374            quote! { #lit }
375        } else {
376            let name =
377                route.route_lit.as_ref().cloned().unwrap_or_else(|| {
378                    LitStr::new(&fn_on_server_name.to_string(), Span::call_site())
379                });
380            quote! {
381                concat!(
382                    "/",
383                    #name
384                )
385            }
386        };
387
388        let hash = match route.prefix.as_ref() {
389            // Implicit route lit, we need to hash the function signature to avoid collisions
390            Some(_) if route.route_lit.is_none() => {
391                // let enable_hash = option_env!("DISABLE_SERVER_FN_HASH").is_none();
392                let key_env_var = match option_env!("SERVER_FN_OVERRIDE_KEY") {
393                    Some(_) => "SERVER_FN_OVERRIDE_KEY",
394                    None => "CARGO_MANIFEST_DIR",
395                };
396                quote! {
397                    dioxus_fullstack::xxhash_rust::const_xxh64::xxh64(
398                        concat!(env!(#key_env_var), ":", module_path!()).as_bytes(),
399                        0
400                    )
401                }
402            }
403
404            // Explicit route lit, no need to hash
405            _ => quote! { "" },
406        };
407
408        quote! {
409            dioxus_fullstack::const_format::concatcp!(#prefix, #route_lit, #hash)
410        }
411    };
412
413    let extracted_idents = route.extracted_idents();
414
415    let query_tokens = if route.query_is_catchall() {
416        let query = route
417            .query_params
418            .iter()
419            .find(|param| param.catch_all)
420            .unwrap();
421        let input = &function.sig.inputs[query.arg_idx];
422        let name = match input {
423            FnArg::Typed(pat_type) => match pat_type.pat.as_ref() {
424                Pat::Ident(pat_ident) => pat_ident.ident.clone(),
425                _ => format_ident!("___Arg{}", query.arg_idx),
426            },
427            FnArg::Receiver(_receiver) => panic!(),
428        };
429        quote! {
430            #name
431        }
432    } else {
433        quote! {
434            __QueryParams__ { #(#query_param_names,)* }
435        }
436    };
437
438    let extracted_as_server_headers = route.extracted_as_server_headers(query_tokens.clone());
439
440    Ok(quote! {
441        #(#fn_docs)*
442        #route_docs
443        #[deny(
444            unexpected_cfgs,
445            reason = "
446==========================================================================================
447  Using Dioxus Server Functions requires a `server` feature flag in your `Cargo.toml`.
448  Please add the following to your `Cargo.toml`:
449
450  ```toml
451  [features]
452  server = [\"dioxus/server\"]
453  ```
454
455  To enable better Rust-Analyzer support, you can make `server` a default feature:
456  ```toml
457  [features]
458  default = [\"web\", \"server\"]
459  web = [\"dioxus/web\"]
460  server = [\"dioxus/server\"]
461  ```
462==========================================================================================
463        "
464        )]
465        #vis async fn #fn_on_server_name #impl_generics( #outer_inputs ) -> #out_ty #where_clause {
466            use dioxus_fullstack::serde as serde;
467            use dioxus_fullstack::{
468                // concrete types
469                ServerFnEncoder, ServerFnDecoder, FullstackContext,
470
471                // "magic" traits for encoding/decoding on the client
472                ExtractRequest, EncodeRequest, RequestDecodeResult, RequestDecodeErr,
473
474                // "magic" traits for encoding/decoding on the server
475                MakeAxumResponse, MakeAxumError,
476            };
477
478            #query_params_struct
479
480            #body_struct_impl
481
482            const __ENDPOINT_PATH: &str = #endpoint_path;
483
484            {
485                _ = dioxus_fullstack::assert_is_result::<#out_ty>();
486
487                let verify_token = (&&&&&&&&&&&&&&ServerFnEncoder::<___Body_Serialize___<#(#body_json_types,)*>, (#(#body_json_types,)*)>::new())
488                    .verify_can_serialize();
489
490                dioxus_fullstack::assert_can_encode(verify_token);
491
492                let decode_token = (&&&&&ServerFnDecoder::<#out_ty>::new())
493                    .verify_can_deserialize();
494
495                dioxus_fullstack::assert_can_decode(decode_token);
496            };
497
498
499            // On the client, we make the request to the server
500            // We want to support extremely flexible error types and return types, making this more complex than it should
501            #[allow(clippy::unused_unit)]
502            #[cfg(not(feature = "server"))]
503            {
504                let client = dioxus_fullstack::ClientRequest::new(
505                    dioxus_fullstack::http::Method::#method_ident,
506                    #query_endpoint,
507                    &#query_tokens,
508                );
509
510                let response = (&&&&&&&&&&&&&&ServerFnEncoder::<___Body_Serialize___<#(#body_json_types,)*>, (#(#body_json_types,)*)>::new())
511                    .fetch_client(client, ___Body_Serialize___ { #(#body_json_names,)* }, #unpack_closure)
512                    .await;
513
514                let decoded = (&&&&&ServerFnDecoder::<#out_ty>::new())
515                    .decode_client_response(response)
516                    .await;
517
518                let result = (&&&&&ServerFnDecoder::<#out_ty>::new())
519                    .decode_client_err(decoded)
520                    .await;
521
522                return result;
523            }
524
525            // On the server, we expand the tokens and submit the function to inventory
526            #[cfg(feature = "server")] {
527                #function_on_server
528
529                #[allow(clippy::unused_unit)]
530                fn __inner__function__ #impl_generics(
531                    ___state: #__axum::extract::State<FullstackContext>,
532                    ___request: #__axum::extract::Request,
533                ) -> std::pin::Pin<Box<dyn std::future::Future<Output = #__axum::response::Response>>> #where_clause {
534                    Box::pin(async move {
535                         match (&&&&&&&&&&&&&&ServerFnEncoder::<___Body_Serialize___<#(#body_json_types,)*>, (#(#body_json_types,)*)>::new()).extract_axum(___state.0, ___request, #unpack_closure).await {
536                            Ok(((#(#body_json_names,)* ), (#(#extracted_as_server_headers,)* #(#server_names,)*) )) => {
537                                // Call the user function
538                                let res = #fn_on_server_name #ty_generics(#(#extracted_idents,)* #(#body_json_names,)* #(#server_names,)*).await;
539
540                                // Encode the response Into a `Result<T, E>`
541                                let encoded = (&&&&&&ServerFnDecoder::<#out_ty>::new()).make_axum_response(res);
542
543                                // And then encode `Result<T, E>` into `Response`
544                                (&&&&&ServerFnDecoder::<#out_ty>::new()).make_axum_error(encoded)
545                            },
546                            Err(res) => res,
547                        }
548                    })
549                }
550
551                dioxus_server::inventory::submit! {
552                    dioxus_server::ServerFunction::new(
553                        dioxus_server::http::Method::#method_ident,
554                        __ENDPOINT_PATH,
555                        || {
556                            dioxus_server::ServerFunction::make_handler(dioxus_server::http::Method::#method_ident, __inner__function__ #ty_generics)
557                                #(#middleware_layers)*
558                        }
559                    )
560                }
561
562                // Extract the server arguments from the context if needed.
563                let (#(#server_names,)*) = dioxus_fullstack::FullstackContext::extract::<(#(#server_types,)*), _>().await?;
564
565                // Call the function directly
566                return #fn_on_server_name #ty_generics(
567                    #(#extracted_idents,)*
568                    #(#body_json_names,)*
569                    #(#server_names,)*
570                ).await;
571            }
572
573            #[allow(unreachable_code)]
574            {
575                unreachable!()
576            }
577        }
578    })
579}
580
581struct CompiledRoute {
582    method: Method,
583    #[allow(clippy::type_complexity)]
584    path_params: Vec<(Slash, PathParam)>,
585    query_params: Vec<QueryParam>,
586    route_lit: Option<LitStr>,
587    prefix: Option<LitStr>,
588    oapi_options: Option<OapiOptions>,
589    server_args: Punctuated<FnArg, Comma>,
590}
591
592struct QueryParam {
593    arg_idx: usize,
594    name: String,
595    binding: Ident,
596    catch_all: bool,
597    ty: Box<Type>,
598}
599
600impl CompiledRoute {
601    fn to_axum_path_string(&self) -> Option<String> {
602        if self.prefix.is_some() {
603            return None;
604        }
605
606        let mut path = String::new();
607
608        for (_slash, param) in &self.path_params {
609            path.push('/');
610            match param {
611                PathParam::Capture(lit, _brace_1, _, _, _brace_2) => {
612                    path.push('{');
613                    path.push_str(&lit.value());
614                    path.push('}');
615                }
616                PathParam::WildCard(lit, _brace_1, _, _, _, _brace_2) => {
617                    path.push('{');
618                    path.push('*');
619                    path.push_str(&lit.value());
620                    path.push('}');
621                }
622                PathParam::Static(lit) => path.push_str(&lit.value()),
623            }
624        }
625
626        Some(path)
627    }
628
629    /// Removes the arguments in `route` from `args`, and merges them in the output.
630    pub fn from_route(
631        mut route: Route,
632        function: &ItemFn,
633        with_aide: bool,
634        method_from_macro: Option<Method>,
635    ) -> syn::Result<Self> {
636        if !with_aide && route.oapi_options.is_some() {
637            return Err(syn::Error::new(
638                Span::call_site(),
639                "Use `api_route` instead of `route` to use OpenAPI options",
640            ));
641        } else if with_aide && route.oapi_options.is_none() {
642            route.oapi_options = Some(OapiOptions {
643                summary: None,
644                description: None,
645                id: None,
646                hidden: None,
647                tags: None,
648                security: None,
649                responses: None,
650                transform: None,
651            });
652        }
653
654        let sig = &function.sig;
655        let mut arg_map = sig
656            .inputs
657            .iter()
658            .enumerate()
659            .filter_map(|(i, item)| match item {
660                syn::FnArg::Receiver(_) => None,
661                syn::FnArg::Typed(pat_type) => Some((i, pat_type)),
662            })
663            .filter_map(|(i, pat_type)| match &*pat_type.pat {
664                syn::Pat::Ident(ident) => Some((ident.ident.clone(), (pat_type.ty.clone(), i))),
665                _ => None,
666            })
667            .collect::<HashMap<_, _>>();
668
669        for (_slash, path_param) in &mut route.path_params {
670            match path_param {
671                PathParam::Capture(_lit, _, ident, ty, _) => {
672                    let (new_ident, new_ty) = arg_map.remove_entry(ident).ok_or_else(|| {
673                        syn::Error::new(
674                            ident.span(),
675                            format!("path parameter `{}` not found in function arguments", ident),
676                        )
677                    })?;
678                    *ident = new_ident;
679                    *ty = new_ty.0;
680                }
681                PathParam::WildCard(_lit, _, _star, ident, ty, _) => {
682                    let (new_ident, new_ty) = arg_map.remove_entry(ident).ok_or_else(|| {
683                        syn::Error::new(
684                            ident.span(),
685                            format!("path parameter `{}` not found in function arguments", ident),
686                        )
687                    })?;
688                    *ident = new_ident;
689                    *ty = new_ty.0;
690                }
691                PathParam::Static(_lit) => {}
692            }
693        }
694
695        let mut query_params = Vec::new();
696        for param in route.query_params {
697            let (ident, ty) = arg_map.remove_entry(&param.binding).ok_or_else(|| {
698                syn::Error::new(
699                    param.binding.span(),
700                    format!(
701                        "query parameter `{}` not found in function arguments",
702                        param.binding
703                    ),
704                )
705            })?;
706            query_params.push(QueryParam {
707                binding: ident,
708                name: param.name,
709                catch_all: param.catch_all,
710                ty: ty.0,
711                arg_idx: ty.1,
712            });
713        }
714
715        // Disallow multiple query params if one is a catch-all
716        if query_params.iter().any(|param| param.catch_all) && query_params.len() > 1 {
717            return Err(syn::Error::new(
718                Span::call_site(),
719                "Cannot have multiple query parameters when one is a catch-all",
720            ));
721        }
722
723        if let Some(options) = route.oapi_options.as_mut() {
724            options.merge_with_fn(function)
725        }
726
727        let method = match (method_from_macro, route.method) {
728            (Some(method), None) => method,
729            (None, Some(method)) => method,
730            (Some(_), Some(_)) => {
731                return Err(syn::Error::new(
732                    Span::call_site(),
733                    "HTTP method specified both in macro and in attribute",
734                ));
735            }
736            (None, None) => {
737                return Err(syn::Error::new(
738                    Span::call_site(),
739                    "HTTP method not specified in macro or in attribute",
740                ));
741            }
742        };
743
744        Ok(Self {
745            method,
746            route_lit: route.route_lit,
747            path_params: route.path_params,
748            query_params,
749            oapi_options: route.oapi_options,
750            prefix: route.prefix,
751            server_args: route.server_args,
752        })
753    }
754
755    pub fn query_is_catchall(&self) -> bool {
756        self.query_params.iter().any(|param| param.catch_all)
757    }
758
759    pub fn extracted_as_server_headers(&self, query_tokens: TokenStream2) -> Vec<Pat> {
760        let mut out = vec![];
761
762        // Add the path extractor
763        out.push({
764            let path_iter = self
765                .path_params
766                .iter()
767                .filter_map(|(_slash, path_param)| path_param.capture());
768            let idents = path_iter.clone().map(|item| item.0);
769            parse_quote! {
770                dioxus_server::axum::extract::Path((#(#idents,)*))
771            }
772        });
773
774        out.push(parse_quote!(
775            dioxus_fullstack::payloads::Query(#query_tokens)
776        ));
777
778        out
779    }
780
781    pub fn query_params_struct(&self, with_aide: bool) -> TokenStream2 {
782        let fields = self.query_params.iter().map(|item| {
783            let name = &item.name;
784            let binding = &item.binding;
785            let ty = &item.ty;
786            if item.catch_all {
787                quote! {}
788            } else if item.binding != item.name {
789                quote! {
790                    #[serde(rename = #name)]
791                    #binding: #ty,
792                }
793            } else {
794                quote! { #binding: #ty, }
795            }
796        });
797        let derive = match with_aide {
798            true => quote! {
799                #[derive(serde::Deserialize, serde::Serialize, ::schemars::JsonSchema)]
800                #[serde(crate = "serde")]
801            },
802            false => quote! {
803                #[derive(serde::Deserialize, serde::Serialize)]
804                #[serde(crate = "serde")]
805            },
806        };
807        quote! {
808            #derive
809            struct __QueryParams__ {
810                #(#fields)*
811            }
812        }
813    }
814
815    pub fn extracted_idents(&self) -> Vec<Ident> {
816        let mut idents = Vec::new();
817        for (_slash, path_param) in &self.path_params {
818            if let Some((ident, _ty)) = path_param.capture() {
819                idents.push(ident.clone());
820            }
821        }
822        for param in &self.query_params {
823            idents.push(param.binding.clone());
824        }
825        idents
826    }
827
828    fn remaining_pattypes_named(&self, args: &Punctuated<FnArg, Comma>) -> Vec<(usize, PatType)> {
829        args.iter()
830            .enumerate()
831            .filter_map(|(i, item)| {
832                if let FnArg::Typed(pat_type) = item {
833                    if let syn::Pat::Ident(pat_ident) = &*pat_type.pat
834                        && (self.path_params.iter().any(|(_slash, path_param)| {
835                            if let Some((path_ident, _ty)) = path_param.capture() {
836                                path_ident == &pat_ident.ident
837                            } else {
838                                false
839                            }
840                        }) || self
841                            .query_params
842                            .iter()
843                            .any(|query| query.binding == pat_ident.ident))
844                    {
845                        return None;
846                    }
847
848                    Some((i, pat_type.clone()))
849                } else {
850                    unimplemented!("Self type is not supported")
851                }
852            })
853            .collect()
854    }
855
856    pub(crate) fn to_doc_comments(&self) -> TokenStream2 {
857        let mut doc = format!(
858            "# Handler information
859- Method: `{}`
860- Path: `{}`",
861            self.method.to_axum_method_name(),
862            self.route_lit
863                .as_ref()
864                .map(|lit| lit.value())
865                .unwrap_or_else(|| "<auto>".into()),
866        );
867
868        if let Some(options) = &self.oapi_options {
869            let summary = options
870                .summary
871                .as_ref()
872                .map(|(_, summary)| format!("\"{}\"", summary.value()))
873                .unwrap_or("None".to_string());
874            let description = options
875                .description
876                .as_ref()
877                .map(|(_, description)| format!("\"{}\"", description.value()))
878                .unwrap_or("None".to_string());
879            let id = options
880                .id
881                .as_ref()
882                .map(|(_, id)| format!("\"{}\"", id.value()))
883                .unwrap_or("None".to_string());
884            let hidden = options
885                .hidden
886                .as_ref()
887                .map(|(_, hidden)| hidden.value().to_string())
888                .unwrap_or("None".to_string());
889            let tags = options
890                .tags
891                .as_ref()
892                .map(|(_, tags)| tags.to_string())
893                .unwrap_or("[]".to_string());
894            let security = options
895                .security
896                .as_ref()
897                .map(|(_, security)| security.to_string())
898                .unwrap_or("{}".to_string());
899
900            doc = format!(
901                "{doc}
902
903## OpenAPI
904- Summary: `{summary}`
905- Description: `{description}`
906- Operation id: `{id}`
907- Tags: `{tags}`
908- Security: `{security}`
909- Hidden: `{hidden}`
910"
911            );
912        }
913
914        quote!(
915            #[doc = #doc]
916        )
917    }
918
919    fn url_without_queries_for_format(&self) -> Option<String> {
920        // If there's a prefix, then it's an old-style route, and we can't generate a format string.
921        if self.prefix.is_some() {
922            return None;
923        }
924
925        // If there's no explicit route, we can't generate a format string this way.
926        let _lit = self.route_lit.as_ref()?;
927
928        let url_without_queries =
929            self.path_params
930                .iter()
931                .fold(String::new(), |mut acc, (_slash, param)| {
932                    acc.push('/');
933                    match param {
934                        PathParam::Capture(lit, _brace_1, _, _, _brace_2) => {
935                            acc.push_str(&format!("{{{}}}", lit.value()));
936                        }
937                        PathParam::WildCard(lit, _brace_1, _, _, _, _brace_2) => {
938                            // no `*` since we want to use the argument *as the wildcard* when making requests
939                            // it's not super applicable to server functions, more for general route generation
940                            acc.push_str(&format!("{{{}}}", lit.value()));
941                        }
942                        PathParam::Static(lit) => {
943                            acc.push_str(&lit.value());
944                        }
945                    }
946                    acc
947                });
948
949        let prefix = self
950            .prefix
951            .as_ref()
952            .cloned()
953            .unwrap_or_else(|| LitStr::new("", Span::call_site()))
954            .value();
955        let full_url = format!(
956            "{}{}{}",
957            prefix,
958            if url_without_queries.starts_with("/") {
959                ""
960            } else {
961                "/"
962            },
963            url_without_queries
964        );
965
966        Some(full_url)
967    }
968}
969
970struct RouteParser {
971    path_params: Vec<(Slash, PathParam)>,
972    query_params: Vec<QueryParam>,
973}
974
975impl RouteParser {
976    fn new(lit: LitStr) -> syn::Result<Self> {
977        let val = lit.value();
978        let span = lit.span();
979        let split_route = val.split('?').collect::<Vec<_>>();
980        if split_route.len() > 2 {
981            return Err(syn::Error::new(span, "expected at most one '?'"));
982        }
983
984        let path = split_route[0];
985        if !path.starts_with('/') {
986            return Err(syn::Error::new(span, "expected path to start with '/'"));
987        }
988        let path = path.strip_prefix('/').unwrap();
989
990        let mut path_params = Vec::new();
991
992        for path_param in path.split('/') {
993            path_params.push((
994                Slash(span),
995                PathParam::new(path_param, span, Box::new(parse_quote!(())))?,
996            ));
997        }
998
999        let path_param_len = path_params.len();
1000        for (i, (_slash, path_param)) in path_params.iter().enumerate() {
1001            match path_param {
1002                PathParam::WildCard(_, _, _, _, _, _) => {
1003                    if i != path_param_len - 1 {
1004                        return Err(syn::Error::new(
1005                            span,
1006                            "wildcard path param must be the last path param",
1007                        ));
1008                    }
1009                }
1010                PathParam::Capture(_, _, _, _, _) => (),
1011                PathParam::Static(lit) => {
1012                    if lit.value() == "*" && i != path_param_len - 1 {
1013                        return Err(syn::Error::new(
1014                            span,
1015                            "wildcard path param must be the last path param",
1016                        ));
1017                    }
1018                }
1019            }
1020        }
1021
1022        let mut query_params = Vec::new();
1023        if split_route.len() == 2 {
1024            let query = split_route[1];
1025            for query_param in query.split('&') {
1026                if query_param.starts_with(":") {
1027                    let ident = Ident::new(query_param.strip_prefix(":").unwrap(), span);
1028
1029                    query_params.push(QueryParam {
1030                        name: ident.to_string(),
1031                        binding: ident,
1032                        catch_all: true,
1033                        ty: parse_quote!(()),
1034                        arg_idx: usize::MAX,
1035                    });
1036                } else if query_param.starts_with("{") && query_param.ends_with("}") {
1037                    let ident = Ident::new(
1038                        query_param
1039                            .strip_prefix("{")
1040                            .unwrap()
1041                            .strip_suffix("}")
1042                            .unwrap(),
1043                        span,
1044                    );
1045
1046                    query_params.push(QueryParam {
1047                        name: ident.to_string(),
1048                        binding: ident,
1049                        catch_all: true,
1050                        ty: parse_quote!(()),
1051                        arg_idx: usize::MAX,
1052                    });
1053                } else {
1054                    // if there's an `=` in the query param, we only take the left side as the name, and the right side is the binding
1055                    let name;
1056                    let binding;
1057                    if let Some((n, b)) = query_param.split_once('=') {
1058                        name = n;
1059                        binding = Ident::new(b, span);
1060                    } else {
1061                        name = query_param;
1062                        binding = Ident::new(query_param, span);
1063                    }
1064
1065                    query_params.push(QueryParam {
1066                        name: name.to_string(),
1067                        binding,
1068                        catch_all: false,
1069                        ty: parse_quote!(()),
1070                        arg_idx: usize::MAX,
1071                    });
1072                }
1073            }
1074        }
1075
1076        Ok(Self {
1077            path_params,
1078            query_params,
1079        })
1080    }
1081}
1082
1083enum PathParam {
1084    WildCard(LitStr, Brace, Star, Ident, Box<Type>, Brace),
1085    Capture(LitStr, Brace, Ident, Box<Type>, Brace),
1086    Static(LitStr),
1087}
1088
1089impl PathParam {
1090    fn _captures(&self) -> bool {
1091        matches!(self, Self::Capture(..) | Self::WildCard(..))
1092    }
1093
1094    fn capture(&self) -> Option<(&Ident, &Type)> {
1095        match self {
1096            Self::Capture(_, _, ident, ty, _) => Some((ident, ty)),
1097            Self::WildCard(_, _, _, ident, ty, _) => Some((ident, ty)),
1098            _ => None,
1099        }
1100    }
1101
1102    fn new(str: &str, span: Span, ty: Box<Type>) -> syn::Result<Self> {
1103        let ok = if str.starts_with('{') {
1104            let str = str
1105                .strip_prefix('{')
1106                .unwrap()
1107                .strip_suffix('}')
1108                .ok_or_else(|| {
1109                    syn::Error::new(span, "expected path param to be wrapped in curly braces")
1110                })?;
1111            Self::Capture(
1112                LitStr::new(str, span),
1113                Brace(span),
1114                Ident::new(str, span),
1115                ty,
1116                Brace(span),
1117            )
1118        } else if str.starts_with('*') && str.len() > 1 {
1119            let str = str.strip_prefix('*').unwrap();
1120            Self::WildCard(
1121                LitStr::new(str, span),
1122                Brace(span),
1123                Star(span),
1124                Ident::new(str, span),
1125                ty,
1126                Brace(span),
1127            )
1128        } else if str.starts_with(':') && str.len() > 1 {
1129            let str = str.strip_prefix(':').unwrap();
1130            Self::Capture(
1131                LitStr::new(str, span),
1132                Brace(span),
1133                Ident::new(str, span),
1134                ty,
1135                Brace(span),
1136            )
1137        } else {
1138            Self::Static(LitStr::new(str, span))
1139        };
1140
1141        Ok(ok)
1142    }
1143}
1144
1145struct OapiOptions {
1146    summary: Option<(Ident, LitStr)>,
1147    description: Option<(Ident, LitStr)>,
1148    id: Option<(Ident, LitStr)>,
1149    hidden: Option<(Ident, LitBool)>,
1150    tags: Option<(Ident, StrArray)>,
1151    security: Option<(Ident, Security)>,
1152    responses: Option<(Ident, Responses)>,
1153    transform: Option<(Ident, ExprClosure)>,
1154}
1155
1156struct Security(Vec<(LitStr, StrArray)>);
1157impl Parse for Security {
1158    fn parse(input: ParseStream) -> syn::Result<Self> {
1159        let inner;
1160        braced!(inner in input);
1161
1162        let mut arr = Vec::new();
1163        while !inner.is_empty() {
1164            let scheme = inner.parse::<LitStr>()?;
1165            let _ = inner.parse::<Token![:]>()?;
1166            let scopes = inner.parse::<StrArray>()?;
1167            let _ = inner.parse::<Token![,]>().ok();
1168            arr.push((scheme, scopes));
1169        }
1170
1171        Ok(Self(arr))
1172    }
1173}
1174
1175impl std::fmt::Display for Security {
1176    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1177        write!(f, "{{")?;
1178        for (i, (scheme, scopes)) in self.0.iter().enumerate() {
1179            if i > 0 {
1180                write!(f, ", ")?;
1181            }
1182            write!(f, "{}: {}", scheme.value(), scopes)?;
1183        }
1184        write!(f, "}}")
1185    }
1186}
1187
1188struct Responses(Vec<(LitInt, Type)>);
1189impl Parse for Responses {
1190    fn parse(input: ParseStream) -> syn::Result<Self> {
1191        let inner;
1192        braced!(inner in input);
1193
1194        let mut arr = Vec::new();
1195        while !inner.is_empty() {
1196            let status = inner.parse::<LitInt>()?;
1197            let _ = inner.parse::<Token![:]>()?;
1198            let ty = inner.parse::<Type>()?;
1199            let _ = inner.parse::<Token![,]>().ok();
1200            arr.push((status, ty));
1201        }
1202
1203        Ok(Self(arr))
1204    }
1205}
1206
1207impl std::fmt::Display for Responses {
1208    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1209        write!(f, "{{")?;
1210        for (i, (status, ty)) in self.0.iter().enumerate() {
1211            if i > 0 {
1212                write!(f, ", ")?;
1213            }
1214            write!(f, "{}: {}", status, ty.to_token_stream())?;
1215        }
1216        write!(f, "}}")
1217    }
1218}
1219
1220#[derive(Clone)]
1221struct StrArray(Vec<LitStr>);
1222impl Parse for StrArray {
1223    fn parse(input: ParseStream) -> syn::Result<Self> {
1224        let inner;
1225        bracketed!(inner in input);
1226        let mut arr = Vec::new();
1227        while !inner.is_empty() {
1228            arr.push(inner.parse::<LitStr>()?);
1229            inner.parse::<Token![,]>().ok();
1230        }
1231        Ok(Self(arr))
1232    }
1233}
1234
1235impl std::fmt::Display for StrArray {
1236    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1237        write!(f, "[")?;
1238        for (i, lit) in self.0.iter().enumerate() {
1239            if i > 0 {
1240                write!(f, ", ")?;
1241            }
1242            write!(f, "\"{}\"", lit.value())?;
1243        }
1244        write!(f, "]")
1245    }
1246}
1247
1248impl Parse for OapiOptions {
1249    fn parse(input: ParseStream) -> syn::Result<Self> {
1250        let mut this = Self {
1251            summary: None,
1252            description: None,
1253            id: None,
1254            hidden: None,
1255            tags: None,
1256            security: None,
1257            responses: None,
1258            transform: None,
1259        };
1260
1261        while !input.is_empty() {
1262            let ident = input.parse::<Ident>()?;
1263            let _ = input.parse::<Token![:]>()?;
1264            match ident.to_string().as_str() {
1265                "summary" => this.summary = Some((ident, input.parse()?)),
1266                "description" => this.description = Some((ident, input.parse()?)),
1267                "id" => this.id = Some((ident, input.parse()?)),
1268                "hidden" => this.hidden = Some((ident, input.parse()?)),
1269                "tags" => this.tags = Some((ident, input.parse()?)),
1270                "security" => this.security = Some((ident, input.parse()?)),
1271                "responses" => this.responses = Some((ident, input.parse()?)),
1272                "transform" => this.transform = Some((ident, input.parse()?)),
1273                _ => {
1274                    return Err(syn::Error::new(
1275                        ident.span(),
1276                        "unexpected field, expected one of (summary, description, id, hidden, tags, security, responses, transform)",
1277                    ));
1278                }
1279            }
1280            let _ = input.parse::<Token![,]>().ok();
1281        }
1282
1283        Ok(this)
1284    }
1285}
1286
1287impl OapiOptions {
1288    fn merge_with_fn(&mut self, function: &ItemFn) {
1289        if self.description.is_none() {
1290            self.description = doc_iter(&function.attrs)
1291                .skip(2)
1292                .map(|item| item.value())
1293                .reduce(|mut acc, item| {
1294                    acc.push('\n');
1295                    acc.push_str(&item);
1296                    acc
1297                })
1298                .map(|item| (parse_quote!(description), parse_quote!(#item)))
1299        }
1300        if self.summary.is_none() {
1301            self.summary = doc_iter(&function.attrs)
1302                .next()
1303                .map(|item| (parse_quote!(summary), item.clone()))
1304        }
1305        if self.id.is_none() {
1306            let id = &function.sig.ident;
1307            self.id = Some((parse_quote!(id), LitStr::new(&id.to_string(), id.span())));
1308        }
1309    }
1310}
1311
1312fn doc_iter(attrs: &[Attribute]) -> impl Iterator<Item = &LitStr> + '_ {
1313    attrs
1314        .iter()
1315        .filter(|attr| attr.path().is_ident("doc"))
1316        .map(|attr| {
1317            let Meta::NameValue(meta) = &attr.meta else {
1318                panic!("doc attribute is not a name-value attribute");
1319            };
1320            let Expr::Lit(lit) = &meta.value else {
1321                panic!("doc attribute is not a string literal");
1322            };
1323            let Lit::Str(lit_str) = &lit.lit else {
1324                panic!("doc attribute is not a string literal");
1325            };
1326            lit_str
1327        })
1328}
1329
1330struct Route {
1331    method: Option<Method>,
1332    path_params: Vec<(Slash, PathParam)>,
1333    query_params: Vec<QueryParam>,
1334    route_lit: Option<LitStr>,
1335    prefix: Option<LitStr>,
1336    oapi_options: Option<OapiOptions>,
1337    server_args: Punctuated<FnArg, Comma>,
1338
1339    // todo: support these since `server_fn` had them
1340    _input_encoding: Option<Type>,
1341    _output_encoding: Option<Type>,
1342}
1343
1344impl Parse for Route {
1345    fn parse(input: ParseStream) -> syn::Result<Self> {
1346        let method = if input.peek(Ident) {
1347            Some(input.parse::<Method>()?)
1348        } else {
1349            None
1350        };
1351
1352        let route_lit = input.parse::<LitStr>()?;
1353        let RouteParser {
1354            path_params,
1355            query_params,
1356        } = RouteParser::new(route_lit.clone())?;
1357
1358        let oapi_options = input
1359            .peek(Brace)
1360            .then(|| {
1361                let inner;
1362                braced!(inner in input);
1363                inner.parse::<OapiOptions>()
1364            })
1365            .transpose()?;
1366
1367        let server_args = if input.peek(Comma) {
1368            let _ = input.parse::<Comma>()?;
1369            input.parse_terminated(FnArg::parse, Comma)?
1370        } else {
1371            Punctuated::new()
1372        };
1373
1374        Ok(Route {
1375            method,
1376            path_params,
1377            query_params,
1378            route_lit: Some(route_lit),
1379            oapi_options,
1380            server_args,
1381            prefix: None,
1382            _input_encoding: None,
1383            _output_encoding: None,
1384        })
1385    }
1386}
1387
1388#[derive(Clone)]
1389enum Method {
1390    Get(Ident),
1391    Post(Ident),
1392    Put(Ident),
1393    Delete(Ident),
1394    Head(Ident),
1395    Connect(Ident),
1396    Options(Ident),
1397    Trace(Ident),
1398    Patch(Ident),
1399}
1400
1401impl ToTokens for Method {
1402    fn to_tokens(&self, tokens: &mut TokenStream2) {
1403        match self {
1404            Self::Get(ident)
1405            | Self::Post(ident)
1406            | Self::Put(ident)
1407            | Self::Delete(ident)
1408            | Self::Head(ident)
1409            | Self::Connect(ident)
1410            | Self::Options(ident)
1411            | Self::Trace(ident)
1412            | Self::Patch(ident) => {
1413                ident.to_tokens(tokens);
1414            }
1415        }
1416    }
1417}
1418
1419impl Parse for Method {
1420    fn parse(input: ParseStream) -> syn::Result<Self> {
1421        let ident = input.parse::<Ident>()?;
1422        match ident.to_string().to_uppercase().as_str() {
1423            "GET" => Ok(Self::Get(ident)),
1424            "POST" => Ok(Self::Post(ident)),
1425            "PUT" => Ok(Self::Put(ident)),
1426            "DELETE" => Ok(Self::Delete(ident)),
1427            "HEAD" => Ok(Self::Head(ident)),
1428            "CONNECT" => Ok(Self::Connect(ident)),
1429            "OPTIONS" => Ok(Self::Options(ident)),
1430            "TRACE" => Ok(Self::Trace(ident)),
1431            _ => Err(input
1432                .error("expected one of (GET, POST, PUT, DELETE, HEAD, CONNECT, OPTIONS, TRACE)")),
1433        }
1434    }
1435}
1436
1437impl Method {
1438    fn to_axum_method_name(&self) -> Ident {
1439        match self {
1440            Self::Get(span) => Ident::new("get", span.span()),
1441            Self::Post(span) => Ident::new("post", span.span()),
1442            Self::Put(span) => Ident::new("put", span.span()),
1443            Self::Delete(span) => Ident::new("delete", span.span()),
1444            Self::Head(span) => Ident::new("head", span.span()),
1445            Self::Connect(span) => Ident::new("connect", span.span()),
1446            Self::Options(span) => Ident::new("options", span.span()),
1447            Self::Trace(span) => Ident::new("trace", span.span()),
1448            Self::Patch(span) => Ident::new("patch", span.span()),
1449        }
1450    }
1451
1452    fn new_from_string(s: &str) -> Self {
1453        match s.to_uppercase().as_str() {
1454            "GET" => Self::Get(Ident::new("GET", Span::call_site())),
1455            "POST" => Self::Post(Ident::new("POST", Span::call_site())),
1456            "PUT" => Self::Put(Ident::new("PUT", Span::call_site())),
1457            "DELETE" => Self::Delete(Ident::new("DELETE", Span::call_site())),
1458            "HEAD" => Self::Head(Ident::new("HEAD", Span::call_site())),
1459            "CONNECT" => Self::Connect(Ident::new("CONNECT", Span::call_site())),
1460            "OPTIONS" => Self::Options(Ident::new("OPTIONS", Span::call_site())),
1461            "TRACE" => Self::Trace(Ident::new("TRACE", Span::call_site())),
1462            "PATCH" => Self::Patch(Ident::new("PATCH", Span::call_site())),
1463            _ => panic!("expected one of (GET, POST, PUT, DELETE, HEAD, CONNECT, OPTIONS, TRACE)"),
1464        }
1465    }
1466}
1467
1468mod kw {
1469    syn::custom_keyword!(with);
1470}
1471
1472/// The arguments to the `server` macro.
1473///
1474/// These originally came from the `server_fn` crate, but many no longer apply after the 0.7 fullstack
1475/// overhaul. We keep the parser here for temporary backwards compatibility with existing code, but
1476/// these arguments will be removed in a future release.
1477#[derive(Debug)]
1478#[non_exhaustive]
1479#[allow(unused)]
1480struct ServerFnArgs {
1481    /// The name of the struct that will implement the server function trait
1482    /// and be submitted to inventory.
1483    struct_name: Option<Ident>,
1484    /// The prefix to use for the server function URL.
1485    prefix: Option<LitStr>,
1486    /// The input http encoding to use for the server function.
1487    input: Option<Type>,
1488    /// Additional traits to derive on the input struct for the server function.
1489    input_derive: Option<ExprTuple>,
1490    /// The output http encoding to use for the server function.
1491    output: Option<Type>,
1492    /// The path to the server function crate.
1493    fn_path: Option<LitStr>,
1494    /// The server type to use for the server function.
1495    server: Option<Type>,
1496    /// The client type to use for the server function.
1497    client: Option<Type>,
1498    /// The custom wrapper to use for the server function struct.
1499    custom_wrapper: Option<syn::Path>,
1500    /// If the generated input type should implement `From` the only field in the input
1501    impl_from: Option<LitBool>,
1502    /// If the generated input type should implement `Deref` to the only field in the input
1503    impl_deref: Option<LitBool>,
1504    /// The protocol to use for the server function implementation.
1505    protocol: Option<Type>,
1506    builtin_encoding: bool,
1507    /// Server-only extractors (e.g., headers: HeaderMap, cookies: Cookies).
1508    /// These are arguments that exist purely on the server side.
1509    server_args: Punctuated<FnArg, Comma>,
1510}
1511
1512impl Parse for ServerFnArgs {
1513    fn parse(stream: ParseStream) -> syn::Result<Self> {
1514        // legacy 4-part arguments
1515        let mut struct_name: Option<Ident> = None;
1516        let mut prefix: Option<LitStr> = None;
1517        let mut encoding: Option<LitStr> = None;
1518        let mut fn_path: Option<LitStr> = None;
1519
1520        // new arguments: can only be keyed by name
1521        let mut input: Option<Type> = None;
1522        let mut input_derive: Option<ExprTuple> = None;
1523        let mut output: Option<Type> = None;
1524        let mut server: Option<Type> = None;
1525        let mut client: Option<Type> = None;
1526        let mut custom_wrapper: Option<syn::Path> = None;
1527        let mut impl_from: Option<LitBool> = None;
1528        let mut impl_deref: Option<LitBool> = None;
1529        let mut protocol: Option<Type> = None;
1530
1531        let mut use_key_and_value = false;
1532        let mut arg_pos = 0;
1533
1534        // Server-only extractors (key: Type pattern)
1535        // These come after config options (key = value pattern)
1536        // Example: #[server(endpoint = "/api/chat", headers: HeaderMap, cookies: Cookies)]
1537        let mut server_args: Punctuated<FnArg, Comma> = Punctuated::new();
1538
1539        while !stream.is_empty() {
1540            // Check if this looks like an extractor (Ident : Type)
1541            // If so, break out to parse extractors - they must come last
1542            if stream.peek(Ident) && stream.peek2(Token![:]) {
1543                break;
1544            }
1545
1546            arg_pos += 1;
1547            let lookahead = stream.lookahead1();
1548            if lookahead.peek(Ident) {
1549                let key_or_value: Ident = stream.parse()?;
1550
1551                let lookahead = stream.lookahead1();
1552                if lookahead.peek(Token![=]) {
1553                    stream.parse::<Token![=]>()?;
1554                    let key = key_or_value;
1555                    use_key_and_value = true;
1556                    if key == "name" {
1557                        if struct_name.is_some() {
1558                            return Err(syn::Error::new(
1559                                key.span(),
1560                                "keyword argument repeated: `name`",
1561                            ));
1562                        }
1563                        struct_name = Some(stream.parse()?);
1564                    } else if key == "prefix" {
1565                        if prefix.is_some() {
1566                            return Err(syn::Error::new(
1567                                key.span(),
1568                                "keyword argument repeated: `prefix`",
1569                            ));
1570                        }
1571                        prefix = Some(stream.parse()?);
1572                    } else if key == "encoding" {
1573                        if encoding.is_some() {
1574                            return Err(syn::Error::new(
1575                                key.span(),
1576                                "keyword argument repeated: `encoding`",
1577                            ));
1578                        }
1579                        encoding = Some(stream.parse()?);
1580                    } else if key == "endpoint" {
1581                        if fn_path.is_some() {
1582                            return Err(syn::Error::new(
1583                                key.span(),
1584                                "keyword argument repeated: `endpoint`",
1585                            ));
1586                        }
1587                        fn_path = Some(stream.parse()?);
1588                    } else if key == "input" {
1589                        if encoding.is_some() {
1590                            return Err(syn::Error::new(
1591                                key.span(),
1592                                "`encoding` and `input` should not both be \
1593                                 specified",
1594                            ));
1595                        } else if input.is_some() {
1596                            return Err(syn::Error::new(
1597                                key.span(),
1598                                "keyword argument repeated: `input`",
1599                            ));
1600                        }
1601                        input = Some(stream.parse()?);
1602                    } else if key == "input_derive" {
1603                        if input_derive.is_some() {
1604                            return Err(syn::Error::new(
1605                                key.span(),
1606                                "keyword argument repeated: `input_derive`",
1607                            ));
1608                        }
1609                        input_derive = Some(stream.parse()?);
1610                    } else if key == "output" {
1611                        if encoding.is_some() {
1612                            return Err(syn::Error::new(
1613                                key.span(),
1614                                "`encoding` and `output` should not both be \
1615                                 specified",
1616                            ));
1617                        } else if output.is_some() {
1618                            return Err(syn::Error::new(
1619                                key.span(),
1620                                "keyword argument repeated: `output`",
1621                            ));
1622                        }
1623                        output = Some(stream.parse()?);
1624                    } else if key == "server" {
1625                        if server.is_some() {
1626                            return Err(syn::Error::new(
1627                                key.span(),
1628                                "keyword argument repeated: `server`",
1629                            ));
1630                        }
1631                        server = Some(stream.parse()?);
1632                    } else if key == "client" {
1633                        if client.is_some() {
1634                            return Err(syn::Error::new(
1635                                key.span(),
1636                                "keyword argument repeated: `client`",
1637                            ));
1638                        }
1639                        client = Some(stream.parse()?);
1640                    } else if key == "custom" {
1641                        if custom_wrapper.is_some() {
1642                            return Err(syn::Error::new(
1643                                key.span(),
1644                                "keyword argument repeated: `custom`",
1645                            ));
1646                        }
1647                        custom_wrapper = Some(stream.parse()?);
1648                    } else if key == "impl_from" {
1649                        if impl_from.is_some() {
1650                            return Err(syn::Error::new(
1651                                key.span(),
1652                                "keyword argument repeated: `impl_from`",
1653                            ));
1654                        }
1655                        impl_from = Some(stream.parse()?);
1656                    } else if key == "impl_deref" {
1657                        if impl_deref.is_some() {
1658                            return Err(syn::Error::new(
1659                                key.span(),
1660                                "keyword argument repeated: `impl_deref`",
1661                            ));
1662                        }
1663                        impl_deref = Some(stream.parse()?);
1664                    } else if key == "protocol" {
1665                        if protocol.is_some() {
1666                            return Err(syn::Error::new(
1667                                key.span(),
1668                                "keyword argument repeated: `protocol`",
1669                            ));
1670                        }
1671                        protocol = Some(stream.parse()?);
1672                    } else {
1673                        return Err(lookahead.error());
1674                    }
1675                } else {
1676                    let value = key_or_value;
1677                    if use_key_and_value {
1678                        return Err(syn::Error::new(
1679                            value.span(),
1680                            "positional argument follows keyword argument",
1681                        ));
1682                    }
1683                    if arg_pos == 1 {
1684                        struct_name = Some(value)
1685                    } else {
1686                        return Err(syn::Error::new(value.span(), "expected string literal"));
1687                    }
1688                }
1689            } else if lookahead.peek(LitStr) {
1690                if use_key_and_value {
1691                    return Err(syn::Error::new(
1692                        stream.span(),
1693                        "If you use keyword arguments (e.g., `name` = \
1694                         Something), then you can no longer use arguments \
1695                         without a keyword.",
1696                    ));
1697                }
1698                match arg_pos {
1699                    1 => return Err(lookahead.error()),
1700                    2 => prefix = Some(stream.parse()?),
1701                    3 => encoding = Some(stream.parse()?),
1702                    4 => fn_path = Some(stream.parse()?),
1703                    _ => return Err(syn::Error::new(stream.span(), "unexpected extra argument")),
1704                }
1705            } else {
1706                return Err(lookahead.error());
1707            }
1708
1709            if !stream.is_empty() {
1710                stream.parse::<Token![,]>()?;
1711            }
1712        }
1713
1714        // Now parse any remaining extractors (key: Type pattern)
1715        while !stream.is_empty() {
1716            if stream.peek(Ident) && stream.peek2(Token![:]) {
1717                server_args.push_value(stream.parse::<FnArg>()?);
1718                if stream.peek(Comma) {
1719                    server_args.push_punct(stream.parse::<Comma>()?);
1720                } else {
1721                    break;
1722                }
1723            } else {
1724                break;
1725            }
1726        }
1727
1728        // parse legacy encoding into input/output
1729        let mut builtin_encoding = false;
1730        if let Some(encoding) = encoding {
1731            match encoding.value().to_lowercase().as_str() {
1732                "url" => {
1733                    input = Some(type_from_ident(syn::parse_quote!(Url)));
1734                    output = Some(type_from_ident(syn::parse_quote!(Json)));
1735                    builtin_encoding = true;
1736                }
1737                "cbor" => {
1738                    input = Some(type_from_ident(syn::parse_quote!(Cbor)));
1739                    output = Some(type_from_ident(syn::parse_quote!(Cbor)));
1740                    builtin_encoding = true;
1741                }
1742                "getcbor" => {
1743                    input = Some(type_from_ident(syn::parse_quote!(GetUrl)));
1744                    output = Some(type_from_ident(syn::parse_quote!(Cbor)));
1745                    builtin_encoding = true;
1746                }
1747                "getjson" => {
1748                    input = Some(type_from_ident(syn::parse_quote!(GetUrl)));
1749                    output = Some(syn::parse_quote!(Json));
1750                    builtin_encoding = true;
1751                }
1752                _ => return Err(syn::Error::new(encoding.span(), "Encoding not found.")),
1753            }
1754        }
1755
1756        Ok(Self {
1757            struct_name,
1758            prefix,
1759            input,
1760            input_derive,
1761            output,
1762            fn_path,
1763            builtin_encoding,
1764            server,
1765            client,
1766            custom_wrapper,
1767            impl_from,
1768            impl_deref,
1769            protocol,
1770            server_args,
1771        })
1772    }
1773}
1774
1775/// An argument type in a server function.
1776#[allow(unused)]
1777// todo - we used to support a number of these attributes and pass them along to serde. bring them back.
1778#[derive(Debug, Clone)]
1779struct ServerFnArg {
1780    /// The attributes on the server function argument.
1781    server_fn_attributes: Vec<Attribute>,
1782    /// The type of the server function argument.
1783    arg: syn::PatType,
1784}
1785
1786impl ToTokens for ServerFnArg {
1787    fn to_tokens(&self, tokens: &mut TokenStream2) {
1788        let ServerFnArg { arg, .. } = self;
1789        tokens.extend(quote! {
1790            #arg
1791        });
1792    }
1793}
1794
1795impl Parse for ServerFnArg {
1796    fn parse(input: ParseStream) -> Result<Self> {
1797        let arg: syn::FnArg = input.parse()?;
1798        let mut arg = match arg {
1799            FnArg::Receiver(_) => {
1800                return Err(syn::Error::new(
1801                    arg.span(),
1802                    "cannot use receiver types in server function macro",
1803                ));
1804            }
1805            FnArg::Typed(t) => t,
1806        };
1807
1808        fn rename_path(path: Path, from_ident: Ident, to_ident: Ident) -> Path {
1809            if path.is_ident(&from_ident) {
1810                Path {
1811                    leading_colon: None,
1812                    segments: Punctuated::from_iter([PathSegment {
1813                        ident: to_ident,
1814                        arguments: PathArguments::None,
1815                    }]),
1816                }
1817            } else {
1818                path
1819            }
1820        }
1821
1822        let server_fn_attributes = arg
1823            .attrs
1824            .iter()
1825            .cloned()
1826            .map(|attr| {
1827                if attr.path().is_ident("server") {
1828                    // Allow the following attributes:
1829                    // - #[server(default)]
1830                    // - #[server(rename = "fieldName")]
1831
1832                    // Rename `server` to `serde`
1833                    let attr = Attribute {
1834                        meta: match attr.meta {
1835                            Meta::Path(path) => Meta::Path(rename_path(
1836                                path,
1837                                format_ident!("server"),
1838                                format_ident!("serde"),
1839                            )),
1840                            Meta::List(mut list) => {
1841                                list.path = rename_path(
1842                                    list.path,
1843                                    format_ident!("server"),
1844                                    format_ident!("serde"),
1845                                );
1846                                Meta::List(list)
1847                            }
1848                            Meta::NameValue(mut name_value) => {
1849                                name_value.path = rename_path(
1850                                    name_value.path,
1851                                    format_ident!("server"),
1852                                    format_ident!("serde"),
1853                                );
1854                                Meta::NameValue(name_value)
1855                            }
1856                        },
1857                        ..attr
1858                    };
1859
1860                    let args = attr.parse_args::<Meta>()?;
1861                    match args {
1862                        // #[server(default)]
1863                        Meta::Path(path) if path.is_ident("default") => Ok(attr.clone()),
1864                        // #[server(flatten)]
1865                        Meta::Path(path) if path.is_ident("flatten") => Ok(attr.clone()),
1866                        // #[server(default = "value")]
1867                        Meta::NameValue(name_value) if name_value.path.is_ident("default") => {
1868                            Ok(attr.clone())
1869                        }
1870                        // #[server(skip)]
1871                        Meta::Path(path) if path.is_ident("skip") => Ok(attr.clone()),
1872                        // #[server(rename = "value")]
1873                        Meta::NameValue(name_value) if name_value.path.is_ident("rename") => {
1874                            Ok(attr.clone())
1875                        }
1876                        _ => Err(Error::new(
1877                            attr.span(),
1878                            "Unrecognized #[server] attribute, expected \
1879                             #[server(default)] or #[server(rename = \
1880                             \"fieldName\")]",
1881                        )),
1882                    }
1883                } else if attr.path().is_ident("doc") {
1884                    // Allow #[doc = "documentation"]
1885                    Ok(attr.clone())
1886                } else if attr.path().is_ident("allow") {
1887                    // Allow #[allow(...)]
1888                    Ok(attr.clone())
1889                } else if attr.path().is_ident("deny") {
1890                    // Allow #[deny(...)]
1891                    Ok(attr.clone())
1892                } else if attr.path().is_ident("ignore") {
1893                    // Allow #[ignore]
1894                    Ok(attr.clone())
1895                } else {
1896                    Err(Error::new(
1897                        attr.span(),
1898                        "Unrecognized attribute, expected #[server(...)]",
1899                    ))
1900                }
1901            })
1902            .collect::<Result<Vec<_>>>()?;
1903        arg.attrs = vec![];
1904        Ok(ServerFnArg {
1905            arg,
1906            server_fn_attributes,
1907        })
1908    }
1909}
1910
1911fn type_from_ident(ident: Ident) -> Type {
1912    let mut segments = Punctuated::new();
1913    segments.push(PathSegment {
1914        ident,
1915        arguments: PathArguments::None,
1916    });
1917    Type::Path(TypePath {
1918        qself: None,
1919        path: Path {
1920            leading_colon: None,
1921            segments,
1922        },
1923    })
1924}