dioxus_fullstack_macro/
lib.rs

1// TODO: Create README, uncomment this: #![doc = include_str!("../README.md")]
2#![doc(html_logo_url = "https://avatars.githubusercontent.com/u/79236386")]
3#![doc(html_favicon_url = "https://avatars.githubusercontent.com/u/79236386")]
4
5use core::panic;
6use proc_macro::TokenStream;
7use proc_macro2::{Span, TokenStream as TokenStream2};
8use quote::ToTokens;
9use quote::{format_ident, quote};
10use std::collections::HashMap;
11use syn::{
12    braced, bracketed,
13    parse::ParseStream,
14    punctuated::Punctuated,
15    token::{Comma, Slash},
16    Error, ExprTuple, FnArg, Meta, PathArguments, PathSegment, Token, Type, TypePath,
17};
18use syn::{parse::Parse, parse_quote, Ident, ItemFn, LitStr, Path};
19use syn::{spanned::Spanned, LitBool, LitInt, Pat, PatType};
20use syn::{
21    token::{Brace, Star},
22    Attribute, Expr, ExprClosure, Lit, Result,
23};
24
25/// ## Usage
26///
27/// ```rust,ignore
28/// # use dioxus::prelude::*;
29/// # #[derive(serde::Deserialize, serde::Serialize)]
30/// # struct BlogPost;
31/// # async fn load_posts(category: &str) -> Result<Vec<BlogPost>> { unimplemented!() }
32///
33/// #[server]
34/// async fn blog_posts(category: String) -> Result<Vec<BlogPost>> {
35///     let posts = load_posts(&category).await?;
36///     // maybe do some other work
37///     Ok(posts)
38/// }
39/// ```
40///
41/// ## Named Arguments
42///
43/// You can use any combination of the following named arguments:
44/// - `endpoint`: a prefix at which the server function handler will be mounted (defaults to `/api`).
45///   Example: `endpoint = "/my_api/my_serverfn"`.
46/// - `input`: the encoding for the arguments, defaults to `Json<T>`
47///     - You may customize the encoding of the arguments by specifying a different type for `input`.
48///     - Any axum `IntoRequest` extractor can be used here, and dioxus provides
49///       - `Json<T>`: The default axum `Json` extractor that decodes JSON-encoded request bodies.
50///       - `Cbor<T>`: A custom axum `Cbor` extractor that decodes CBOR-encoded request bodies.
51///       - `MessagePack<T>`: A custom axum `MessagePack` extractor that decodes MessagePack-encoded request bodies.
52/// - `output`: the encoding for the response (defaults to `Json`).
53///     - The `output` argument specifies how the server should encode the response data.
54///     - Acceptable values include:
55///       - `Json`: A response encoded as JSON (default). This is ideal for most web applications.
56///       - `Cbor`: A response encoded in the CBOR format for efficient, binary-encoded data.
57/// - `client`: a custom `Client` implementation that will be used for this server function. This allows
58///   customization of the client-side behavior if needed.
59///
60/// ## Advanced Usage of `input` and `output` Fields
61///
62/// The `input` and `output` fields allow you to customize how arguments and responses are encoded and decoded.
63/// These fields impose specific trait bounds on the types you use. Here are detailed examples for different scenarios:
64///
65/// ## Adding layers to server functions
66///
67/// Layers allow you to transform the request and response of a server function. You can use layers
68/// to add authentication, logging, or other functionality to your server functions. Server functions integrate
69/// with the tower ecosystem, so you can use any layer that is compatible with tower.
70///
71/// Common layers include:
72/// - [`tower_http::trace::TraceLayer`](https://docs.rs/tower-http/latest/tower_http/trace/struct.TraceLayer.html) for tracing requests and responses
73/// - [`tower_http::compression::CompressionLayer`](https://docs.rs/tower-http/latest/tower_http/compression/struct.CompressionLayer.html) for compressing large responses
74/// - [`tower_http::cors::CorsLayer`](https://docs.rs/tower-http/latest/tower_http/cors/struct.CorsLayer.html) for adding CORS headers to responses
75/// - [`tower_http::timeout::TimeoutLayer`](https://docs.rs/tower-http/latest/tower_http/timeout/struct.TimeoutLayer.html) for adding timeouts to requests
76/// - [`tower_sessions::service::SessionManagerLayer`](https://docs.rs/tower-sessions/0.13.0/tower_sessions/service/struct.SessionManagerLayer.html) for adding session management to requests
77///
78/// You can add a tower [`Layer`](https://docs.rs/tower/latest/tower/trait.Layer.html) to your server function with the middleware attribute:
79///
80/// ```rust,ignore
81/// # use dioxus::prelude::*;
82/// #[server]
83/// // The TraceLayer will log all requests to the console
84/// #[middleware(tower_http::timeout::TimeoutLayer::new(std::time::Duration::from_secs(5)))]
85/// pub async fn my_wacky_server_fn(input: Vec<String>) -> ServerFnResult<usize> {
86///     unimplemented!()
87/// }
88/// ```
89#[proc_macro_attribute]
90pub fn server(attr: proc_macro::TokenStream, mut item: TokenStream) -> TokenStream {
91    // Parse the attribute list using the old server_fn arg parser.
92    let args = match syn::parse::<ServerFnArgs>(attr) {
93        Ok(args) => args,
94        Err(err) => {
95            let err: TokenStream = err.to_compile_error().into();
96            item.extend(err);
97            return item;
98        }
99    };
100
101    let method = Method::Post(Ident::new("POST", proc_macro2::Span::call_site()));
102    let route: Route = Route {
103        method: None,
104        path_params: vec![],
105        query_params: vec![],
106        route_lit: args.fn_path,
107        oapi_options: None,
108        server_args: Default::default(),
109        prefix: Some(
110            args.prefix
111                .unwrap_or_else(|| LitStr::new("/api", Span::call_site())),
112        ),
113        _input_encoding: args.input,
114        _output_encoding: args.output,
115    };
116
117    match route_impl_with_route(route, item.clone(), Some(method)) {
118        Ok(mut tokens) => {
119            // Let's add some deprecated warnings to the various fields from `args` if the user is using them...
120            // We don't generate structs anymore, don't use various protocols, etc
121            if let Some(name) = args.struct_name {
122                tokens.extend(quote! {
123                    const _: () = {
124                        #[deprecated(note = "Dioxus server functions no longer generate a struct for the server function. The function itself is used directly.")]
125                        struct #name;
126                        fn ___assert_deprecated() {
127                            let _ = #name;
128                        }
129
130                        ()
131                    };
132                });
133            }
134
135            //
136            tokens.into()
137        }
138
139        // Retain the original function item and append the error to it. Better for autocomplete.
140        Err(err) => {
141            let err: TokenStream = err.to_compile_error().into();
142            item.extend(err);
143            item
144        }
145    }
146}
147
148#[proc_macro_attribute]
149pub fn get(args: proc_macro::TokenStream, body: TokenStream) -> TokenStream {
150    wrapped_route_impl(args, body, Some(Method::new_from_string("GET")))
151}
152
153#[proc_macro_attribute]
154pub fn post(args: proc_macro::TokenStream, body: TokenStream) -> TokenStream {
155    wrapped_route_impl(args, body, Some(Method::new_from_string("POST")))
156}
157
158#[proc_macro_attribute]
159pub fn put(args: proc_macro::TokenStream, body: TokenStream) -> TokenStream {
160    wrapped_route_impl(args, body, Some(Method::new_from_string("PUT")))
161}
162
163#[proc_macro_attribute]
164pub fn delete(args: proc_macro::TokenStream, body: TokenStream) -> TokenStream {
165    wrapped_route_impl(args, body, Some(Method::new_from_string("DELETE")))
166}
167
168#[proc_macro_attribute]
169pub fn patch(args: proc_macro::TokenStream, body: TokenStream) -> TokenStream {
170    wrapped_route_impl(args, body, Some(Method::new_from_string("PATCH")))
171}
172
173fn wrapped_route_impl(
174    attr: TokenStream,
175    mut item: TokenStream,
176    method: Option<Method>,
177) -> TokenStream {
178    match route_impl(attr, item.clone(), method) {
179        Ok(tokens) => tokens.into(),
180        Err(err) => {
181            let err: TokenStream = err.to_compile_error().into();
182            item.extend(err);
183            item
184        }
185    }
186}
187
188fn route_impl(
189    attr: TokenStream,
190    item: TokenStream,
191    method_from_macro: Option<Method>,
192) -> syn::Result<TokenStream2> {
193    let route = syn::parse::<Route>(attr)?;
194    route_impl_with_route(route, item, method_from_macro)
195}
196
197fn route_impl_with_route(
198    route: Route,
199    item: TokenStream,
200    method_from_macro: Option<Method>,
201) -> syn::Result<TokenStream2> {
202    // Parse the route and function
203    let mut function = syn::parse::<ItemFn>(item)?;
204
205    // Collect the middleware initializers
206    let middleware_layers = function
207        .attrs
208        .iter()
209        .filter(|attr| attr.path().is_ident("middleware"))
210        .map(|f| match &f.meta {
211            Meta::List(meta_list) => Ok({
212                let tokens = &meta_list.tokens;
213                quote! { .layer(#tokens) }
214            }),
215            _ => Err(Error::new(
216                f.span(),
217                "Expected middleware attribute to be a list, e.g. #[middleware(MyLayer::new())]",
218            )),
219        })
220        .collect::<Result<Vec<_>>>()?;
221
222    // don't re-emit the middleware attribute on the inner
223    function
224        .attrs
225        .retain(|attr| !attr.path().is_ident("middleware"));
226
227    // Attach `#[allow(unused_mut)]` to all original inputs to avoid warnings
228    let outer_inputs = function
229        .sig
230        .inputs
231        .iter()
232        .enumerate()
233        .map(|(i, arg)| match arg {
234            FnArg::Receiver(_receiver) => panic!("Self type is not supported"),
235            FnArg::Typed(pat_type) => match pat_type.pat.as_ref() {
236                Pat::Ident(_) => {
237                    quote! { #[allow(unused_mut)] #pat_type }
238                }
239                _ => {
240                    let ident = format_ident!("___Arg{}", i);
241                    let ty = &pat_type.ty;
242                    quote! { #[allow(unused_mut)] #ident: #ty }
243                }
244            },
245        })
246        .collect::<Punctuated<_, Token![,]>>();
247    // .collect::<Punctuated<_, Token![,]>>();
248
249    let route = CompiledRoute::from_route(route, &function, false, method_from_macro)?;
250    let query_params_struct = route.query_params_struct(false);
251    let method_ident = &route.method;
252    let body_json_args = route.remaining_pattypes_named(&function.sig.inputs);
253    let body_json_names = body_json_args
254        .iter()
255        .map(|(i, pat_type)| match &*pat_type.pat {
256            Pat::Ident(ref pat_ident) => pat_ident.ident.clone(),
257            _ => format_ident!("___Arg{}", i),
258        })
259        .collect::<Vec<_>>();
260    let body_json_types = body_json_args
261        .iter()
262        .map(|pat_type| &pat_type.1.ty)
263        .collect::<Vec<_>>();
264    let route_docs = route.to_doc_comments();
265
266    // Get the variables we need for code generation
267    let fn_on_server_name = &function.sig.ident;
268    let vis = &function.vis;
269    let (impl_generics, ty_generics, where_clause) = &function.sig.generics.split_for_impl();
270    let ty_generics = ty_generics.as_turbofish();
271    let fn_docs = function
272        .attrs
273        .iter()
274        .filter(|attr| attr.path().is_ident("doc"));
275
276    let __axum = quote! { dioxus_server::axum };
277
278    let output_type = match &function.sig.output {
279        syn::ReturnType::Default => parse_quote! { () },
280        syn::ReturnType::Type(_, ty) => (*ty).clone(),
281    };
282
283    let query_param_names = route
284        .query_params
285        .iter()
286        .filter(|c| !c.catch_all)
287        .map(|param| &param.binding);
288
289    let path_param_args = route.path_params.iter().map(|(_slash, param)| match param {
290        PathParam::Capture(_lit, _brace_1, ident, _ty, _brace_2) => {
291            Some(quote! { #ident = #ident, })
292        }
293        PathParam::WildCard(_lit, _brace_1, _star, ident, _ty, _brace_2) => {
294            Some(quote! { #ident = #ident, })
295        }
296        PathParam::Static(_lit) => None,
297    });
298
299    let out_ty = match output_type.as_ref() {
300        Type::Tuple(tuple) if tuple.elems.is_empty() => parse_quote! { () },
301        _ => output_type.clone(),
302    };
303
304    let mut function_on_server = function.clone();
305    function_on_server
306        .sig
307        .inputs
308        .extend(route.server_args.clone());
309
310    let server_names = route
311        .server_args
312        .iter()
313        .enumerate()
314        .map(|(i, pat_type)| match pat_type {
315            FnArg::Typed(_pat_type) => format_ident!("___sarg___{}", i),
316            FnArg::Receiver(_) => panic!("Self type is not supported"),
317        })
318        .collect::<Vec<_>>();
319
320    let server_types = route
321        .server_args
322        .iter()
323        .map(|pat_type| match pat_type {
324            FnArg::Receiver(_) => parse_quote! { () },
325            FnArg::Typed(pat_type) => (*pat_type.ty).clone(),
326        })
327        .collect::<Vec<_>>();
328
329    let body_struct_impl = {
330        let tys = body_json_types
331            .iter()
332            .enumerate()
333            .map(|(idx, _)| format_ident!("__Ty{}", idx));
334
335        let names = body_json_names.iter().enumerate().map(|(idx, name)| {
336            let ty_name = format_ident!("__Ty{}", idx);
337            quote! { #name: #ty_name }
338        });
339
340        quote! {
341            #[derive(serde::Serialize, serde::Deserialize)]
342            #[serde(crate = "serde")]
343            struct ___Body_Serialize___< #(#tys,)* > {
344                #(#names,)*
345            }
346        }
347    };
348
349    // This unpacks the body struct into the individual variables that get scoped
350    let unpack_closure = {
351        let unpack_args = body_json_names.iter().map(|name| quote! { data.#name });
352        quote! {
353            |data| { ( #(#unpack_args,)* ) }
354        }
355    };
356
357    let as_axum_path = route.to_axum_path_string();
358
359    let query_endpoint = if let Some(full_url) = route.url_without_queries_for_format() {
360        quote! { format!(#full_url, #( #path_param_args)*) }
361    } else {
362        quote! { __ENDPOINT_PATH.to_string() }
363    };
364
365    let endpoint_path = {
366        let prefix = route
367            .prefix
368            .as_ref()
369            .cloned()
370            .unwrap_or_else(|| LitStr::new("", Span::call_site()));
371
372        let route_lit = if !as_axum_path.is_empty() {
373            quote! { #as_axum_path }
374        } else {
375            quote! {
376                concat!(
377                    "/",
378                    stringify!(#fn_on_server_name)
379                )
380            }
381        };
382
383        let hash = match route.route_lit.as_ref() {
384            // Explicit route lit, no need to hash
385            Some(_) => quote! { "" },
386
387            // Implicit route lit, we need to hash the function signature to avoid collisions
388            None => {
389                // let enable_hash = option_env!("DISABLE_SERVER_FN_HASH").is_none();
390                let key_env_var = match option_env!("SERVER_FN_OVERRIDE_KEY") {
391                    Some(_) => "SERVER_FN_OVERRIDE_KEY",
392                    None => "CARGO_MANIFEST_DIR",
393                };
394                quote! {
395                    dioxus_fullstack::xxhash_rust::const_xxh64::xxh64(
396                        concat!(env!(#key_env_var), ":", module_path!()).as_bytes(),
397                        0
398                    )
399                }
400            }
401        };
402
403        quote! {
404            dioxus_fullstack::const_format::concatcp!(#prefix, #route_lit, #hash)
405        }
406    };
407
408    let extracted_idents = route.extracted_idents();
409
410    let query_tokens = if route.query_is_catchall() {
411        let query = route
412            .query_params
413            .iter()
414            .find(|param| param.catch_all)
415            .unwrap();
416        let input = &function.sig.inputs[query.arg_idx];
417        let name = match input {
418            FnArg::Typed(pat_type) => match pat_type.pat.as_ref() {
419                Pat::Ident(ref pat_ident) => pat_ident.ident.clone(),
420                _ => format_ident!("___Arg{}", query.arg_idx),
421            },
422            FnArg::Receiver(_receiver) => panic!(),
423        };
424        quote! {
425            #name
426        }
427    } else {
428        quote! {
429            __QueryParams__ { #(#query_param_names,)* }
430        }
431    };
432
433    let extracted_as_server_headers = route.extracted_as_server_headers(query_tokens.clone());
434
435    Ok(quote! {
436        #(#fn_docs)*
437        #route_docs
438        #[deny(
439            unexpected_cfgs,
440            reason = "
441==========================================================================================
442  Using Dioxus Server Functions requires a `server` feature flag in your `Cargo.toml`.
443  Please add the following to your `Cargo.toml`:
444
445  ```toml
446  [features]
447  server = [\"dioxus/server\"]
448  ```
449
450  To enable better Rust-Analyzer support, you can make `server` a default feature:
451  ```toml
452  [features]
453  default = [\"web\", \"server\"]
454  web = [\"dioxus/web\"]
455  server = [\"dioxus/server\"]
456  ```
457==========================================================================================
458        "
459        )]
460        #vis async fn #fn_on_server_name #impl_generics( #outer_inputs ) -> #out_ty #where_clause {
461            use dioxus_fullstack::serde as serde;
462            use dioxus_fullstack::{
463                // concrete types
464                ServerFnEncoder, ServerFnDecoder, FullstackContext,
465
466                // "magic" traits for encoding/decoding on the client
467                ExtractRequest, EncodeRequest, RequestDecodeResult, RequestDecodeErr,
468
469                // "magic" traits for encoding/decoding on the server
470                MakeAxumResponse, MakeAxumError,
471            };
472
473            #query_params_struct
474
475            #body_struct_impl
476
477            const __ENDPOINT_PATH: &str = #endpoint_path;
478
479            {
480                _ = dioxus_fullstack::assert_is_result::<#out_ty>();
481
482                let verify_token = (&&&&&&&&&&&&&&ServerFnEncoder::<___Body_Serialize___<#(#body_json_types,)*>, (#(#body_json_types,)*)>::new())
483                    .verify_can_serialize();
484
485                dioxus_fullstack::assert_can_encode(verify_token);
486
487                let decode_token = (&&&&&ServerFnDecoder::<#out_ty>::new())
488                    .verify_can_deserialize();
489
490                dioxus_fullstack::assert_can_decode(decode_token);
491            };
492
493
494            // On the client, we make the request to the server
495            // We want to support extremely flexible error types and return types, making this more complex than it should
496            #[allow(clippy::unused_unit)]
497            #[cfg(not(feature = "server"))]
498            {
499                let client = dioxus_fullstack::ClientRequest::new(
500                    dioxus_fullstack::http::Method::#method_ident,
501                    #query_endpoint,
502                    &#query_tokens,
503                );
504
505                let response = (&&&&&&&&&&&&&&ServerFnEncoder::<___Body_Serialize___<#(#body_json_types,)*>, (#(#body_json_types,)*)>::new())
506                    .fetch_client(client, ___Body_Serialize___ { #(#body_json_names,)* }, #unpack_closure)
507                    .await;
508
509                let decoded = (&&&&&ServerFnDecoder::<#out_ty>::new())
510                    .decode_client_response(response)
511                    .await;
512
513                let result = (&&&&&ServerFnDecoder::<#out_ty>::new())
514                    .decode_client_err(decoded)
515                    .await;
516
517                return result;
518            }
519
520            // On the server, we expand the tokens and submit the function to inventory
521            #[cfg(feature = "server")] {
522                #function_on_server
523
524                #[allow(clippy::unused_unit)]
525                fn __inner__function__ #impl_generics(
526                    ___state: #__axum::extract::State<FullstackContext>,
527                    ___request: #__axum::extract::Request,
528                ) -> std::pin::Pin<Box<dyn std::future::Future<Output = #__axum::response::Response>>> #where_clause {
529                    Box::pin(async move {
530                         match (&&&&&&&&&&&&&&ServerFnEncoder::<___Body_Serialize___<#(#body_json_types,)*>, (#(#body_json_types,)*)>::new()).extract_axum(___state.0, ___request, #unpack_closure).await {
531                            Ok(((#(#body_json_names,)* ), (#(#extracted_as_server_headers,)* #(#server_names,)*) )) => {
532                                // Call the user function
533                                let res = #fn_on_server_name #ty_generics(#(#extracted_idents,)* #(#body_json_names,)* #(#server_names,)*).await;
534
535                                // Encode the response Into a `Result<T, E>`
536                                let encoded = (&&&&&&ServerFnDecoder::<#out_ty>::new()).make_axum_response(res);
537
538                                // And then encode `Result<T, E>` into `Response`
539                                (&&&&&ServerFnDecoder::<#out_ty>::new()).make_axum_error(encoded)
540                            },
541                            Err(res) => res,
542                        }
543                    })
544                }
545
546                dioxus_server::inventory::submit! {
547                    dioxus_server::ServerFunction::new(
548                        dioxus_server::http::Method::#method_ident,
549                        __ENDPOINT_PATH,
550                        || {
551                            dioxus_server::ServerFunction::make_handler(dioxus_server::http::Method::#method_ident, __inner__function__ #ty_generics)
552                                #(#middleware_layers)*
553                        }
554                    )
555                }
556
557                // Extract the server arguments from the context
558                let (#(#server_names,)*) = dioxus_fullstack::FullstackContext::extract::<(#(#server_types,)*), _>().await?;
559
560                // Call the function directly
561                return #fn_on_server_name #ty_generics(
562                    #(#extracted_idents,)*
563                    #(#body_json_names,)*
564                    #(#server_names,)*
565                ).await;
566            }
567
568            #[allow(unreachable_code)]
569            {
570                unreachable!()
571            }
572        }
573    })
574}
575
576struct CompiledRoute {
577    method: Method,
578    #[allow(clippy::type_complexity)]
579    path_params: Vec<(Slash, PathParam)>,
580    query_params: Vec<QueryParam>,
581    route_lit: Option<LitStr>,
582    prefix: Option<LitStr>,
583    oapi_options: Option<OapiOptions>,
584    server_args: Punctuated<FnArg, Comma>,
585}
586
587struct QueryParam {
588    arg_idx: usize,
589    name: String,
590    binding: Ident,
591    catch_all: bool,
592    ty: Box<Type>,
593}
594
595impl CompiledRoute {
596    fn to_axum_path_string(&self) -> String {
597        let mut path = String::new();
598
599        for (_slash, param) in &self.path_params {
600            path.push('/');
601            match param {
602                PathParam::Capture(lit, _brace_1, _, _, _brace_2) => {
603                    path.push('{');
604                    path.push_str(&lit.value());
605                    path.push('}');
606                }
607                PathParam::WildCard(lit, _brace_1, _, _, _, _brace_2) => {
608                    path.push('{');
609                    path.push('*');
610                    path.push_str(&lit.value());
611                    path.push('}');
612                }
613                PathParam::Static(lit) => path.push_str(&lit.value()),
614            }
615        }
616
617        path
618    }
619
620    /// Removes the arguments in `route` from `args`, and merges them in the output.
621    pub fn from_route(
622        mut route: Route,
623        function: &ItemFn,
624        with_aide: bool,
625        method_from_macro: Option<Method>,
626    ) -> syn::Result<Self> {
627        if !with_aide && route.oapi_options.is_some() {
628            return Err(syn::Error::new(
629                Span::call_site(),
630                "Use `api_route` instead of `route` to use OpenAPI options",
631            ));
632        } else if with_aide && route.oapi_options.is_none() {
633            route.oapi_options = Some(OapiOptions {
634                summary: None,
635                description: None,
636                id: None,
637                hidden: None,
638                tags: None,
639                security: None,
640                responses: None,
641                transform: None,
642            });
643        }
644
645        let sig = &function.sig;
646        let mut arg_map = sig
647            .inputs
648            .iter()
649            .enumerate()
650            .filter_map(|(i, item)| match item {
651                syn::FnArg::Receiver(_) => None,
652                syn::FnArg::Typed(pat_type) => Some((i, pat_type)),
653            })
654            .filter_map(|(i, pat_type)| match &*pat_type.pat {
655                syn::Pat::Ident(ident) => Some((ident.ident.clone(), (pat_type.ty.clone(), i))),
656                _ => None,
657            })
658            .collect::<HashMap<_, _>>();
659
660        for (_slash, path_param) in &mut route.path_params {
661            match path_param {
662                PathParam::Capture(_lit, _, ident, ty, _) => {
663                    let (new_ident, new_ty) = arg_map.remove_entry(ident).ok_or_else(|| {
664                        syn::Error::new(
665                            ident.span(),
666                            format!("path parameter `{}` not found in function arguments", ident),
667                        )
668                    })?;
669                    *ident = new_ident;
670                    *ty = new_ty.0;
671                }
672                PathParam::WildCard(_lit, _, _star, ident, ty, _) => {
673                    let (new_ident, new_ty) = arg_map.remove_entry(ident).ok_or_else(|| {
674                        syn::Error::new(
675                            ident.span(),
676                            format!("path parameter `{}` not found in function arguments", ident),
677                        )
678                    })?;
679                    *ident = new_ident;
680                    *ty = new_ty.0;
681                }
682                PathParam::Static(_lit) => {}
683            }
684        }
685
686        let mut query_params = Vec::new();
687        for param in route.query_params {
688            let (ident, ty) = arg_map.remove_entry(&param.binding).ok_or_else(|| {
689                syn::Error::new(
690                    param.binding.span(),
691                    format!(
692                        "query parameter `{}` not found in function arguments",
693                        param.binding
694                    ),
695                )
696            })?;
697            query_params.push(QueryParam {
698                binding: ident,
699                name: param.name,
700                catch_all: param.catch_all,
701                ty: ty.0,
702                arg_idx: ty.1,
703            });
704        }
705
706        // Disallow multiple query params if one is a catch-all
707        if query_params.iter().any(|param| param.catch_all) && query_params.len() > 1 {
708            return Err(syn::Error::new(
709                Span::call_site(),
710                "Cannot have multiple query parameters when one is a catch-all",
711            ));
712        }
713
714        if let Some(options) = route.oapi_options.as_mut() {
715            options.merge_with_fn(function)
716        }
717
718        let method = match (method_from_macro, route.method) {
719            (Some(method), None) => method,
720            (None, Some(method)) => method,
721            (Some(_), Some(_)) => {
722                return Err(syn::Error::new(
723                    Span::call_site(),
724                    "HTTP method specified both in macro and in attribute",
725                ))
726            }
727            (None, None) => {
728                return Err(syn::Error::new(
729                    Span::call_site(),
730                    "HTTP method not specified in macro or in attribute",
731                ))
732            }
733        };
734
735        Ok(Self {
736            method,
737            route_lit: route.route_lit,
738            path_params: route.path_params,
739            query_params,
740            oapi_options: route.oapi_options,
741            prefix: route.prefix,
742            server_args: route.server_args,
743        })
744    }
745
746    pub fn query_is_catchall(&self) -> bool {
747        self.query_params.iter().any(|param| param.catch_all)
748    }
749
750    pub fn extracted_as_server_headers(&self, query_tokens: TokenStream2) -> Vec<Pat> {
751        let mut out = vec![];
752
753        // Add the path extractor
754        out.push({
755            let path_iter = self
756                .path_params
757                .iter()
758                .filter_map(|(_slash, path_param)| path_param.capture());
759            let idents = path_iter.clone().map(|item| item.0);
760            parse_quote! {
761                dioxus_server::axum::extract::Path((#(#idents,)*))
762            }
763        });
764
765        out.push(parse_quote!(
766            dioxus_server::axum::extract::Query(#query_tokens)
767        ));
768
769        out
770    }
771
772    pub fn query_params_struct(&self, with_aide: bool) -> TokenStream2 {
773        let fields = self.query_params.iter().map(|item| {
774            let name = &item.name;
775            let binding = &item.binding;
776            let ty = &item.ty;
777            if item.catch_all {
778                quote! {}
779            } else if item.binding != item.name {
780                quote! {
781                    #[serde(rename = #name)]
782                    #binding: #ty,
783                }
784            } else {
785                quote! { #binding: #ty, }
786            }
787        });
788        let derive = match with_aide {
789            true => quote! {
790                #[derive(serde::Deserialize, serde::Serialize, ::schemars::JsonSchema)]
791                #[serde(crate = "serde")]
792            },
793            false => quote! {
794                #[derive(serde::Deserialize, serde::Serialize)]
795                #[serde(crate = "serde")]
796            },
797        };
798        quote! {
799            #derive
800            struct __QueryParams__ {
801                #(#fields)*
802            }
803        }
804    }
805
806    pub fn extracted_idents(&self) -> Vec<Ident> {
807        let mut idents = Vec::new();
808        for (_slash, path_param) in &self.path_params {
809            if let Some((ident, _ty)) = path_param.capture() {
810                idents.push(ident.clone());
811            }
812        }
813        for param in &self.query_params {
814            idents.push(param.binding.clone());
815        }
816        idents
817    }
818
819    fn remaining_pattypes_named(&self, args: &Punctuated<FnArg, Comma>) -> Vec<(usize, PatType)> {
820        args.iter()
821            .enumerate()
822            .filter_map(|(i, item)| {
823                if let FnArg::Typed(pat_type) = item {
824                    if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
825                        if self.path_params.iter().any(|(_slash, path_param)| {
826                            if let Some((path_ident, _ty)) = path_param.capture() {
827                                path_ident == &pat_ident.ident
828                            } else {
829                                false
830                            }
831                        }) || self
832                            .query_params
833                            .iter()
834                            .any(|query| query.binding == pat_ident.ident)
835                        {
836                            return None;
837                        }
838                    }
839
840                    Some((i, pat_type.clone()))
841                } else {
842                    unimplemented!("Self type is not supported")
843                }
844            })
845            .collect()
846    }
847
848    pub(crate) fn to_doc_comments(&self) -> TokenStream2 {
849        let mut doc = format!(
850            "# Handler information
851- Method: `{}`
852- Path: `{}`",
853            self.method.to_axum_method_name(),
854            self.route_lit
855                .as_ref()
856                .map(|lit| lit.value())
857                .unwrap_or_else(|| "<auto>".into()),
858        );
859
860        if let Some(options) = &self.oapi_options {
861            let summary = options
862                .summary
863                .as_ref()
864                .map(|(_, summary)| format!("\"{}\"", summary.value()))
865                .unwrap_or("None".to_string());
866            let description = options
867                .description
868                .as_ref()
869                .map(|(_, description)| format!("\"{}\"", description.value()))
870                .unwrap_or("None".to_string());
871            let id = options
872                .id
873                .as_ref()
874                .map(|(_, id)| format!("\"{}\"", id.value()))
875                .unwrap_or("None".to_string());
876            let hidden = options
877                .hidden
878                .as_ref()
879                .map(|(_, hidden)| hidden.value().to_string())
880                .unwrap_or("None".to_string());
881            let tags = options
882                .tags
883                .as_ref()
884                .map(|(_, tags)| tags.to_string())
885                .unwrap_or("[]".to_string());
886            let security = options
887                .security
888                .as_ref()
889                .map(|(_, security)| security.to_string())
890                .unwrap_or("{}".to_string());
891
892            doc = format!(
893                "{doc}
894
895## OpenAPI
896- Summary: `{summary}`
897- Description: `{description}`
898- Operation id: `{id}`
899- Tags: `{tags}`
900- Security: `{security}`
901- Hidden: `{hidden}`
902"
903            );
904        }
905
906        quote!(
907            #[doc = #doc]
908        )
909    }
910
911    fn url_without_queries_for_format(&self) -> Option<String> {
912        // If there's no explicit route, we can't generate a format string this way.
913        let _lit = self.route_lit.as_ref()?;
914
915        let url_without_queries =
916            self.path_params
917                .iter()
918                .fold(String::new(), |mut acc, (_slash, param)| {
919                    acc.push('/');
920                    match param {
921                        PathParam::Capture(lit, _brace_1, _, _, _brace_2) => {
922                            acc.push_str(&format!("{{{}}}", lit.value()));
923                        }
924                        PathParam::WildCard(lit, _brace_1, _, _, _, _brace_2) => {
925                            // no `*` since we want to use the argument *as the wildcard* when making requests
926                            // it's not super applicable to server functions, more for general route generation
927                            acc.push_str(&format!("{{{}}}", lit.value()));
928                        }
929                        PathParam::Static(lit) => {
930                            acc.push_str(&lit.value());
931                        }
932                    }
933                    acc
934                });
935
936        let prefix = self
937            .prefix
938            .as_ref()
939            .cloned()
940            .unwrap_or_else(|| LitStr::new("", Span::call_site()))
941            .value();
942        let full_url = format!(
943            "{}{}{}",
944            prefix,
945            if url_without_queries.starts_with("/") {
946                ""
947            } else {
948                "/"
949            },
950            url_without_queries
951        );
952
953        Some(full_url)
954    }
955}
956
957struct RouteParser {
958    path_params: Vec<(Slash, PathParam)>,
959    query_params: Vec<QueryParam>,
960}
961
962impl RouteParser {
963    fn new(lit: LitStr) -> syn::Result<Self> {
964        let val = lit.value();
965        let span = lit.span();
966        let split_route = val.split('?').collect::<Vec<_>>();
967        if split_route.len() > 2 {
968            return Err(syn::Error::new(span, "expected at most one '?'"));
969        }
970
971        let path = split_route[0];
972        if !path.starts_with('/') {
973            return Err(syn::Error::new(span, "expected path to start with '/'"));
974        }
975        let path = path.strip_prefix('/').unwrap();
976
977        let mut path_params = Vec::new();
978
979        for path_param in path.split('/') {
980            path_params.push((
981                Slash(span),
982                PathParam::new(path_param, span, Box::new(parse_quote!(())))?,
983            ));
984        }
985
986        let path_param_len = path_params.len();
987        for (i, (_slash, path_param)) in path_params.iter().enumerate() {
988            match path_param {
989                PathParam::WildCard(_, _, _, _, _, _) => {
990                    if i != path_param_len - 1 {
991                        return Err(syn::Error::new(
992                            span,
993                            "wildcard path param must be the last path param",
994                        ));
995                    }
996                }
997                PathParam::Capture(_, _, _, _, _) => (),
998                PathParam::Static(lit) => {
999                    if lit.value() == "*" && i != path_param_len - 1 {
1000                        return Err(syn::Error::new(
1001                            span,
1002                            "wildcard path param must be the last path param",
1003                        ));
1004                    }
1005                }
1006            }
1007        }
1008
1009        let mut query_params = Vec::new();
1010        if split_route.len() == 2 {
1011            let query = split_route[1];
1012            for query_param in query.split('&') {
1013                if query_param.starts_with(":") {
1014                    let ident = Ident::new(query_param.strip_prefix(":").unwrap(), span);
1015
1016                    query_params.push(QueryParam {
1017                        name: ident.to_string(),
1018                        binding: ident,
1019                        catch_all: true,
1020                        ty: parse_quote!(()),
1021                        arg_idx: usize::MAX,
1022                    });
1023                } else if query_param.starts_with("{") && query_param.ends_with("}") {
1024                    let ident = Ident::new(
1025                        query_param
1026                            .strip_prefix("{")
1027                            .unwrap()
1028                            .strip_suffix("}")
1029                            .unwrap(),
1030                        span,
1031                    );
1032
1033                    query_params.push(QueryParam {
1034                        name: ident.to_string(),
1035                        binding: ident,
1036                        catch_all: true,
1037                        ty: parse_quote!(()),
1038                        arg_idx: usize::MAX,
1039                    });
1040                } else {
1041                    // if there's an `=` in the query param, we only take the left side as the name, and the right side is the binding
1042                    let name;
1043                    let binding;
1044                    if let Some((n, b)) = query_param.split_once('=') {
1045                        name = n;
1046                        binding = Ident::new(b, span);
1047                    } else {
1048                        name = query_param;
1049                        binding = Ident::new(query_param, span);
1050                    }
1051
1052                    query_params.push(QueryParam {
1053                        name: name.to_string(),
1054                        binding,
1055                        catch_all: false,
1056                        ty: parse_quote!(()),
1057                        arg_idx: usize::MAX,
1058                    });
1059                }
1060            }
1061        }
1062
1063        Ok(Self {
1064            path_params,
1065            query_params,
1066        })
1067    }
1068}
1069
1070enum PathParam {
1071    WildCard(LitStr, Brace, Star, Ident, Box<Type>, Brace),
1072    Capture(LitStr, Brace, Ident, Box<Type>, Brace),
1073    Static(LitStr),
1074}
1075
1076impl PathParam {
1077    fn _captures(&self) -> bool {
1078        matches!(self, Self::Capture(..) | Self::WildCard(..))
1079    }
1080
1081    fn capture(&self) -> Option<(&Ident, &Type)> {
1082        match self {
1083            Self::Capture(_, _, ident, ty, _) => Some((ident, ty)),
1084            Self::WildCard(_, _, _, ident, ty, _) => Some((ident, ty)),
1085            _ => None,
1086        }
1087    }
1088
1089    fn new(str: &str, span: Span, ty: Box<Type>) -> syn::Result<Self> {
1090        let ok = if str.starts_with('{') {
1091            let str = str
1092                .strip_prefix('{')
1093                .unwrap()
1094                .strip_suffix('}')
1095                .ok_or_else(|| {
1096                    syn::Error::new(span, "expected path param to be wrapped in curly braces")
1097                })?;
1098            Self::Capture(
1099                LitStr::new(str, span),
1100                Brace(span),
1101                Ident::new(str, span),
1102                ty,
1103                Brace(span),
1104            )
1105        } else if str.starts_with('*') && str.len() > 1 {
1106            let str = str.strip_prefix('*').unwrap();
1107            Self::WildCard(
1108                LitStr::new(str, span),
1109                Brace(span),
1110                Star(span),
1111                Ident::new(str, span),
1112                ty,
1113                Brace(span),
1114            )
1115        } else if str.starts_with(':') && str.len() > 1 {
1116            let str = str.strip_prefix(':').unwrap();
1117            Self::Capture(
1118                LitStr::new(str, span),
1119                Brace(span),
1120                Ident::new(str, span),
1121                ty,
1122                Brace(span),
1123            )
1124        } else {
1125            Self::Static(LitStr::new(str, span))
1126        };
1127
1128        Ok(ok)
1129    }
1130}
1131
1132struct OapiOptions {
1133    summary: Option<(Ident, LitStr)>,
1134    description: Option<(Ident, LitStr)>,
1135    id: Option<(Ident, LitStr)>,
1136    hidden: Option<(Ident, LitBool)>,
1137    tags: Option<(Ident, StrArray)>,
1138    security: Option<(Ident, Security)>,
1139    responses: Option<(Ident, Responses)>,
1140    transform: Option<(Ident, ExprClosure)>,
1141}
1142
1143struct Security(Vec<(LitStr, StrArray)>);
1144impl Parse for Security {
1145    fn parse(input: ParseStream) -> syn::Result<Self> {
1146        let inner;
1147        braced!(inner in input);
1148
1149        let mut arr = Vec::new();
1150        while !inner.is_empty() {
1151            let scheme = inner.parse::<LitStr>()?;
1152            let _ = inner.parse::<Token![:]>()?;
1153            let scopes = inner.parse::<StrArray>()?;
1154            let _ = inner.parse::<Token![,]>().ok();
1155            arr.push((scheme, scopes));
1156        }
1157
1158        Ok(Self(arr))
1159    }
1160}
1161
1162impl std::fmt::Display for Security {
1163    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1164        write!(f, "{{")?;
1165        for (i, (scheme, scopes)) in self.0.iter().enumerate() {
1166            if i > 0 {
1167                write!(f, ", ")?;
1168            }
1169            write!(f, "{}: {}", scheme.value(), scopes)?;
1170        }
1171        write!(f, "}}")
1172    }
1173}
1174
1175struct Responses(Vec<(LitInt, Type)>);
1176impl Parse for Responses {
1177    fn parse(input: ParseStream) -> syn::Result<Self> {
1178        let inner;
1179        braced!(inner in input);
1180
1181        let mut arr = Vec::new();
1182        while !inner.is_empty() {
1183            let status = inner.parse::<LitInt>()?;
1184            let _ = inner.parse::<Token![:]>()?;
1185            let ty = inner.parse::<Type>()?;
1186            let _ = inner.parse::<Token![,]>().ok();
1187            arr.push((status, ty));
1188        }
1189
1190        Ok(Self(arr))
1191    }
1192}
1193
1194impl std::fmt::Display for Responses {
1195    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1196        write!(f, "{{")?;
1197        for (i, (status, ty)) in self.0.iter().enumerate() {
1198            if i > 0 {
1199                write!(f, ", ")?;
1200            }
1201            write!(f, "{}: {}", status, ty.to_token_stream())?;
1202        }
1203        write!(f, "}}")
1204    }
1205}
1206
1207#[derive(Clone)]
1208struct StrArray(Vec<LitStr>);
1209impl Parse for StrArray {
1210    fn parse(input: ParseStream) -> syn::Result<Self> {
1211        let inner;
1212        bracketed!(inner in input);
1213        let mut arr = Vec::new();
1214        while !inner.is_empty() {
1215            arr.push(inner.parse::<LitStr>()?);
1216            inner.parse::<Token![,]>().ok();
1217        }
1218        Ok(Self(arr))
1219    }
1220}
1221
1222impl std::fmt::Display for StrArray {
1223    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1224        write!(f, "[")?;
1225        for (i, lit) in self.0.iter().enumerate() {
1226            if i > 0 {
1227                write!(f, ", ")?;
1228            }
1229            write!(f, "\"{}\"", lit.value())?;
1230        }
1231        write!(f, "]")
1232    }
1233}
1234
1235impl Parse for OapiOptions {
1236    fn parse(input: ParseStream) -> syn::Result<Self> {
1237        let mut this = Self {
1238            summary: None,
1239            description: None,
1240            id: None,
1241            hidden: None,
1242            tags: None,
1243            security: None,
1244            responses: None,
1245            transform: None,
1246        };
1247
1248        while !input.is_empty() {
1249            let ident = input.parse::<Ident>()?;
1250            let _ = input.parse::<Token![:]>()?;
1251            match ident.to_string().as_str() {
1252                "summary" => this.summary = Some((ident, input.parse()?)),
1253                "description" => this.description = Some((ident, input.parse()?)),
1254                "id" => this.id = Some((ident, input.parse()?)),
1255                "hidden" => this.hidden = Some((ident, input.parse()?)),
1256                "tags" => this.tags = Some((ident, input.parse()?)),
1257                "security" => this.security = Some((ident, input.parse()?)),
1258                "responses" => this.responses = Some((ident, input.parse()?)),
1259                "transform" => this.transform = Some((ident, input.parse()?)),
1260                _ => {
1261                    return Err(syn::Error::new(
1262                        ident.span(),
1263                        "unexpected field, expected one of (summary, description, id, hidden, tags, security, responses, transform)",
1264                    ))
1265                }
1266            }
1267            let _ = input.parse::<Token![,]>().ok();
1268        }
1269
1270        Ok(this)
1271    }
1272}
1273
1274impl OapiOptions {
1275    fn merge_with_fn(&mut self, function: &ItemFn) {
1276        if self.description.is_none() {
1277            self.description = doc_iter(&function.attrs)
1278                .skip(2)
1279                .map(|item| item.value())
1280                .reduce(|mut acc, item| {
1281                    acc.push('\n');
1282                    acc.push_str(&item);
1283                    acc
1284                })
1285                .map(|item| (parse_quote!(description), parse_quote!(#item)))
1286        }
1287        if self.summary.is_none() {
1288            self.summary = doc_iter(&function.attrs)
1289                .next()
1290                .map(|item| (parse_quote!(summary), item.clone()))
1291        }
1292        if self.id.is_none() {
1293            let id = &function.sig.ident;
1294            self.id = Some((parse_quote!(id), LitStr::new(&id.to_string(), id.span())));
1295        }
1296    }
1297}
1298
1299fn doc_iter(attrs: &[Attribute]) -> impl Iterator<Item = &LitStr> + '_ {
1300    attrs
1301        .iter()
1302        .filter(|attr| attr.path().is_ident("doc"))
1303        .map(|attr| {
1304            let Meta::NameValue(meta) = &attr.meta else {
1305                panic!("doc attribute is not a name-value attribute");
1306            };
1307            let Expr::Lit(lit) = &meta.value else {
1308                panic!("doc attribute is not a string literal");
1309            };
1310            let Lit::Str(lit_str) = &lit.lit else {
1311                panic!("doc attribute is not a string literal");
1312            };
1313            lit_str
1314        })
1315}
1316
1317struct Route {
1318    method: Option<Method>,
1319    path_params: Vec<(Slash, PathParam)>,
1320    query_params: Vec<QueryParam>,
1321    route_lit: Option<LitStr>,
1322    prefix: Option<LitStr>,
1323    oapi_options: Option<OapiOptions>,
1324    server_args: Punctuated<FnArg, Comma>,
1325
1326    // todo: support these since `server_fn` had them
1327    _input_encoding: Option<Type>,
1328    _output_encoding: Option<Type>,
1329}
1330
1331impl Parse for Route {
1332    fn parse(input: ParseStream) -> syn::Result<Self> {
1333        let method = if input.peek(Ident) {
1334            Some(input.parse::<Method>()?)
1335        } else {
1336            None
1337        };
1338
1339        let route_lit = input.parse::<LitStr>()?;
1340        let RouteParser {
1341            path_params,
1342            query_params,
1343        } = RouteParser::new(route_lit.clone())?;
1344
1345        let oapi_options = input
1346            .peek(Brace)
1347            .then(|| {
1348                let inner;
1349                braced!(inner in input);
1350                inner.parse::<OapiOptions>()
1351            })
1352            .transpose()?;
1353
1354        let server_args = if input.peek(Comma) {
1355            let _ = input.parse::<Comma>()?;
1356            input.parse_terminated(FnArg::parse, Comma)?
1357        } else {
1358            Punctuated::new()
1359        };
1360
1361        Ok(Route {
1362            method,
1363            path_params,
1364            query_params,
1365            route_lit: Some(route_lit),
1366            oapi_options,
1367            server_args,
1368            prefix: None,
1369            _input_encoding: None,
1370            _output_encoding: None,
1371        })
1372    }
1373}
1374
1375#[derive(Clone)]
1376enum Method {
1377    Get(Ident),
1378    Post(Ident),
1379    Put(Ident),
1380    Delete(Ident),
1381    Head(Ident),
1382    Connect(Ident),
1383    Options(Ident),
1384    Trace(Ident),
1385}
1386
1387impl ToTokens for Method {
1388    fn to_tokens(&self, tokens: &mut TokenStream2) {
1389        match self {
1390            Self::Get(ident)
1391            | Self::Post(ident)
1392            | Self::Put(ident)
1393            | Self::Delete(ident)
1394            | Self::Head(ident)
1395            | Self::Connect(ident)
1396            | Self::Options(ident)
1397            | Self::Trace(ident) => {
1398                ident.to_tokens(tokens);
1399            }
1400        }
1401    }
1402}
1403
1404impl Parse for Method {
1405    fn parse(input: ParseStream) -> syn::Result<Self> {
1406        let ident = input.parse::<Ident>()?;
1407        match ident.to_string().to_uppercase().as_str() {
1408            "GET" => Ok(Self::Get(ident)),
1409            "POST" => Ok(Self::Post(ident)),
1410            "PUT" => Ok(Self::Put(ident)),
1411            "DELETE" => Ok(Self::Delete(ident)),
1412            "HEAD" => Ok(Self::Head(ident)),
1413            "CONNECT" => Ok(Self::Connect(ident)),
1414            "OPTIONS" => Ok(Self::Options(ident)),
1415            "TRACE" => Ok(Self::Trace(ident)),
1416            _ => Err(input
1417                .error("expected one of (GET, POST, PUT, DELETE, HEAD, CONNECT, OPTIONS, TRACE)")),
1418        }
1419    }
1420}
1421
1422impl Method {
1423    fn to_axum_method_name(&self) -> Ident {
1424        match self {
1425            Self::Get(span) => Ident::new("get", span.span()),
1426            Self::Post(span) => Ident::new("post", span.span()),
1427            Self::Put(span) => Ident::new("put", span.span()),
1428            Self::Delete(span) => Ident::new("delete", span.span()),
1429            Self::Head(span) => Ident::new("head", span.span()),
1430            Self::Connect(span) => Ident::new("connect", span.span()),
1431            Self::Options(span) => Ident::new("options", span.span()),
1432            Self::Trace(span) => Ident::new("trace", span.span()),
1433        }
1434    }
1435
1436    fn new_from_string(s: &str) -> Self {
1437        match s.to_uppercase().as_str() {
1438            "GET" => Self::Get(Ident::new("GET", Span::call_site())),
1439            "POST" => Self::Post(Ident::new("POST", Span::call_site())),
1440            "PUT" => Self::Put(Ident::new("PUT", Span::call_site())),
1441            "DELETE" => Self::Delete(Ident::new("DELETE", Span::call_site())),
1442            "HEAD" => Self::Head(Ident::new("HEAD", Span::call_site())),
1443            "CONNECT" => Self::Connect(Ident::new("CONNECT", Span::call_site())),
1444            "OPTIONS" => Self::Options(Ident::new("OPTIONS", Span::call_site())),
1445            "TRACE" => Self::Trace(Ident::new("TRACE", Span::call_site())),
1446            _ => panic!("expected one of (GET, POST, PUT, DELETE, HEAD, CONNECT, OPTIONS, TRACE)"),
1447        }
1448    }
1449}
1450
1451mod kw {
1452    syn::custom_keyword!(with);
1453}
1454
1455/// The arguments to the `server` macro.
1456///
1457/// These originally came from the `server_fn` crate, but many no longer apply after the 0.7 fullstack
1458/// overhaul. We keep the parser here for temporary backwards compatibility with existing code, but
1459/// these arguments will be removed in a future release.
1460#[derive(Debug)]
1461#[non_exhaustive]
1462#[allow(unused)]
1463struct ServerFnArgs {
1464    /// The name of the struct that will implement the server function trait
1465    /// and be submitted to inventory.
1466    struct_name: Option<Ident>,
1467    /// The prefix to use for the server function URL.
1468    prefix: Option<LitStr>,
1469    /// The input http encoding to use for the server function.
1470    input: Option<Type>,
1471    /// Additional traits to derive on the input struct for the server function.
1472    input_derive: Option<ExprTuple>,
1473    /// The output http encoding to use for the server function.
1474    output: Option<Type>,
1475    /// The path to the server function crate.
1476    fn_path: Option<LitStr>,
1477    /// The server type to use for the server function.
1478    server: Option<Type>,
1479    /// The client type to use for the server function.
1480    client: Option<Type>,
1481    /// The custom wrapper to use for the server function struct.
1482    custom_wrapper: Option<syn::Path>,
1483    /// If the generated input type should implement `From` the only field in the input
1484    impl_from: Option<LitBool>,
1485    /// If the generated input type should implement `Deref` to the only field in the input
1486    impl_deref: Option<LitBool>,
1487    /// The protocol to use for the server function implementation.
1488    protocol: Option<Type>,
1489    builtin_encoding: bool,
1490}
1491
1492impl Parse for ServerFnArgs {
1493    fn parse(stream: ParseStream) -> syn::Result<Self> {
1494        // legacy 4-part arguments
1495        let mut struct_name: Option<Ident> = None;
1496        let mut prefix: Option<LitStr> = None;
1497        let mut encoding: Option<LitStr> = None;
1498        let mut fn_path: Option<LitStr> = None;
1499
1500        // new arguments: can only be keyed by name
1501        let mut input: Option<Type> = None;
1502        let mut input_derive: Option<ExprTuple> = None;
1503        let mut output: Option<Type> = None;
1504        let mut server: Option<Type> = None;
1505        let mut client: Option<Type> = None;
1506        let mut custom_wrapper: Option<syn::Path> = None;
1507        let mut impl_from: Option<LitBool> = None;
1508        let mut impl_deref: Option<LitBool> = None;
1509        let mut protocol: Option<Type> = None;
1510
1511        let mut use_key_and_value = false;
1512        let mut arg_pos = 0;
1513
1514        while !stream.is_empty() {
1515            arg_pos += 1;
1516            let lookahead = stream.lookahead1();
1517            if lookahead.peek(Ident) {
1518                let key_or_value: Ident = stream.parse()?;
1519
1520                let lookahead = stream.lookahead1();
1521                if lookahead.peek(Token![=]) {
1522                    stream.parse::<Token![=]>()?;
1523                    let key = key_or_value;
1524                    use_key_and_value = true;
1525                    if key == "name" {
1526                        if struct_name.is_some() {
1527                            return Err(syn::Error::new(
1528                                key.span(),
1529                                "keyword argument repeated: `name`",
1530                            ));
1531                        }
1532                        struct_name = Some(stream.parse()?);
1533                    } else if key == "prefix" {
1534                        if prefix.is_some() {
1535                            return Err(syn::Error::new(
1536                                key.span(),
1537                                "keyword argument repeated: `prefix`",
1538                            ));
1539                        }
1540                        prefix = Some(stream.parse()?);
1541                    } else if key == "encoding" {
1542                        if encoding.is_some() {
1543                            return Err(syn::Error::new(
1544                                key.span(),
1545                                "keyword argument repeated: `encoding`",
1546                            ));
1547                        }
1548                        encoding = Some(stream.parse()?);
1549                    } else if key == "endpoint" {
1550                        if fn_path.is_some() {
1551                            return Err(syn::Error::new(
1552                                key.span(),
1553                                "keyword argument repeated: `endpoint`",
1554                            ));
1555                        }
1556                        fn_path = Some(stream.parse()?);
1557                    } else if key == "input" {
1558                        if encoding.is_some() {
1559                            return Err(syn::Error::new(
1560                                key.span(),
1561                                "`encoding` and `input` should not both be \
1562                                 specified",
1563                            ));
1564                        } else if input.is_some() {
1565                            return Err(syn::Error::new(
1566                                key.span(),
1567                                "keyword argument repeated: `input`",
1568                            ));
1569                        }
1570                        input = Some(stream.parse()?);
1571                    } else if key == "input_derive" {
1572                        if input_derive.is_some() {
1573                            return Err(syn::Error::new(
1574                                key.span(),
1575                                "keyword argument repeated: `input_derive`",
1576                            ));
1577                        }
1578                        input_derive = Some(stream.parse()?);
1579                    } else if key == "output" {
1580                        if encoding.is_some() {
1581                            return Err(syn::Error::new(
1582                                key.span(),
1583                                "`encoding` and `output` should not both be \
1584                                 specified",
1585                            ));
1586                        } else if output.is_some() {
1587                            return Err(syn::Error::new(
1588                                key.span(),
1589                                "keyword argument repeated: `output`",
1590                            ));
1591                        }
1592                        output = Some(stream.parse()?);
1593                    } else if key == "server" {
1594                        if server.is_some() {
1595                            return Err(syn::Error::new(
1596                                key.span(),
1597                                "keyword argument repeated: `server`",
1598                            ));
1599                        }
1600                        server = Some(stream.parse()?);
1601                    } else if key == "client" {
1602                        if client.is_some() {
1603                            return Err(syn::Error::new(
1604                                key.span(),
1605                                "keyword argument repeated: `client`",
1606                            ));
1607                        }
1608                        client = Some(stream.parse()?);
1609                    } else if key == "custom" {
1610                        if custom_wrapper.is_some() {
1611                            return Err(syn::Error::new(
1612                                key.span(),
1613                                "keyword argument repeated: `custom`",
1614                            ));
1615                        }
1616                        custom_wrapper = Some(stream.parse()?);
1617                    } else if key == "impl_from" {
1618                        if impl_from.is_some() {
1619                            return Err(syn::Error::new(
1620                                key.span(),
1621                                "keyword argument repeated: `impl_from`",
1622                            ));
1623                        }
1624                        impl_from = Some(stream.parse()?);
1625                    } else if key == "impl_deref" {
1626                        if impl_deref.is_some() {
1627                            return Err(syn::Error::new(
1628                                key.span(),
1629                                "keyword argument repeated: `impl_deref`",
1630                            ));
1631                        }
1632                        impl_deref = Some(stream.parse()?);
1633                    } else if key == "protocol" {
1634                        if protocol.is_some() {
1635                            return Err(syn::Error::new(
1636                                key.span(),
1637                                "keyword argument repeated: `protocol`",
1638                            ));
1639                        }
1640                        protocol = Some(stream.parse()?);
1641                    } else {
1642                        return Err(lookahead.error());
1643                    }
1644                } else {
1645                    let value = key_or_value;
1646                    if use_key_and_value {
1647                        return Err(syn::Error::new(
1648                            value.span(),
1649                            "positional argument follows keyword argument",
1650                        ));
1651                    }
1652                    if arg_pos == 1 {
1653                        struct_name = Some(value)
1654                    } else {
1655                        return Err(syn::Error::new(value.span(), "expected string literal"));
1656                    }
1657                }
1658            } else if lookahead.peek(LitStr) {
1659                if use_key_and_value {
1660                    return Err(syn::Error::new(
1661                        stream.span(),
1662                        "If you use keyword arguments (e.g., `name` = \
1663                         Something), then you can no longer use arguments \
1664                         without a keyword.",
1665                    ));
1666                }
1667                match arg_pos {
1668                    1 => return Err(lookahead.error()),
1669                    2 => prefix = Some(stream.parse()?),
1670                    3 => encoding = Some(stream.parse()?),
1671                    4 => fn_path = Some(stream.parse()?),
1672                    _ => return Err(syn::Error::new(stream.span(), "unexpected extra argument")),
1673                }
1674            } else {
1675                return Err(lookahead.error());
1676            }
1677
1678            if !stream.is_empty() {
1679                stream.parse::<Token![,]>()?;
1680            }
1681        }
1682
1683        // parse legacy encoding into input/output
1684        let mut builtin_encoding = false;
1685        if let Some(encoding) = encoding {
1686            match encoding.value().to_lowercase().as_str() {
1687                "url" => {
1688                    input = Some(type_from_ident(syn::parse_quote!(Url)));
1689                    output = Some(type_from_ident(syn::parse_quote!(Json)));
1690                    builtin_encoding = true;
1691                }
1692                "cbor" => {
1693                    input = Some(type_from_ident(syn::parse_quote!(Cbor)));
1694                    output = Some(type_from_ident(syn::parse_quote!(Cbor)));
1695                    builtin_encoding = true;
1696                }
1697                "getcbor" => {
1698                    input = Some(type_from_ident(syn::parse_quote!(GetUrl)));
1699                    output = Some(type_from_ident(syn::parse_quote!(Cbor)));
1700                    builtin_encoding = true;
1701                }
1702                "getjson" => {
1703                    input = Some(type_from_ident(syn::parse_quote!(GetUrl)));
1704                    output = Some(syn::parse_quote!(Json));
1705                    builtin_encoding = true;
1706                }
1707                _ => return Err(syn::Error::new(encoding.span(), "Encoding not found.")),
1708            }
1709        }
1710
1711        Ok(Self {
1712            struct_name,
1713            prefix,
1714            input,
1715            input_derive,
1716            output,
1717            fn_path,
1718            builtin_encoding,
1719            server,
1720            client,
1721            custom_wrapper,
1722            impl_from,
1723            impl_deref,
1724            protocol,
1725        })
1726    }
1727}
1728
1729/// An argument type in a server function.
1730#[allow(unused)]
1731// todo - we used to support a number of these attributes and pass them along to serde. bring them back.
1732#[derive(Debug, Clone)]
1733struct ServerFnArg {
1734    /// The attributes on the server function argument.
1735    server_fn_attributes: Vec<Attribute>,
1736    /// The type of the server function argument.
1737    arg: syn::PatType,
1738}
1739
1740impl ToTokens for ServerFnArg {
1741    fn to_tokens(&self, tokens: &mut TokenStream2) {
1742        let ServerFnArg { arg, .. } = self;
1743        tokens.extend(quote! {
1744            #arg
1745        });
1746    }
1747}
1748
1749impl Parse for ServerFnArg {
1750    fn parse(input: ParseStream) -> Result<Self> {
1751        let arg: syn::FnArg = input.parse()?;
1752        let mut arg = match arg {
1753            FnArg::Receiver(_) => {
1754                return Err(syn::Error::new(
1755                    arg.span(),
1756                    "cannot use receiver types in server function macro",
1757                ))
1758            }
1759            FnArg::Typed(t) => t,
1760        };
1761
1762        fn rename_path(path: Path, from_ident: Ident, to_ident: Ident) -> Path {
1763            if path.is_ident(&from_ident) {
1764                Path {
1765                    leading_colon: None,
1766                    segments: Punctuated::from_iter([PathSegment {
1767                        ident: to_ident,
1768                        arguments: PathArguments::None,
1769                    }]),
1770                }
1771            } else {
1772                path
1773            }
1774        }
1775
1776        let server_fn_attributes = arg
1777            .attrs
1778            .iter()
1779            .cloned()
1780            .map(|attr| {
1781                if attr.path().is_ident("server") {
1782                    // Allow the following attributes:
1783                    // - #[server(default)]
1784                    // - #[server(rename = "fieldName")]
1785
1786                    // Rename `server` to `serde`
1787                    let attr = Attribute {
1788                        meta: match attr.meta {
1789                            Meta::Path(path) => Meta::Path(rename_path(
1790                                path,
1791                                format_ident!("server"),
1792                                format_ident!("serde"),
1793                            )),
1794                            Meta::List(mut list) => {
1795                                list.path = rename_path(
1796                                    list.path,
1797                                    format_ident!("server"),
1798                                    format_ident!("serde"),
1799                                );
1800                                Meta::List(list)
1801                            }
1802                            Meta::NameValue(mut name_value) => {
1803                                name_value.path = rename_path(
1804                                    name_value.path,
1805                                    format_ident!("server"),
1806                                    format_ident!("serde"),
1807                                );
1808                                Meta::NameValue(name_value)
1809                            }
1810                        },
1811                        ..attr
1812                    };
1813
1814                    let args = attr.parse_args::<Meta>()?;
1815                    match args {
1816                        // #[server(default)]
1817                        Meta::Path(path) if path.is_ident("default") => Ok(attr.clone()),
1818                        // #[server(flatten)]
1819                        Meta::Path(path) if path.is_ident("flatten") => Ok(attr.clone()),
1820                        // #[server(default = "value")]
1821                        Meta::NameValue(name_value) if name_value.path.is_ident("default") => {
1822                            Ok(attr.clone())
1823                        }
1824                        // #[server(skip)]
1825                        Meta::Path(path) if path.is_ident("skip") => Ok(attr.clone()),
1826                        // #[server(rename = "value")]
1827                        Meta::NameValue(name_value) if name_value.path.is_ident("rename") => {
1828                            Ok(attr.clone())
1829                        }
1830                        _ => Err(Error::new(
1831                            attr.span(),
1832                            "Unrecognized #[server] attribute, expected \
1833                             #[server(default)] or #[server(rename = \
1834                             \"fieldName\")]",
1835                        )),
1836                    }
1837                } else if attr.path().is_ident("doc") {
1838                    // Allow #[doc = "documentation"]
1839                    Ok(attr.clone())
1840                } else if attr.path().is_ident("allow") {
1841                    // Allow #[allow(...)]
1842                    Ok(attr.clone())
1843                } else if attr.path().is_ident("deny") {
1844                    // Allow #[deny(...)]
1845                    Ok(attr.clone())
1846                } else if attr.path().is_ident("ignore") {
1847                    // Allow #[ignore]
1848                    Ok(attr.clone())
1849                } else {
1850                    Err(Error::new(
1851                        attr.span(),
1852                        "Unrecognized attribute, expected #[server(...)]",
1853                    ))
1854                }
1855            })
1856            .collect::<Result<Vec<_>>>()?;
1857        arg.attrs = vec![];
1858        Ok(ServerFnArg {
1859            arg,
1860            server_fn_attributes,
1861        })
1862    }
1863}
1864
1865fn type_from_ident(ident: Ident) -> Type {
1866    let mut segments = Punctuated::new();
1867    segments.push(PathSegment {
1868        ident,
1869        arguments: PathArguments::None,
1870    });
1871    Type::Path(TypePath {
1872        qself: None,
1873        path: Path {
1874            leading_colon: None,
1875            segments,
1876        },
1877    })
1878}