dioxus_fullstack_macro/
lib.rs

1// TODO: Create README, uncomment this: #![doc = include_str!("../README.md")]
2#![doc(html_logo_url = "https://avatars.githubusercontent.com/u/79236386")]
3#![doc(html_favicon_url = "https://avatars.githubusercontent.com/u/79236386")]
4
5use core::panic;
6use proc_macro::TokenStream;
7use proc_macro2::{Span, TokenStream as TokenStream2};
8use quote::ToTokens;
9use quote::{format_ident, quote};
10use std::collections::HashMap;
11use syn::{
12    braced, bracketed,
13    parse::ParseStream,
14    punctuated::Punctuated,
15    token::{Comma, Slash},
16    Error, ExprTuple, FnArg, Meta, PathArguments, PathSegment, Token, Type, TypePath,
17};
18use syn::{parse::Parse, parse_quote, Ident, ItemFn, LitStr, Path};
19use syn::{spanned::Spanned, LitBool, LitInt, Pat, PatType};
20use syn::{
21    token::{Brace, Star},
22    Attribute, Expr, ExprClosure, Lit, Result,
23};
24
25/// ## Usage
26///
27/// ```rust,ignore
28/// # use dioxus::prelude::*;
29/// # #[derive(serde::Deserialize, serde::Serialize)]
30/// # struct BlogPost;
31/// # async fn load_posts(category: &str) -> Result<Vec<BlogPost>> { unimplemented!() }
32///
33/// #[server]
34/// async fn blog_posts(category: String) -> Result<Vec<BlogPost>> {
35///     let posts = load_posts(&category).await?;
36///     // maybe do some other work
37///     Ok(posts)
38/// }
39/// ```
40///
41/// ## Named Arguments
42///
43/// You can use any combination of the following named arguments:
44/// - `endpoint`: a prefix at which the server function handler will be mounted (defaults to `/api`).
45///   Example: `endpoint = "/my_api/my_serverfn"`.
46/// - `input`: the encoding for the arguments, defaults to `Json<T>`
47///     - You may customize the encoding of the arguments by specifying a different type for `input`.
48///     - Any axum `IntoRequest` extractor can be used here, and dioxus provides
49///       - `Json<T>`: The default axum `Json` extractor that decodes JSON-encoded request bodies.
50///       - `Cbor<T>`: A custom axum `Cbor` extractor that decodes CBOR-encoded request bodies.
51///       - `MessagePack<T>`: A custom axum `MessagePack` extractor that decodes MessagePack-encoded request bodies.
52/// - `output`: the encoding for the response (defaults to `Json`).
53///     - The `output` argument specifies how the server should encode the response data.
54///     - Acceptable values include:
55///       - `Json`: A response encoded as JSON (default). This is ideal for most web applications.
56///       - `Cbor`: A response encoded in the CBOR format for efficient, binary-encoded data.
57/// - `client`: a custom `Client` implementation that will be used for this server function. This allows
58///   customization of the client-side behavior if needed.
59///
60/// ## Advanced Usage of `input` and `output` Fields
61///
62/// The `input` and `output` fields allow you to customize how arguments and responses are encoded and decoded.
63/// These fields impose specific trait bounds on the types you use. Here are detailed examples for different scenarios:
64///
65/// ## Adding layers to server functions
66///
67/// Layers allow you to transform the request and response of a server function. You can use layers
68/// to add authentication, logging, or other functionality to your server functions. Server functions integrate
69/// with the tower ecosystem, so you can use any layer that is compatible with tower.
70///
71/// Common layers include:
72/// - [`tower_http::trace::TraceLayer`](https://docs.rs/tower-http/latest/tower_http/trace/struct.TraceLayer.html) for tracing requests and responses
73/// - [`tower_http::compression::CompressionLayer`](https://docs.rs/tower-http/latest/tower_http/compression/struct.CompressionLayer.html) for compressing large responses
74/// - [`tower_http::cors::CorsLayer`](https://docs.rs/tower-http/latest/tower_http/cors/struct.CorsLayer.html) for adding CORS headers to responses
75/// - [`tower_http::timeout::TimeoutLayer`](https://docs.rs/tower-http/latest/tower_http/timeout/struct.TimeoutLayer.html) for adding timeouts to requests
76/// - [`tower_sessions::service::SessionManagerLayer`](https://docs.rs/tower-sessions/0.13.0/tower_sessions/service/struct.SessionManagerLayer.html) for adding session management to requests
77///
78/// You can add a tower [`Layer`](https://docs.rs/tower/latest/tower/trait.Layer.html) to your server function with the middleware attribute:
79///
80/// ```rust,ignore
81/// # use dioxus::prelude::*;
82/// #[server]
83/// // The TraceLayer will log all requests to the console
84/// #[middleware(tower_http::timeout::TimeoutLayer::new(std::time::Duration::from_secs(5)))]
85/// pub async fn my_wacky_server_fn(input: Vec<String>) -> ServerFnResult<usize> {
86///     unimplemented!()
87/// }
88/// ```
89#[proc_macro_attribute]
90pub fn server(attr: proc_macro::TokenStream, mut item: TokenStream) -> TokenStream {
91    // Parse the attribute list using the old server_fn arg parser.
92    let args = match syn::parse::<ServerFnArgs>(attr) {
93        Ok(args) => args,
94        Err(err) => {
95            let err: TokenStream = err.to_compile_error().into();
96            item.extend(err);
97            return item;
98        }
99    };
100
101    let method = Method::Post(Ident::new("POST", proc_macro2::Span::call_site()));
102    let prefix = args
103        .prefix
104        .unwrap_or_else(|| LitStr::new("/api", Span::call_site()));
105
106    let route: Route = Route {
107        method: None,
108        path_params: vec![],
109        query_params: vec![],
110        route_lit: args.fn_path,
111        oapi_options: None,
112        server_args: Default::default(),
113        prefix: Some(prefix),
114        _input_encoding: args.input,
115        _output_encoding: args.output,
116    };
117
118    match route_impl_with_route(route, item.clone(), Some(method)) {
119        Ok(mut tokens) => {
120            // Let's add some deprecated warnings to the various fields from `args` if the user is using them...
121            // We don't generate structs anymore, don't use various protocols, etc
122            if let Some(name) = args.struct_name {
123                tokens.extend(quote! {
124                    const _: () = {
125                        #[deprecated(note = "Dioxus server functions no longer generate a struct for the server function. The function itself is used directly.")]
126                        struct #name;
127                        fn ___assert_deprecated() {
128                            let _ = #name;
129                        }
130
131                        ()
132                    };
133                });
134            }
135
136            //
137            tokens.into()
138        }
139
140        // Retain the original function item and append the error to it. Better for autocomplete.
141        Err(err) => {
142            let err: TokenStream = err.to_compile_error().into();
143            item.extend(err);
144            item
145        }
146    }
147}
148
149#[proc_macro_attribute]
150pub fn get(args: proc_macro::TokenStream, body: TokenStream) -> TokenStream {
151    wrapped_route_impl(args, body, Some(Method::new_from_string("GET")))
152}
153
154#[proc_macro_attribute]
155pub fn post(args: proc_macro::TokenStream, body: TokenStream) -> TokenStream {
156    wrapped_route_impl(args, body, Some(Method::new_from_string("POST")))
157}
158
159#[proc_macro_attribute]
160pub fn put(args: proc_macro::TokenStream, body: TokenStream) -> TokenStream {
161    wrapped_route_impl(args, body, Some(Method::new_from_string("PUT")))
162}
163
164#[proc_macro_attribute]
165pub fn delete(args: proc_macro::TokenStream, body: TokenStream) -> TokenStream {
166    wrapped_route_impl(args, body, Some(Method::new_from_string("DELETE")))
167}
168
169#[proc_macro_attribute]
170pub fn patch(args: proc_macro::TokenStream, body: TokenStream) -> TokenStream {
171    wrapped_route_impl(args, body, Some(Method::new_from_string("PATCH")))
172}
173
174fn wrapped_route_impl(
175    attr: TokenStream,
176    mut item: TokenStream,
177    method: Option<Method>,
178) -> TokenStream {
179    match route_impl(attr, item.clone(), method) {
180        Ok(tokens) => tokens.into(),
181        Err(err) => {
182            let err: TokenStream = err.to_compile_error().into();
183            item.extend(err);
184            item
185        }
186    }
187}
188
189fn route_impl(
190    attr: TokenStream,
191    item: TokenStream,
192    method_from_macro: Option<Method>,
193) -> syn::Result<TokenStream2> {
194    let route = syn::parse::<Route>(attr)?;
195    route_impl_with_route(route, item, method_from_macro)
196}
197
198fn route_impl_with_route(
199    route: Route,
200    item: TokenStream,
201    method_from_macro: Option<Method>,
202) -> syn::Result<TokenStream2> {
203    // Parse the route and function
204    let mut function = syn::parse::<ItemFn>(item)?;
205
206    // Collect the middleware initializers
207    let middleware_layers = function
208        .attrs
209        .iter()
210        .filter(|attr| attr.path().is_ident("middleware"))
211        .map(|f| match &f.meta {
212            Meta::List(meta_list) => Ok({
213                let tokens = &meta_list.tokens;
214                quote! { .layer(#tokens) }
215            }),
216            _ => Err(Error::new(
217                f.span(),
218                "Expected middleware attribute to be a list, e.g. #[middleware(MyLayer::new())]",
219            )),
220        })
221        .collect::<Result<Vec<_>>>()?;
222
223    // don't re-emit the middleware attribute on the inner
224    function
225        .attrs
226        .retain(|attr| !attr.path().is_ident("middleware"));
227
228    // Attach `#[allow(unused_mut)]` to all original inputs to avoid warnings
229    let outer_inputs = function
230        .sig
231        .inputs
232        .iter()
233        .enumerate()
234        .map(|(i, arg)| match arg {
235            FnArg::Receiver(_receiver) => panic!("Self type is not supported"),
236            FnArg::Typed(pat_type) => match pat_type.pat.as_ref() {
237                Pat::Ident(_) => {
238                    quote! { #[allow(unused_mut)] #pat_type }
239                }
240                _ => {
241                    let ident = format_ident!("___Arg{}", i);
242                    let ty = &pat_type.ty;
243                    quote! { #[allow(unused_mut)] #ident: #ty }
244                }
245            },
246        })
247        .collect::<Punctuated<_, Token![,]>>();
248    // .collect::<Punctuated<_, Token![,]>>();
249
250    let route = CompiledRoute::from_route(route, &function, false, method_from_macro)?;
251    let query_params_struct = route.query_params_struct(false);
252    let method_ident = &route.method;
253    let body_json_args = route.remaining_pattypes_named(&function.sig.inputs);
254    let body_json_names = body_json_args
255        .iter()
256        .map(|(i, pat_type)| match &*pat_type.pat {
257            Pat::Ident(ref pat_ident) => pat_ident.ident.clone(),
258            _ => format_ident!("___Arg{}", i),
259        })
260        .collect::<Vec<_>>();
261    let body_json_types = body_json_args
262        .iter()
263        .map(|pat_type| &pat_type.1.ty)
264        .collect::<Vec<_>>();
265    let route_docs = route.to_doc_comments();
266
267    // Get the variables we need for code generation
268    let fn_on_server_name = &function.sig.ident;
269    let vis = &function.vis;
270    let (impl_generics, ty_generics, where_clause) = &function.sig.generics.split_for_impl();
271    let ty_generics = ty_generics.as_turbofish();
272    let fn_docs = function
273        .attrs
274        .iter()
275        .filter(|attr| attr.path().is_ident("doc"));
276
277    let __axum = quote! { dioxus_server::axum };
278
279    let output_type = match &function.sig.output {
280        syn::ReturnType::Default => parse_quote! { () },
281        syn::ReturnType::Type(_, ty) => (*ty).clone(),
282    };
283
284    let query_param_names = route
285        .query_params
286        .iter()
287        .filter(|c| !c.catch_all)
288        .map(|param| &param.binding);
289
290    let path_param_args = route.path_params.iter().map(|(_slash, param)| match param {
291        PathParam::Capture(_lit, _brace_1, ident, _ty, _brace_2) => {
292            Some(quote! { #ident = #ident, })
293        }
294        PathParam::WildCard(_lit, _brace_1, _star, ident, _ty, _brace_2) => {
295            Some(quote! { #ident = #ident, })
296        }
297        PathParam::Static(_lit) => None,
298    });
299
300    let out_ty = match output_type.as_ref() {
301        Type::Tuple(tuple) if tuple.elems.is_empty() => parse_quote! { () },
302        _ => output_type.clone(),
303    };
304
305    let mut function_on_server = function.clone();
306    function_on_server
307        .sig
308        .inputs
309        .extend(route.server_args.clone());
310
311    let server_names = route
312        .server_args
313        .iter()
314        .enumerate()
315        .map(|(i, pat_type)| match pat_type {
316            FnArg::Typed(_pat_type) => format_ident!("___sarg___{}", i),
317            FnArg::Receiver(_) => panic!("Self type is not supported"),
318        })
319        .collect::<Vec<_>>();
320
321    let server_types = route
322        .server_args
323        .iter()
324        .map(|pat_type| match pat_type {
325            FnArg::Receiver(_) => parse_quote! { () },
326            FnArg::Typed(pat_type) => (*pat_type.ty).clone(),
327        })
328        .collect::<Vec<_>>();
329
330    let body_struct_impl = {
331        let tys = body_json_types
332            .iter()
333            .enumerate()
334            .map(|(idx, _)| format_ident!("__Ty{}", idx));
335
336        let names = body_json_names.iter().enumerate().map(|(idx, name)| {
337            let ty_name = format_ident!("__Ty{}", idx);
338            quote! { #name: #ty_name }
339        });
340
341        quote! {
342            #[derive(serde::Serialize, serde::Deserialize)]
343            #[serde(crate = "serde")]
344            struct ___Body_Serialize___< #(#tys,)* > {
345                #(#names,)*
346            }
347        }
348    };
349
350    // This unpacks the body struct into the individual variables that get scoped
351    let unpack_closure = {
352        let unpack_args = body_json_names.iter().map(|name| quote! { data.#name });
353        quote! {
354            |data| { ( #(#unpack_args,)* ) }
355        }
356    };
357
358    let as_axum_path = route.to_axum_path_string();
359
360    let query_endpoint = if let Some(full_url) = route.url_without_queries_for_format() {
361        quote! { format!(#full_url, #( #path_param_args)*) }
362    } else {
363        quote! { __ENDPOINT_PATH.to_string() }
364    };
365
366    let endpoint_path = {
367        let prefix = route
368            .prefix
369            .as_ref()
370            .cloned()
371            .unwrap_or_else(|| LitStr::new("", Span::call_site()));
372
373        let route_lit = if let Some(lit) = as_axum_path {
374            quote! { #lit }
375        } else {
376            let name =
377                route.route_lit.as_ref().cloned().unwrap_or_else(|| {
378                    LitStr::new(&fn_on_server_name.to_string(), Span::call_site())
379                });
380            quote! {
381                concat!(
382                    "/",
383                    #name
384                )
385            }
386        };
387
388        let hash = match route.prefix.as_ref() {
389            // Implicit route lit, we need to hash the function signature to avoid collisions
390            Some(_) if route.route_lit.is_none() => {
391                // let enable_hash = option_env!("DISABLE_SERVER_FN_HASH").is_none();
392                let key_env_var = match option_env!("SERVER_FN_OVERRIDE_KEY") {
393                    Some(_) => "SERVER_FN_OVERRIDE_KEY",
394                    None => "CARGO_MANIFEST_DIR",
395                };
396                quote! {
397                    dioxus_fullstack::xxhash_rust::const_xxh64::xxh64(
398                        concat!(env!(#key_env_var), ":", module_path!()).as_bytes(),
399                        0
400                    )
401                }
402            }
403
404            // Explicit route lit, no need to hash
405            _ => quote! { "" },
406        };
407
408        quote! {
409            dioxus_fullstack::const_format::concatcp!(#prefix, #route_lit, #hash)
410        }
411    };
412
413    let extracted_idents = route.extracted_idents();
414
415    let query_tokens = if route.query_is_catchall() {
416        let query = route
417            .query_params
418            .iter()
419            .find(|param| param.catch_all)
420            .unwrap();
421        let input = &function.sig.inputs[query.arg_idx];
422        let name = match input {
423            FnArg::Typed(pat_type) => match pat_type.pat.as_ref() {
424                Pat::Ident(ref pat_ident) => pat_ident.ident.clone(),
425                _ => format_ident!("___Arg{}", query.arg_idx),
426            },
427            FnArg::Receiver(_receiver) => panic!(),
428        };
429        quote! {
430            #name
431        }
432    } else {
433        quote! {
434            __QueryParams__ { #(#query_param_names,)* }
435        }
436    };
437
438    let extracted_as_server_headers = route.extracted_as_server_headers(query_tokens.clone());
439
440    Ok(quote! {
441        #(#fn_docs)*
442        #route_docs
443        #[deny(
444            unexpected_cfgs,
445            reason = "
446==========================================================================================
447  Using Dioxus Server Functions requires a `server` feature flag in your `Cargo.toml`.
448  Please add the following to your `Cargo.toml`:
449
450  ```toml
451  [features]
452  server = [\"dioxus/server\"]
453  ```
454
455  To enable better Rust-Analyzer support, you can make `server` a default feature:
456  ```toml
457  [features]
458  default = [\"web\", \"server\"]
459  web = [\"dioxus/web\"]
460  server = [\"dioxus/server\"]
461  ```
462==========================================================================================
463        "
464        )]
465        #vis async fn #fn_on_server_name #impl_generics( #outer_inputs ) -> #out_ty #where_clause {
466            use dioxus_fullstack::serde as serde;
467            use dioxus_fullstack::{
468                // concrete types
469                ServerFnEncoder, ServerFnDecoder, FullstackContext,
470
471                // "magic" traits for encoding/decoding on the client
472                ExtractRequest, EncodeRequest, RequestDecodeResult, RequestDecodeErr,
473
474                // "magic" traits for encoding/decoding on the server
475                MakeAxumResponse, MakeAxumError,
476            };
477
478            #query_params_struct
479
480            #body_struct_impl
481
482            const __ENDPOINT_PATH: &str = #endpoint_path;
483
484            {
485                _ = dioxus_fullstack::assert_is_result::<#out_ty>();
486
487                let verify_token = (&&&&&&&&&&&&&&ServerFnEncoder::<___Body_Serialize___<#(#body_json_types,)*>, (#(#body_json_types,)*)>::new())
488                    .verify_can_serialize();
489
490                dioxus_fullstack::assert_can_encode(verify_token);
491
492                let decode_token = (&&&&&ServerFnDecoder::<#out_ty>::new())
493                    .verify_can_deserialize();
494
495                dioxus_fullstack::assert_can_decode(decode_token);
496            };
497
498
499            // On the client, we make the request to the server
500            // We want to support extremely flexible error types and return types, making this more complex than it should
501            #[allow(clippy::unused_unit)]
502            #[cfg(not(feature = "server"))]
503            {
504                let client = dioxus_fullstack::ClientRequest::new(
505                    dioxus_fullstack::http::Method::#method_ident,
506                    #query_endpoint,
507                    &#query_tokens,
508                );
509
510                let response = (&&&&&&&&&&&&&&ServerFnEncoder::<___Body_Serialize___<#(#body_json_types,)*>, (#(#body_json_types,)*)>::new())
511                    .fetch_client(client, ___Body_Serialize___ { #(#body_json_names,)* }, #unpack_closure)
512                    .await;
513
514                let decoded = (&&&&&ServerFnDecoder::<#out_ty>::new())
515                    .decode_client_response(response)
516                    .await;
517
518                let result = (&&&&&ServerFnDecoder::<#out_ty>::new())
519                    .decode_client_err(decoded)
520                    .await;
521
522                return result;
523            }
524
525            // On the server, we expand the tokens and submit the function to inventory
526            #[cfg(feature = "server")] {
527                #function_on_server
528
529                #[allow(clippy::unused_unit)]
530                fn __inner__function__ #impl_generics(
531                    ___state: #__axum::extract::State<FullstackContext>,
532                    ___request: #__axum::extract::Request,
533                ) -> std::pin::Pin<Box<dyn std::future::Future<Output = #__axum::response::Response>>> #where_clause {
534                    Box::pin(async move {
535                         match (&&&&&&&&&&&&&&ServerFnEncoder::<___Body_Serialize___<#(#body_json_types,)*>, (#(#body_json_types,)*)>::new()).extract_axum(___state.0, ___request, #unpack_closure).await {
536                            Ok(((#(#body_json_names,)* ), (#(#extracted_as_server_headers,)* #(#server_names,)*) )) => {
537                                // Call the user function
538                                let res = #fn_on_server_name #ty_generics(#(#extracted_idents,)* #(#body_json_names,)* #(#server_names,)*).await;
539
540                                // Encode the response Into a `Result<T, E>`
541                                let encoded = (&&&&&&ServerFnDecoder::<#out_ty>::new()).make_axum_response(res);
542
543                                // And then encode `Result<T, E>` into `Response`
544                                (&&&&&ServerFnDecoder::<#out_ty>::new()).make_axum_error(encoded)
545                            },
546                            Err(res) => res,
547                        }
548                    })
549                }
550
551                dioxus_server::inventory::submit! {
552                    dioxus_server::ServerFunction::new(
553                        dioxus_server::http::Method::#method_ident,
554                        __ENDPOINT_PATH,
555                        || {
556                            dioxus_server::ServerFunction::make_handler(dioxus_server::http::Method::#method_ident, __inner__function__ #ty_generics)
557                                #(#middleware_layers)*
558                        }
559                    )
560                }
561
562                // Extract the server arguments from the context if needed.
563                let (#(#server_names,)*) = dioxus_fullstack::FullstackContext::extract::<(#(#server_types,)*), _>().await?;
564
565                // Call the function directly
566                return #fn_on_server_name #ty_generics(
567                    #(#extracted_idents,)*
568                    #(#body_json_names,)*
569                    #(#server_names,)*
570                ).await;
571            }
572
573            #[allow(unreachable_code)]
574            {
575                unreachable!()
576            }
577        }
578    })
579}
580
581struct CompiledRoute {
582    method: Method,
583    #[allow(clippy::type_complexity)]
584    path_params: Vec<(Slash, PathParam)>,
585    query_params: Vec<QueryParam>,
586    route_lit: Option<LitStr>,
587    prefix: Option<LitStr>,
588    oapi_options: Option<OapiOptions>,
589    server_args: Punctuated<FnArg, Comma>,
590}
591
592struct QueryParam {
593    arg_idx: usize,
594    name: String,
595    binding: Ident,
596    catch_all: bool,
597    ty: Box<Type>,
598}
599
600impl CompiledRoute {
601    fn to_axum_path_string(&self) -> Option<String> {
602        if self.prefix.is_some() {
603            return None;
604        }
605
606        let mut path = String::new();
607
608        for (_slash, param) in &self.path_params {
609            path.push('/');
610            match param {
611                PathParam::Capture(lit, _brace_1, _, _, _brace_2) => {
612                    path.push('{');
613                    path.push_str(&lit.value());
614                    path.push('}');
615                }
616                PathParam::WildCard(lit, _brace_1, _, _, _, _brace_2) => {
617                    path.push('{');
618                    path.push('*');
619                    path.push_str(&lit.value());
620                    path.push('}');
621                }
622                PathParam::Static(lit) => path.push_str(&lit.value()),
623            }
624        }
625
626        Some(path)
627    }
628
629    /// Removes the arguments in `route` from `args`, and merges them in the output.
630    pub fn from_route(
631        mut route: Route,
632        function: &ItemFn,
633        with_aide: bool,
634        method_from_macro: Option<Method>,
635    ) -> syn::Result<Self> {
636        if !with_aide && route.oapi_options.is_some() {
637            return Err(syn::Error::new(
638                Span::call_site(),
639                "Use `api_route` instead of `route` to use OpenAPI options",
640            ));
641        } else if with_aide && route.oapi_options.is_none() {
642            route.oapi_options = Some(OapiOptions {
643                summary: None,
644                description: None,
645                id: None,
646                hidden: None,
647                tags: None,
648                security: None,
649                responses: None,
650                transform: None,
651            });
652        }
653
654        let sig = &function.sig;
655        let mut arg_map = sig
656            .inputs
657            .iter()
658            .enumerate()
659            .filter_map(|(i, item)| match item {
660                syn::FnArg::Receiver(_) => None,
661                syn::FnArg::Typed(pat_type) => Some((i, pat_type)),
662            })
663            .filter_map(|(i, pat_type)| match &*pat_type.pat {
664                syn::Pat::Ident(ident) => Some((ident.ident.clone(), (pat_type.ty.clone(), i))),
665                _ => None,
666            })
667            .collect::<HashMap<_, _>>();
668
669        for (_slash, path_param) in &mut route.path_params {
670            match path_param {
671                PathParam::Capture(_lit, _, ident, ty, _) => {
672                    let (new_ident, new_ty) = arg_map.remove_entry(ident).ok_or_else(|| {
673                        syn::Error::new(
674                            ident.span(),
675                            format!("path parameter `{}` not found in function arguments", ident),
676                        )
677                    })?;
678                    *ident = new_ident;
679                    *ty = new_ty.0;
680                }
681                PathParam::WildCard(_lit, _, _star, ident, ty, _) => {
682                    let (new_ident, new_ty) = arg_map.remove_entry(ident).ok_or_else(|| {
683                        syn::Error::new(
684                            ident.span(),
685                            format!("path parameter `{}` not found in function arguments", ident),
686                        )
687                    })?;
688                    *ident = new_ident;
689                    *ty = new_ty.0;
690                }
691                PathParam::Static(_lit) => {}
692            }
693        }
694
695        let mut query_params = Vec::new();
696        for param in route.query_params {
697            let (ident, ty) = arg_map.remove_entry(&param.binding).ok_or_else(|| {
698                syn::Error::new(
699                    param.binding.span(),
700                    format!(
701                        "query parameter `{}` not found in function arguments",
702                        param.binding
703                    ),
704                )
705            })?;
706            query_params.push(QueryParam {
707                binding: ident,
708                name: param.name,
709                catch_all: param.catch_all,
710                ty: ty.0,
711                arg_idx: ty.1,
712            });
713        }
714
715        // Disallow multiple query params if one is a catch-all
716        if query_params.iter().any(|param| param.catch_all) && query_params.len() > 1 {
717            return Err(syn::Error::new(
718                Span::call_site(),
719                "Cannot have multiple query parameters when one is a catch-all",
720            ));
721        }
722
723        if let Some(options) = route.oapi_options.as_mut() {
724            options.merge_with_fn(function)
725        }
726
727        let method = match (method_from_macro, route.method) {
728            (Some(method), None) => method,
729            (None, Some(method)) => method,
730            (Some(_), Some(_)) => {
731                return Err(syn::Error::new(
732                    Span::call_site(),
733                    "HTTP method specified both in macro and in attribute",
734                ))
735            }
736            (None, None) => {
737                return Err(syn::Error::new(
738                    Span::call_site(),
739                    "HTTP method not specified in macro or in attribute",
740                ))
741            }
742        };
743
744        Ok(Self {
745            method,
746            route_lit: route.route_lit,
747            path_params: route.path_params,
748            query_params,
749            oapi_options: route.oapi_options,
750            prefix: route.prefix,
751            server_args: route.server_args,
752        })
753    }
754
755    pub fn query_is_catchall(&self) -> bool {
756        self.query_params.iter().any(|param| param.catch_all)
757    }
758
759    pub fn extracted_as_server_headers(&self, query_tokens: TokenStream2) -> Vec<Pat> {
760        let mut out = vec![];
761
762        // Add the path extractor
763        out.push({
764            let path_iter = self
765                .path_params
766                .iter()
767                .filter_map(|(_slash, path_param)| path_param.capture());
768            let idents = path_iter.clone().map(|item| item.0);
769            parse_quote! {
770                dioxus_server::axum::extract::Path((#(#idents,)*))
771            }
772        });
773
774        out.push(parse_quote!(
775            dioxus_fullstack::payloads::Query(#query_tokens)
776        ));
777
778        out
779    }
780
781    pub fn query_params_struct(&self, with_aide: bool) -> TokenStream2 {
782        let fields = self.query_params.iter().map(|item| {
783            let name = &item.name;
784            let binding = &item.binding;
785            let ty = &item.ty;
786            if item.catch_all {
787                quote! {}
788            } else if item.binding != item.name {
789                quote! {
790                    #[serde(rename = #name)]
791                    #binding: #ty,
792                }
793            } else {
794                quote! { #binding: #ty, }
795            }
796        });
797        let derive = match with_aide {
798            true => quote! {
799                #[derive(serde::Deserialize, serde::Serialize, ::schemars::JsonSchema)]
800                #[serde(crate = "serde")]
801            },
802            false => quote! {
803                #[derive(serde::Deserialize, serde::Serialize)]
804                #[serde(crate = "serde")]
805            },
806        };
807        quote! {
808            #derive
809            struct __QueryParams__ {
810                #(#fields)*
811            }
812        }
813    }
814
815    pub fn extracted_idents(&self) -> Vec<Ident> {
816        let mut idents = Vec::new();
817        for (_slash, path_param) in &self.path_params {
818            if let Some((ident, _ty)) = path_param.capture() {
819                idents.push(ident.clone());
820            }
821        }
822        for param in &self.query_params {
823            idents.push(param.binding.clone());
824        }
825        idents
826    }
827
828    fn remaining_pattypes_named(&self, args: &Punctuated<FnArg, Comma>) -> Vec<(usize, PatType)> {
829        args.iter()
830            .enumerate()
831            .filter_map(|(i, item)| {
832                if let FnArg::Typed(pat_type) = item {
833                    if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
834                        if self.path_params.iter().any(|(_slash, path_param)| {
835                            if let Some((path_ident, _ty)) = path_param.capture() {
836                                path_ident == &pat_ident.ident
837                            } else {
838                                false
839                            }
840                        }) || self
841                            .query_params
842                            .iter()
843                            .any(|query| query.binding == pat_ident.ident)
844                        {
845                            return None;
846                        }
847                    }
848
849                    Some((i, pat_type.clone()))
850                } else {
851                    unimplemented!("Self type is not supported")
852                }
853            })
854            .collect()
855    }
856
857    pub(crate) fn to_doc_comments(&self) -> TokenStream2 {
858        let mut doc = format!(
859            "# Handler information
860- Method: `{}`
861- Path: `{}`",
862            self.method.to_axum_method_name(),
863            self.route_lit
864                .as_ref()
865                .map(|lit| lit.value())
866                .unwrap_or_else(|| "<auto>".into()),
867        );
868
869        if let Some(options) = &self.oapi_options {
870            let summary = options
871                .summary
872                .as_ref()
873                .map(|(_, summary)| format!("\"{}\"", summary.value()))
874                .unwrap_or("None".to_string());
875            let description = options
876                .description
877                .as_ref()
878                .map(|(_, description)| format!("\"{}\"", description.value()))
879                .unwrap_or("None".to_string());
880            let id = options
881                .id
882                .as_ref()
883                .map(|(_, id)| format!("\"{}\"", id.value()))
884                .unwrap_or("None".to_string());
885            let hidden = options
886                .hidden
887                .as_ref()
888                .map(|(_, hidden)| hidden.value().to_string())
889                .unwrap_or("None".to_string());
890            let tags = options
891                .tags
892                .as_ref()
893                .map(|(_, tags)| tags.to_string())
894                .unwrap_or("[]".to_string());
895            let security = options
896                .security
897                .as_ref()
898                .map(|(_, security)| security.to_string())
899                .unwrap_or("{}".to_string());
900
901            doc = format!(
902                "{doc}
903
904## OpenAPI
905- Summary: `{summary}`
906- Description: `{description}`
907- Operation id: `{id}`
908- Tags: `{tags}`
909- Security: `{security}`
910- Hidden: `{hidden}`
911"
912            );
913        }
914
915        quote!(
916            #[doc = #doc]
917        )
918    }
919
920    fn url_without_queries_for_format(&self) -> Option<String> {
921        // If there's a prefix, then it's an old-style route, and we can't generate a format string.
922        if self.prefix.is_some() {
923            return None;
924        }
925
926        // If there's no explicit route, we can't generate a format string this way.
927        let _lit = self.route_lit.as_ref()?;
928
929        let url_without_queries =
930            self.path_params
931                .iter()
932                .fold(String::new(), |mut acc, (_slash, param)| {
933                    acc.push('/');
934                    match param {
935                        PathParam::Capture(lit, _brace_1, _, _, _brace_2) => {
936                            acc.push_str(&format!("{{{}}}", lit.value()));
937                        }
938                        PathParam::WildCard(lit, _brace_1, _, _, _, _brace_2) => {
939                            // no `*` since we want to use the argument *as the wildcard* when making requests
940                            // it's not super applicable to server functions, more for general route generation
941                            acc.push_str(&format!("{{{}}}", lit.value()));
942                        }
943                        PathParam::Static(lit) => {
944                            acc.push_str(&lit.value());
945                        }
946                    }
947                    acc
948                });
949
950        let prefix = self
951            .prefix
952            .as_ref()
953            .cloned()
954            .unwrap_or_else(|| LitStr::new("", Span::call_site()))
955            .value();
956        let full_url = format!(
957            "{}{}{}",
958            prefix,
959            if url_without_queries.starts_with("/") {
960                ""
961            } else {
962                "/"
963            },
964            url_without_queries
965        );
966
967        Some(full_url)
968    }
969}
970
971struct RouteParser {
972    path_params: Vec<(Slash, PathParam)>,
973    query_params: Vec<QueryParam>,
974}
975
976impl RouteParser {
977    fn new(lit: LitStr) -> syn::Result<Self> {
978        let val = lit.value();
979        let span = lit.span();
980        let split_route = val.split('?').collect::<Vec<_>>();
981        if split_route.len() > 2 {
982            return Err(syn::Error::new(span, "expected at most one '?'"));
983        }
984
985        let path = split_route[0];
986        if !path.starts_with('/') {
987            return Err(syn::Error::new(span, "expected path to start with '/'"));
988        }
989        let path = path.strip_prefix('/').unwrap();
990
991        let mut path_params = Vec::new();
992
993        for path_param in path.split('/') {
994            path_params.push((
995                Slash(span),
996                PathParam::new(path_param, span, Box::new(parse_quote!(())))?,
997            ));
998        }
999
1000        let path_param_len = path_params.len();
1001        for (i, (_slash, path_param)) in path_params.iter().enumerate() {
1002            match path_param {
1003                PathParam::WildCard(_, _, _, _, _, _) => {
1004                    if i != path_param_len - 1 {
1005                        return Err(syn::Error::new(
1006                            span,
1007                            "wildcard path param must be the last path param",
1008                        ));
1009                    }
1010                }
1011                PathParam::Capture(_, _, _, _, _) => (),
1012                PathParam::Static(lit) => {
1013                    if lit.value() == "*" && i != path_param_len - 1 {
1014                        return Err(syn::Error::new(
1015                            span,
1016                            "wildcard path param must be the last path param",
1017                        ));
1018                    }
1019                }
1020            }
1021        }
1022
1023        let mut query_params = Vec::new();
1024        if split_route.len() == 2 {
1025            let query = split_route[1];
1026            for query_param in query.split('&') {
1027                if query_param.starts_with(":") {
1028                    let ident = Ident::new(query_param.strip_prefix(":").unwrap(), span);
1029
1030                    query_params.push(QueryParam {
1031                        name: ident.to_string(),
1032                        binding: ident,
1033                        catch_all: true,
1034                        ty: parse_quote!(()),
1035                        arg_idx: usize::MAX,
1036                    });
1037                } else if query_param.starts_with("{") && query_param.ends_with("}") {
1038                    let ident = Ident::new(
1039                        query_param
1040                            .strip_prefix("{")
1041                            .unwrap()
1042                            .strip_suffix("}")
1043                            .unwrap(),
1044                        span,
1045                    );
1046
1047                    query_params.push(QueryParam {
1048                        name: ident.to_string(),
1049                        binding: ident,
1050                        catch_all: true,
1051                        ty: parse_quote!(()),
1052                        arg_idx: usize::MAX,
1053                    });
1054                } else {
1055                    // if there's an `=` in the query param, we only take the left side as the name, and the right side is the binding
1056                    let name;
1057                    let binding;
1058                    if let Some((n, b)) = query_param.split_once('=') {
1059                        name = n;
1060                        binding = Ident::new(b, span);
1061                    } else {
1062                        name = query_param;
1063                        binding = Ident::new(query_param, span);
1064                    }
1065
1066                    query_params.push(QueryParam {
1067                        name: name.to_string(),
1068                        binding,
1069                        catch_all: false,
1070                        ty: parse_quote!(()),
1071                        arg_idx: usize::MAX,
1072                    });
1073                }
1074            }
1075        }
1076
1077        Ok(Self {
1078            path_params,
1079            query_params,
1080        })
1081    }
1082}
1083
1084enum PathParam {
1085    WildCard(LitStr, Brace, Star, Ident, Box<Type>, Brace),
1086    Capture(LitStr, Brace, Ident, Box<Type>, Brace),
1087    Static(LitStr),
1088}
1089
1090impl PathParam {
1091    fn _captures(&self) -> bool {
1092        matches!(self, Self::Capture(..) | Self::WildCard(..))
1093    }
1094
1095    fn capture(&self) -> Option<(&Ident, &Type)> {
1096        match self {
1097            Self::Capture(_, _, ident, ty, _) => Some((ident, ty)),
1098            Self::WildCard(_, _, _, ident, ty, _) => Some((ident, ty)),
1099            _ => None,
1100        }
1101    }
1102
1103    fn new(str: &str, span: Span, ty: Box<Type>) -> syn::Result<Self> {
1104        let ok = if str.starts_with('{') {
1105            let str = str
1106                .strip_prefix('{')
1107                .unwrap()
1108                .strip_suffix('}')
1109                .ok_or_else(|| {
1110                    syn::Error::new(span, "expected path param to be wrapped in curly braces")
1111                })?;
1112            Self::Capture(
1113                LitStr::new(str, span),
1114                Brace(span),
1115                Ident::new(str, span),
1116                ty,
1117                Brace(span),
1118            )
1119        } else if str.starts_with('*') && str.len() > 1 {
1120            let str = str.strip_prefix('*').unwrap();
1121            Self::WildCard(
1122                LitStr::new(str, span),
1123                Brace(span),
1124                Star(span),
1125                Ident::new(str, span),
1126                ty,
1127                Brace(span),
1128            )
1129        } else if str.starts_with(':') && str.len() > 1 {
1130            let str = str.strip_prefix(':').unwrap();
1131            Self::Capture(
1132                LitStr::new(str, span),
1133                Brace(span),
1134                Ident::new(str, span),
1135                ty,
1136                Brace(span),
1137            )
1138        } else {
1139            Self::Static(LitStr::new(str, span))
1140        };
1141
1142        Ok(ok)
1143    }
1144}
1145
1146struct OapiOptions {
1147    summary: Option<(Ident, LitStr)>,
1148    description: Option<(Ident, LitStr)>,
1149    id: Option<(Ident, LitStr)>,
1150    hidden: Option<(Ident, LitBool)>,
1151    tags: Option<(Ident, StrArray)>,
1152    security: Option<(Ident, Security)>,
1153    responses: Option<(Ident, Responses)>,
1154    transform: Option<(Ident, ExprClosure)>,
1155}
1156
1157struct Security(Vec<(LitStr, StrArray)>);
1158impl Parse for Security {
1159    fn parse(input: ParseStream) -> syn::Result<Self> {
1160        let inner;
1161        braced!(inner in input);
1162
1163        let mut arr = Vec::new();
1164        while !inner.is_empty() {
1165            let scheme = inner.parse::<LitStr>()?;
1166            let _ = inner.parse::<Token![:]>()?;
1167            let scopes = inner.parse::<StrArray>()?;
1168            let _ = inner.parse::<Token![,]>().ok();
1169            arr.push((scheme, scopes));
1170        }
1171
1172        Ok(Self(arr))
1173    }
1174}
1175
1176impl std::fmt::Display for Security {
1177    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1178        write!(f, "{{")?;
1179        for (i, (scheme, scopes)) in self.0.iter().enumerate() {
1180            if i > 0 {
1181                write!(f, ", ")?;
1182            }
1183            write!(f, "{}: {}", scheme.value(), scopes)?;
1184        }
1185        write!(f, "}}")
1186    }
1187}
1188
1189struct Responses(Vec<(LitInt, Type)>);
1190impl Parse for Responses {
1191    fn parse(input: ParseStream) -> syn::Result<Self> {
1192        let inner;
1193        braced!(inner in input);
1194
1195        let mut arr = Vec::new();
1196        while !inner.is_empty() {
1197            let status = inner.parse::<LitInt>()?;
1198            let _ = inner.parse::<Token![:]>()?;
1199            let ty = inner.parse::<Type>()?;
1200            let _ = inner.parse::<Token![,]>().ok();
1201            arr.push((status, ty));
1202        }
1203
1204        Ok(Self(arr))
1205    }
1206}
1207
1208impl std::fmt::Display for Responses {
1209    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1210        write!(f, "{{")?;
1211        for (i, (status, ty)) in self.0.iter().enumerate() {
1212            if i > 0 {
1213                write!(f, ", ")?;
1214            }
1215            write!(f, "{}: {}", status, ty.to_token_stream())?;
1216        }
1217        write!(f, "}}")
1218    }
1219}
1220
1221#[derive(Clone)]
1222struct StrArray(Vec<LitStr>);
1223impl Parse for StrArray {
1224    fn parse(input: ParseStream) -> syn::Result<Self> {
1225        let inner;
1226        bracketed!(inner in input);
1227        let mut arr = Vec::new();
1228        while !inner.is_empty() {
1229            arr.push(inner.parse::<LitStr>()?);
1230            inner.parse::<Token![,]>().ok();
1231        }
1232        Ok(Self(arr))
1233    }
1234}
1235
1236impl std::fmt::Display for StrArray {
1237    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1238        write!(f, "[")?;
1239        for (i, lit) in self.0.iter().enumerate() {
1240            if i > 0 {
1241                write!(f, ", ")?;
1242            }
1243            write!(f, "\"{}\"", lit.value())?;
1244        }
1245        write!(f, "]")
1246    }
1247}
1248
1249impl Parse for OapiOptions {
1250    fn parse(input: ParseStream) -> syn::Result<Self> {
1251        let mut this = Self {
1252            summary: None,
1253            description: None,
1254            id: None,
1255            hidden: None,
1256            tags: None,
1257            security: None,
1258            responses: None,
1259            transform: None,
1260        };
1261
1262        while !input.is_empty() {
1263            let ident = input.parse::<Ident>()?;
1264            let _ = input.parse::<Token![:]>()?;
1265            match ident.to_string().as_str() {
1266                "summary" => this.summary = Some((ident, input.parse()?)),
1267                "description" => this.description = Some((ident, input.parse()?)),
1268                "id" => this.id = Some((ident, input.parse()?)),
1269                "hidden" => this.hidden = Some((ident, input.parse()?)),
1270                "tags" => this.tags = Some((ident, input.parse()?)),
1271                "security" => this.security = Some((ident, input.parse()?)),
1272                "responses" => this.responses = Some((ident, input.parse()?)),
1273                "transform" => this.transform = Some((ident, input.parse()?)),
1274                _ => {
1275                    return Err(syn::Error::new(
1276                        ident.span(),
1277                        "unexpected field, expected one of (summary, description, id, hidden, tags, security, responses, transform)",
1278                    ))
1279                }
1280            }
1281            let _ = input.parse::<Token![,]>().ok();
1282        }
1283
1284        Ok(this)
1285    }
1286}
1287
1288impl OapiOptions {
1289    fn merge_with_fn(&mut self, function: &ItemFn) {
1290        if self.description.is_none() {
1291            self.description = doc_iter(&function.attrs)
1292                .skip(2)
1293                .map(|item| item.value())
1294                .reduce(|mut acc, item| {
1295                    acc.push('\n');
1296                    acc.push_str(&item);
1297                    acc
1298                })
1299                .map(|item| (parse_quote!(description), parse_quote!(#item)))
1300        }
1301        if self.summary.is_none() {
1302            self.summary = doc_iter(&function.attrs)
1303                .next()
1304                .map(|item| (parse_quote!(summary), item.clone()))
1305        }
1306        if self.id.is_none() {
1307            let id = &function.sig.ident;
1308            self.id = Some((parse_quote!(id), LitStr::new(&id.to_string(), id.span())));
1309        }
1310    }
1311}
1312
1313fn doc_iter(attrs: &[Attribute]) -> impl Iterator<Item = &LitStr> + '_ {
1314    attrs
1315        .iter()
1316        .filter(|attr| attr.path().is_ident("doc"))
1317        .map(|attr| {
1318            let Meta::NameValue(meta) = &attr.meta else {
1319                panic!("doc attribute is not a name-value attribute");
1320            };
1321            let Expr::Lit(lit) = &meta.value else {
1322                panic!("doc attribute is not a string literal");
1323            };
1324            let Lit::Str(lit_str) = &lit.lit else {
1325                panic!("doc attribute is not a string literal");
1326            };
1327            lit_str
1328        })
1329}
1330
1331struct Route {
1332    method: Option<Method>,
1333    path_params: Vec<(Slash, PathParam)>,
1334    query_params: Vec<QueryParam>,
1335    route_lit: Option<LitStr>,
1336    prefix: Option<LitStr>,
1337    oapi_options: Option<OapiOptions>,
1338    server_args: Punctuated<FnArg, Comma>,
1339
1340    // todo: support these since `server_fn` had them
1341    _input_encoding: Option<Type>,
1342    _output_encoding: Option<Type>,
1343}
1344
1345impl Parse for Route {
1346    fn parse(input: ParseStream) -> syn::Result<Self> {
1347        let method = if input.peek(Ident) {
1348            Some(input.parse::<Method>()?)
1349        } else {
1350            None
1351        };
1352
1353        let route_lit = input.parse::<LitStr>()?;
1354        let RouteParser {
1355            path_params,
1356            query_params,
1357        } = RouteParser::new(route_lit.clone())?;
1358
1359        let oapi_options = input
1360            .peek(Brace)
1361            .then(|| {
1362                let inner;
1363                braced!(inner in input);
1364                inner.parse::<OapiOptions>()
1365            })
1366            .transpose()?;
1367
1368        let server_args = if input.peek(Comma) {
1369            let _ = input.parse::<Comma>()?;
1370            input.parse_terminated(FnArg::parse, Comma)?
1371        } else {
1372            Punctuated::new()
1373        };
1374
1375        Ok(Route {
1376            method,
1377            path_params,
1378            query_params,
1379            route_lit: Some(route_lit),
1380            oapi_options,
1381            server_args,
1382            prefix: None,
1383            _input_encoding: None,
1384            _output_encoding: None,
1385        })
1386    }
1387}
1388
1389#[derive(Clone)]
1390enum Method {
1391    Get(Ident),
1392    Post(Ident),
1393    Put(Ident),
1394    Delete(Ident),
1395    Head(Ident),
1396    Connect(Ident),
1397    Options(Ident),
1398    Trace(Ident),
1399    Patch(Ident),
1400}
1401
1402impl ToTokens for Method {
1403    fn to_tokens(&self, tokens: &mut TokenStream2) {
1404        match self {
1405            Self::Get(ident)
1406            | Self::Post(ident)
1407            | Self::Put(ident)
1408            | Self::Delete(ident)
1409            | Self::Head(ident)
1410            | Self::Connect(ident)
1411            | Self::Options(ident)
1412            | Self::Trace(ident)
1413            | Self::Patch(ident) => {
1414                ident.to_tokens(tokens);
1415            }
1416        }
1417    }
1418}
1419
1420impl Parse for Method {
1421    fn parse(input: ParseStream) -> syn::Result<Self> {
1422        let ident = input.parse::<Ident>()?;
1423        match ident.to_string().to_uppercase().as_str() {
1424            "GET" => Ok(Self::Get(ident)),
1425            "POST" => Ok(Self::Post(ident)),
1426            "PUT" => Ok(Self::Put(ident)),
1427            "DELETE" => Ok(Self::Delete(ident)),
1428            "HEAD" => Ok(Self::Head(ident)),
1429            "CONNECT" => Ok(Self::Connect(ident)),
1430            "OPTIONS" => Ok(Self::Options(ident)),
1431            "TRACE" => Ok(Self::Trace(ident)),
1432            _ => Err(input
1433                .error("expected one of (GET, POST, PUT, DELETE, HEAD, CONNECT, OPTIONS, TRACE)")),
1434        }
1435    }
1436}
1437
1438impl Method {
1439    fn to_axum_method_name(&self) -> Ident {
1440        match self {
1441            Self::Get(span) => Ident::new("get", span.span()),
1442            Self::Post(span) => Ident::new("post", span.span()),
1443            Self::Put(span) => Ident::new("put", span.span()),
1444            Self::Delete(span) => Ident::new("delete", span.span()),
1445            Self::Head(span) => Ident::new("head", span.span()),
1446            Self::Connect(span) => Ident::new("connect", span.span()),
1447            Self::Options(span) => Ident::new("options", span.span()),
1448            Self::Trace(span) => Ident::new("trace", span.span()),
1449            Self::Patch(span) => Ident::new("patch", span.span()),
1450        }
1451    }
1452
1453    fn new_from_string(s: &str) -> Self {
1454        match s.to_uppercase().as_str() {
1455            "GET" => Self::Get(Ident::new("GET", Span::call_site())),
1456            "POST" => Self::Post(Ident::new("POST", Span::call_site())),
1457            "PUT" => Self::Put(Ident::new("PUT", Span::call_site())),
1458            "DELETE" => Self::Delete(Ident::new("DELETE", Span::call_site())),
1459            "HEAD" => Self::Head(Ident::new("HEAD", Span::call_site())),
1460            "CONNECT" => Self::Connect(Ident::new("CONNECT", Span::call_site())),
1461            "OPTIONS" => Self::Options(Ident::new("OPTIONS", Span::call_site())),
1462            "TRACE" => Self::Trace(Ident::new("TRACE", Span::call_site())),
1463            "PATCH" => Self::Patch(Ident::new("PATCH", Span::call_site())),
1464            _ => panic!("expected one of (GET, POST, PUT, DELETE, HEAD, CONNECT, OPTIONS, TRACE)"),
1465        }
1466    }
1467}
1468
1469mod kw {
1470    syn::custom_keyword!(with);
1471}
1472
1473/// The arguments to the `server` macro.
1474///
1475/// These originally came from the `server_fn` crate, but many no longer apply after the 0.7 fullstack
1476/// overhaul. We keep the parser here for temporary backwards compatibility with existing code, but
1477/// these arguments will be removed in a future release.
1478#[derive(Debug)]
1479#[non_exhaustive]
1480#[allow(unused)]
1481struct ServerFnArgs {
1482    /// The name of the struct that will implement the server function trait
1483    /// and be submitted to inventory.
1484    struct_name: Option<Ident>,
1485    /// The prefix to use for the server function URL.
1486    prefix: Option<LitStr>,
1487    /// The input http encoding to use for the server function.
1488    input: Option<Type>,
1489    /// Additional traits to derive on the input struct for the server function.
1490    input_derive: Option<ExprTuple>,
1491    /// The output http encoding to use for the server function.
1492    output: Option<Type>,
1493    /// The path to the server function crate.
1494    fn_path: Option<LitStr>,
1495    /// The server type to use for the server function.
1496    server: Option<Type>,
1497    /// The client type to use for the server function.
1498    client: Option<Type>,
1499    /// The custom wrapper to use for the server function struct.
1500    custom_wrapper: Option<syn::Path>,
1501    /// If the generated input type should implement `From` the only field in the input
1502    impl_from: Option<LitBool>,
1503    /// If the generated input type should implement `Deref` to the only field in the input
1504    impl_deref: Option<LitBool>,
1505    /// The protocol to use for the server function implementation.
1506    protocol: Option<Type>,
1507    builtin_encoding: bool,
1508}
1509
1510impl Parse for ServerFnArgs {
1511    fn parse(stream: ParseStream) -> syn::Result<Self> {
1512        // legacy 4-part arguments
1513        let mut struct_name: Option<Ident> = None;
1514        let mut prefix: Option<LitStr> = None;
1515        let mut encoding: Option<LitStr> = None;
1516        let mut fn_path: Option<LitStr> = None;
1517
1518        // new arguments: can only be keyed by name
1519        let mut input: Option<Type> = None;
1520        let mut input_derive: Option<ExprTuple> = None;
1521        let mut output: Option<Type> = None;
1522        let mut server: Option<Type> = None;
1523        let mut client: Option<Type> = None;
1524        let mut custom_wrapper: Option<syn::Path> = None;
1525        let mut impl_from: Option<LitBool> = None;
1526        let mut impl_deref: Option<LitBool> = None;
1527        let mut protocol: Option<Type> = None;
1528
1529        let mut use_key_and_value = false;
1530        let mut arg_pos = 0;
1531
1532        while !stream.is_empty() {
1533            arg_pos += 1;
1534            let lookahead = stream.lookahead1();
1535            if lookahead.peek(Ident) {
1536                let key_or_value: Ident = stream.parse()?;
1537
1538                let lookahead = stream.lookahead1();
1539                if lookahead.peek(Token![=]) {
1540                    stream.parse::<Token![=]>()?;
1541                    let key = key_or_value;
1542                    use_key_and_value = true;
1543                    if key == "name" {
1544                        if struct_name.is_some() {
1545                            return Err(syn::Error::new(
1546                                key.span(),
1547                                "keyword argument repeated: `name`",
1548                            ));
1549                        }
1550                        struct_name = Some(stream.parse()?);
1551                    } else if key == "prefix" {
1552                        if prefix.is_some() {
1553                            return Err(syn::Error::new(
1554                                key.span(),
1555                                "keyword argument repeated: `prefix`",
1556                            ));
1557                        }
1558                        prefix = Some(stream.parse()?);
1559                    } else if key == "encoding" {
1560                        if encoding.is_some() {
1561                            return Err(syn::Error::new(
1562                                key.span(),
1563                                "keyword argument repeated: `encoding`",
1564                            ));
1565                        }
1566                        encoding = Some(stream.parse()?);
1567                    } else if key == "endpoint" {
1568                        if fn_path.is_some() {
1569                            return Err(syn::Error::new(
1570                                key.span(),
1571                                "keyword argument repeated: `endpoint`",
1572                            ));
1573                        }
1574                        fn_path = Some(stream.parse()?);
1575                    } else if key == "input" {
1576                        if encoding.is_some() {
1577                            return Err(syn::Error::new(
1578                                key.span(),
1579                                "`encoding` and `input` should not both be \
1580                                 specified",
1581                            ));
1582                        } else if input.is_some() {
1583                            return Err(syn::Error::new(
1584                                key.span(),
1585                                "keyword argument repeated: `input`",
1586                            ));
1587                        }
1588                        input = Some(stream.parse()?);
1589                    } else if key == "input_derive" {
1590                        if input_derive.is_some() {
1591                            return Err(syn::Error::new(
1592                                key.span(),
1593                                "keyword argument repeated: `input_derive`",
1594                            ));
1595                        }
1596                        input_derive = Some(stream.parse()?);
1597                    } else if key == "output" {
1598                        if encoding.is_some() {
1599                            return Err(syn::Error::new(
1600                                key.span(),
1601                                "`encoding` and `output` should not both be \
1602                                 specified",
1603                            ));
1604                        } else if output.is_some() {
1605                            return Err(syn::Error::new(
1606                                key.span(),
1607                                "keyword argument repeated: `output`",
1608                            ));
1609                        }
1610                        output = Some(stream.parse()?);
1611                    } else if key == "server" {
1612                        if server.is_some() {
1613                            return Err(syn::Error::new(
1614                                key.span(),
1615                                "keyword argument repeated: `server`",
1616                            ));
1617                        }
1618                        server = Some(stream.parse()?);
1619                    } else if key == "client" {
1620                        if client.is_some() {
1621                            return Err(syn::Error::new(
1622                                key.span(),
1623                                "keyword argument repeated: `client`",
1624                            ));
1625                        }
1626                        client = Some(stream.parse()?);
1627                    } else if key == "custom" {
1628                        if custom_wrapper.is_some() {
1629                            return Err(syn::Error::new(
1630                                key.span(),
1631                                "keyword argument repeated: `custom`",
1632                            ));
1633                        }
1634                        custom_wrapper = Some(stream.parse()?);
1635                    } else if key == "impl_from" {
1636                        if impl_from.is_some() {
1637                            return Err(syn::Error::new(
1638                                key.span(),
1639                                "keyword argument repeated: `impl_from`",
1640                            ));
1641                        }
1642                        impl_from = Some(stream.parse()?);
1643                    } else if key == "impl_deref" {
1644                        if impl_deref.is_some() {
1645                            return Err(syn::Error::new(
1646                                key.span(),
1647                                "keyword argument repeated: `impl_deref`",
1648                            ));
1649                        }
1650                        impl_deref = Some(stream.parse()?);
1651                    } else if key == "protocol" {
1652                        if protocol.is_some() {
1653                            return Err(syn::Error::new(
1654                                key.span(),
1655                                "keyword argument repeated: `protocol`",
1656                            ));
1657                        }
1658                        protocol = Some(stream.parse()?);
1659                    } else {
1660                        return Err(lookahead.error());
1661                    }
1662                } else {
1663                    let value = key_or_value;
1664                    if use_key_and_value {
1665                        return Err(syn::Error::new(
1666                            value.span(),
1667                            "positional argument follows keyword argument",
1668                        ));
1669                    }
1670                    if arg_pos == 1 {
1671                        struct_name = Some(value)
1672                    } else {
1673                        return Err(syn::Error::new(value.span(), "expected string literal"));
1674                    }
1675                }
1676            } else if lookahead.peek(LitStr) {
1677                if use_key_and_value {
1678                    return Err(syn::Error::new(
1679                        stream.span(),
1680                        "If you use keyword arguments (e.g., `name` = \
1681                         Something), then you can no longer use arguments \
1682                         without a keyword.",
1683                    ));
1684                }
1685                match arg_pos {
1686                    1 => return Err(lookahead.error()),
1687                    2 => prefix = Some(stream.parse()?),
1688                    3 => encoding = Some(stream.parse()?),
1689                    4 => fn_path = Some(stream.parse()?),
1690                    _ => return Err(syn::Error::new(stream.span(), "unexpected extra argument")),
1691                }
1692            } else {
1693                return Err(lookahead.error());
1694            }
1695
1696            if !stream.is_empty() {
1697                stream.parse::<Token![,]>()?;
1698            }
1699        }
1700
1701        // parse legacy encoding into input/output
1702        let mut builtin_encoding = false;
1703        if let Some(encoding) = encoding {
1704            match encoding.value().to_lowercase().as_str() {
1705                "url" => {
1706                    input = Some(type_from_ident(syn::parse_quote!(Url)));
1707                    output = Some(type_from_ident(syn::parse_quote!(Json)));
1708                    builtin_encoding = true;
1709                }
1710                "cbor" => {
1711                    input = Some(type_from_ident(syn::parse_quote!(Cbor)));
1712                    output = Some(type_from_ident(syn::parse_quote!(Cbor)));
1713                    builtin_encoding = true;
1714                }
1715                "getcbor" => {
1716                    input = Some(type_from_ident(syn::parse_quote!(GetUrl)));
1717                    output = Some(type_from_ident(syn::parse_quote!(Cbor)));
1718                    builtin_encoding = true;
1719                }
1720                "getjson" => {
1721                    input = Some(type_from_ident(syn::parse_quote!(GetUrl)));
1722                    output = Some(syn::parse_quote!(Json));
1723                    builtin_encoding = true;
1724                }
1725                _ => return Err(syn::Error::new(encoding.span(), "Encoding not found.")),
1726            }
1727        }
1728
1729        Ok(Self {
1730            struct_name,
1731            prefix,
1732            input,
1733            input_derive,
1734            output,
1735            fn_path,
1736            builtin_encoding,
1737            server,
1738            client,
1739            custom_wrapper,
1740            impl_from,
1741            impl_deref,
1742            protocol,
1743        })
1744    }
1745}
1746
1747/// An argument type in a server function.
1748#[allow(unused)]
1749// todo - we used to support a number of these attributes and pass them along to serde. bring them back.
1750#[derive(Debug, Clone)]
1751struct ServerFnArg {
1752    /// The attributes on the server function argument.
1753    server_fn_attributes: Vec<Attribute>,
1754    /// The type of the server function argument.
1755    arg: syn::PatType,
1756}
1757
1758impl ToTokens for ServerFnArg {
1759    fn to_tokens(&self, tokens: &mut TokenStream2) {
1760        let ServerFnArg { arg, .. } = self;
1761        tokens.extend(quote! {
1762            #arg
1763        });
1764    }
1765}
1766
1767impl Parse for ServerFnArg {
1768    fn parse(input: ParseStream) -> Result<Self> {
1769        let arg: syn::FnArg = input.parse()?;
1770        let mut arg = match arg {
1771            FnArg::Receiver(_) => {
1772                return Err(syn::Error::new(
1773                    arg.span(),
1774                    "cannot use receiver types in server function macro",
1775                ))
1776            }
1777            FnArg::Typed(t) => t,
1778        };
1779
1780        fn rename_path(path: Path, from_ident: Ident, to_ident: Ident) -> Path {
1781            if path.is_ident(&from_ident) {
1782                Path {
1783                    leading_colon: None,
1784                    segments: Punctuated::from_iter([PathSegment {
1785                        ident: to_ident,
1786                        arguments: PathArguments::None,
1787                    }]),
1788                }
1789            } else {
1790                path
1791            }
1792        }
1793
1794        let server_fn_attributes = arg
1795            .attrs
1796            .iter()
1797            .cloned()
1798            .map(|attr| {
1799                if attr.path().is_ident("server") {
1800                    // Allow the following attributes:
1801                    // - #[server(default)]
1802                    // - #[server(rename = "fieldName")]
1803
1804                    // Rename `server` to `serde`
1805                    let attr = Attribute {
1806                        meta: match attr.meta {
1807                            Meta::Path(path) => Meta::Path(rename_path(
1808                                path,
1809                                format_ident!("server"),
1810                                format_ident!("serde"),
1811                            )),
1812                            Meta::List(mut list) => {
1813                                list.path = rename_path(
1814                                    list.path,
1815                                    format_ident!("server"),
1816                                    format_ident!("serde"),
1817                                );
1818                                Meta::List(list)
1819                            }
1820                            Meta::NameValue(mut name_value) => {
1821                                name_value.path = rename_path(
1822                                    name_value.path,
1823                                    format_ident!("server"),
1824                                    format_ident!("serde"),
1825                                );
1826                                Meta::NameValue(name_value)
1827                            }
1828                        },
1829                        ..attr
1830                    };
1831
1832                    let args = attr.parse_args::<Meta>()?;
1833                    match args {
1834                        // #[server(default)]
1835                        Meta::Path(path) if path.is_ident("default") => Ok(attr.clone()),
1836                        // #[server(flatten)]
1837                        Meta::Path(path) if path.is_ident("flatten") => Ok(attr.clone()),
1838                        // #[server(default = "value")]
1839                        Meta::NameValue(name_value) if name_value.path.is_ident("default") => {
1840                            Ok(attr.clone())
1841                        }
1842                        // #[server(skip)]
1843                        Meta::Path(path) if path.is_ident("skip") => Ok(attr.clone()),
1844                        // #[server(rename = "value")]
1845                        Meta::NameValue(name_value) if name_value.path.is_ident("rename") => {
1846                            Ok(attr.clone())
1847                        }
1848                        _ => Err(Error::new(
1849                            attr.span(),
1850                            "Unrecognized #[server] attribute, expected \
1851                             #[server(default)] or #[server(rename = \
1852                             \"fieldName\")]",
1853                        )),
1854                    }
1855                } else if attr.path().is_ident("doc") {
1856                    // Allow #[doc = "documentation"]
1857                    Ok(attr.clone())
1858                } else if attr.path().is_ident("allow") {
1859                    // Allow #[allow(...)]
1860                    Ok(attr.clone())
1861                } else if attr.path().is_ident("deny") {
1862                    // Allow #[deny(...)]
1863                    Ok(attr.clone())
1864                } else if attr.path().is_ident("ignore") {
1865                    // Allow #[ignore]
1866                    Ok(attr.clone())
1867                } else {
1868                    Err(Error::new(
1869                        attr.span(),
1870                        "Unrecognized attribute, expected #[server(...)]",
1871                    ))
1872                }
1873            })
1874            .collect::<Result<Vec<_>>>()?;
1875        arg.attrs = vec![];
1876        Ok(ServerFnArg {
1877            arg,
1878            server_fn_attributes,
1879        })
1880    }
1881}
1882
1883fn type_from_ident(ident: Ident) -> Type {
1884    let mut segments = Punctuated::new();
1885    segments.push(PathSegment {
1886        ident,
1887        arguments: PathArguments::None,
1888    });
1889    Type::Path(TypePath {
1890        qself: None,
1891        path: Path {
1892            leading_colon: None,
1893            segments,
1894        },
1895    })
1896}