dioxus_fullstack/server/
mod.rs

1//! Dioxus utilities for the [Axum](https://docs.rs/axum/latest/axum/index.html) server framework.
2//!
3//! # Example
4//! ```rust, no_run
5//! #![allow(non_snake_case)]
6//! use dioxus::prelude::*;
7//!
8//! fn main() {
9//!     #[cfg(feature = "web")]
10//!     // Hydrate the application on the client
11//!     dioxus::launch(app);
12//!     #[cfg(feature = "server")]
13//!     {
14//!         tokio::runtime::Runtime::new()
15//!             .unwrap()
16//!             .block_on(async move {
17//!                 // Get the address the server should run on. If the CLI is running, the CLI proxies fullstack into the main address
18//!                 // and we use the generated address the CLI gives us
19//!                 let address = dioxus::cli_config::fullstack_address_or_localhost();
20//!                 let listener = tokio::net::TcpListener::bind(address)
21//!                     .await
22//!                     .unwrap();
23//!                 axum::serve(
24//!                         listener,
25//!                         axum::Router::new()
26//!                             // Server side render the application, serve static assets, and register server functions
27//!                             .serve_dioxus_application(ServeConfigBuilder::default(), app)
28//!                             .into_make_service(),
29//!                     )
30//!                     .await
31//!                     .unwrap();
32//!             });
33//!      }
34//! }
35//!
36//! fn app() -> Element {
37//!     let mut text = use_signal(|| "...".to_string());
38//!
39//!     rsx! {
40//!         button {
41//!             onclick: move |_| async move {
42//!                 if let Ok(data) = get_server_data().await {
43//!                     text.set(data);
44//!                 }
45//!             },
46//!             "Run a server function"
47//!         }
48//!         "Server said: {text}"
49//!     }
50//! }
51//!
52//! #[server(GetServerData)]
53//! async fn get_server_data() -> Result<String, ServerFnError> {
54//!     Ok("Hello from the server!".to_string())
55//! }
56//! ```
57
58pub mod launch;
59
60#[allow(unused)]
61pub(crate) type ContextProviders =
62    Arc<Vec<Box<dyn Fn() -> Box<dyn std::any::Any> + Send + Sync + 'static>>>;
63
64use axum::routing::*;
65use axum::{
66    body::{self, Body},
67    extract::State,
68    http::{Request, Response, StatusCode},
69    response::IntoResponse,
70};
71use dioxus_lib::prelude::{Element, VirtualDom};
72use http::header::*;
73
74use std::sync::Arc;
75
76use crate::prelude::*;
77
78/// A extension trait with utilities for integrating Dioxus with your Axum router.
79pub trait DioxusRouterExt<S> {
80    /// Registers server functions with the default handler. This handler function will pass an empty [`DioxusServerContext`] to your server functions.
81    ///
82    /// # Example
83    /// ```rust, no_run
84    /// # use dioxus_lib::prelude::*;
85    /// # use dioxus_fullstack::prelude::*;
86    /// #[tokio::main]
87    /// async fn main() {
88    ///     let addr = dioxus::cli_config::fullstack_address_or_localhost();
89    ///     let router = axum::Router::new()
90    ///         // Register server functions routes with the default handler
91    ///         .register_server_functions()
92    ///         .into_make_service();
93    ///     let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
94    ///     axum::serve(listener, router).await.unwrap();
95    /// }
96    /// ```
97    fn register_server_functions(self) -> Self
98    where
99        Self: Sized,
100    {
101        self.register_server_functions_with_context(Default::default())
102    }
103
104    /// Registers server functions with some additional context to insert into the [`DioxusServerContext`] for that handler.
105    ///
106    /// # Example
107    /// ```rust, no_run
108    /// # use dioxus_lib::prelude::*;
109    /// # use dioxus_fullstack::prelude::*;
110    /// # use std::sync::Arc;
111    /// #[tokio::main]
112    /// async fn main() {
113    ///     let addr = dioxus::cli_config::fullstack_address_or_localhost();
114    ///     let router = axum::Router::new()
115    ///         // Register server functions routes with the default handler
116    ///         .register_server_functions_with_context(Arc::new(vec![Box::new(|| Box::new(1234567890u32))]))
117    ///         .into_make_service();
118    ///     let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
119    ///     axum::serve(listener, router).await.unwrap();
120    /// }
121    /// ```
122    fn register_server_functions_with_context(self, context_providers: ContextProviders) -> Self;
123
124    /// Serves the static WASM for your Dioxus application (except the generated index.html).
125    ///
126    /// # Example
127    /// ```rust, no_run
128    /// # #![allow(non_snake_case)]
129    /// # use dioxus_lib::prelude::*;
130    /// # use dioxus_fullstack::prelude::*;
131    /// #[tokio::main]
132    /// async fn main() {
133    ///     let addr = dioxus::cli_config::fullstack_address_or_localhost();
134    ///     let router = axum::Router::new()
135    ///         // Server side render the application, serve static assets, and register server functions
136    ///         .serve_static_assets()
137    ///         // Server render the application
138    ///         // ...
139    ///         .into_make_service();
140    ///     let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
141    ///     axum::serve(listener, router).await.unwrap();
142    /// }
143    /// ```
144    fn serve_static_assets(self) -> Self
145    where
146        Self: Sized;
147
148    /// Serves the Dioxus application. This will serve a complete server side rendered application.
149    /// This will serve static assets, server render the application, register server functions, and integrate with hot reloading.
150    ///
151    /// # Example
152    /// ```rust, no_run
153    /// # #![allow(non_snake_case)]
154    /// # use dioxus_lib::prelude::*;
155    /// # use dioxus_fullstack::prelude::*;
156    /// #[tokio::main]
157    /// async fn main() {
158    ///     let addr = dioxus::cli_config::fullstack_address_or_localhost();
159    ///     let router = axum::Router::new()
160    ///         // Server side render the application, serve static assets, and register server functions
161    ///         .serve_dioxus_application(ServeConfig::new().unwrap(), app)
162    ///         .into_make_service();
163    ///     let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
164    ///     axum::serve(listener, router).await.unwrap();
165    /// }
166    ///
167    /// fn app() -> Element {
168    ///     rsx! { "Hello World" }
169    /// }
170    /// ```
171    fn serve_dioxus_application<Cfg, Error>(self, cfg: Cfg, app: fn() -> Element) -> Self
172    where
173        Cfg: TryInto<ServeConfig, Error = Error>,
174        Error: std::error::Error,
175        Self: Sized;
176}
177
178impl<S> DioxusRouterExt<S> for Router<S>
179where
180    S: Send + Sync + Clone + 'static,
181{
182    fn register_server_functions_with_context(
183        mut self,
184        context_providers: ContextProviders,
185    ) -> Self {
186        use http::method::Method;
187
188        for (path, method) in server_fn::axum::server_fn_paths() {
189            tracing::trace!("Registering server function: {} {}", method, path);
190            let context_providers = context_providers.clone();
191            let handler = move |req| handle_server_fns_inner(path, context_providers, req);
192            self = match method {
193                Method::GET => self.route(path, get(handler)),
194                Method::POST => self.route(path, post(handler)),
195                Method::PUT => self.route(path, put(handler)),
196                _ => unimplemented!("Unsupported server function method: {}", method),
197            };
198        }
199
200        self
201    }
202
203    fn serve_static_assets(mut self) -> Self {
204        use tower_http::services::{ServeDir, ServeFile};
205
206        let public_path = crate::public_path();
207
208        if !public_path.exists() {
209            return self;
210        }
211
212        // Serve all files in public folder except index.html
213        let dir = std::fs::read_dir(&public_path).unwrap_or_else(|e| {
214            panic!(
215                "Couldn't read public directory at {:?}: {}",
216                &public_path, e
217            )
218        });
219
220        for entry in dir.flatten() {
221            let path = entry.path();
222            if path.ends_with("index.html") {
223                continue;
224            }
225            let route = path
226                .strip_prefix(&public_path)
227                .unwrap()
228                .iter()
229                .map(|segment| {
230                    segment.to_str().unwrap_or_else(|| {
231                        panic!("Failed to convert path segment {:?} to string", segment)
232                    })
233                })
234                .collect::<Vec<_>>()
235                .join("/");
236            let route = format!("/{}", route);
237            if path.is_dir() {
238                self = self.nest_service(&route, ServeDir::new(path).precompressed_br());
239            } else {
240                self = self.nest_service(&route, ServeFile::new(path).precompressed_br());
241            }
242        }
243
244        self
245    }
246
247    fn serve_dioxus_application<Cfg, Error>(self, cfg: Cfg, app: fn() -> Element) -> Self
248    where
249        Cfg: TryInto<ServeConfig, Error = Error>,
250        Error: std::error::Error,
251    {
252        let cfg = cfg.try_into();
253        let context_providers = cfg
254            .as_ref()
255            .map(|cfg| cfg.context_providers.clone())
256            .unwrap_or_default();
257
258        // Add server functions and render index.html
259        let server = self
260            .serve_static_assets()
261            .register_server_functions_with_context(context_providers);
262
263        match cfg {
264            Ok(cfg) => {
265                let ssr_state = SSRState::new(&cfg);
266                server.fallback(
267                    get(render_handler)
268                        .with_state(RenderHandleState::new(cfg, app).with_ssr_state(ssr_state)),
269                )
270            }
271            Err(err) => {
272                tracing::trace!("Failed to create render handler. This is expected if you are only using fullstack for desktop/mobile server functions: {}", err);
273                server
274            }
275        }
276    }
277}
278
279fn apply_request_parts_to_response<B>(
280    headers: hyper::header::HeaderMap,
281    response: &mut axum::response::Response<B>,
282) {
283    let mut_headers = response.headers_mut();
284    for (key, value) in headers.iter() {
285        mut_headers.insert(key, value.clone());
286    }
287}
288
289fn add_server_context(server_context: &DioxusServerContext, context_providers: &ContextProviders) {
290    for index in 0..context_providers.len() {
291        let context_providers = context_providers.clone();
292        server_context.insert_boxed_factory(Box::new(move || context_providers[index]()));
293    }
294}
295
296/// State used by [`render_handler`] to render a dioxus component with axum
297#[derive(Clone)]
298pub struct RenderHandleState {
299    config: ServeConfig,
300    build_virtual_dom: Arc<dyn Fn() -> VirtualDom + Send + Sync>,
301    ssr_state: once_cell::sync::OnceCell<SSRState>,
302}
303
304impl RenderHandleState {
305    /// Create a new [`RenderHandleState`]
306    pub fn new(config: ServeConfig, root: fn() -> Element) -> Self {
307        Self {
308            config,
309            build_virtual_dom: Arc::new(move || VirtualDom::new(root)),
310            ssr_state: Default::default(),
311        }
312    }
313
314    /// Create a new [`RenderHandleState`] with a custom [`VirtualDom`] factory. This method can be used to pass context into the root component of your application.
315    pub fn new_with_virtual_dom_factory(
316        config: ServeConfig,
317        build_virtual_dom: impl Fn() -> VirtualDom + Send + Sync + 'static,
318    ) -> Self {
319        Self {
320            config,
321            build_virtual_dom: Arc::new(build_virtual_dom),
322            ssr_state: Default::default(),
323        }
324    }
325
326    /// Set the [`ServeConfig`] for this [`RenderHandleState`]
327    pub fn with_config(mut self, config: ServeConfig) -> Self {
328        self.config = config;
329        self
330    }
331
332    /// Set the [`SSRState`] for this [`RenderHandleState`]. Sharing a [`SSRState`] between multiple [`RenderHandleState`]s is more efficient than creating a new [`SSRState`] for each [`RenderHandleState`].
333    pub fn with_ssr_state(mut self, ssr_state: SSRState) -> Self {
334        self.ssr_state = once_cell::sync::OnceCell::new();
335        if self.ssr_state.set(ssr_state).is_err() {
336            panic!("SSRState already set");
337        }
338        self
339    }
340
341    fn ssr_state(&self) -> &SSRState {
342        self.ssr_state.get_or_init(|| SSRState::new(&self.config))
343    }
344}
345
346/// SSR renderer handler for Axum with added context injection.
347///
348/// # Example
349/// ```rust,no_run
350/// #![allow(non_snake_case)]
351/// use std::sync::{Arc, Mutex};
352///
353/// use axum::routing::get;
354/// use dioxus::prelude::*;
355///
356/// fn app() -> Element {
357///     rsx! {
358///         "hello!"
359///     }
360/// }
361///
362/// #[tokio::main]
363/// async fn main() {
364///     let addr = dioxus::cli_config::fullstack_address_or_localhost();
365///     let router = axum::Router::new()
366///         // Register server functions, etc.
367///         // Note you can use `register_server_functions_with_context`
368///         // to inject the context into server functions running outside
369///         // of an SSR render context.
370///         .fallback(get(render_handler)
371///             .with_state(RenderHandleState::new(ServeConfig::new().unwrap(), app))
372///         )
373///         .into_make_service();
374///     let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
375///     axum::serve(listener, router).await.unwrap();
376/// }
377/// ```
378pub async fn render_handler(
379    State(state): State<RenderHandleState>,
380    request: Request<Body>,
381) -> impl IntoResponse {
382    // Only respond to requests for HTML
383    if let Some(mime) = request.headers().get("Accept") {
384        let mime = mime.to_str().map(|mime| mime.to_ascii_lowercase());
385        match mime {
386            Ok(accepts) if accepts.contains("text/html") => {}
387            _ => return Err(StatusCode::NOT_ACCEPTABLE),
388        }
389    }
390
391    let cfg = &state.config;
392    let ssr_state = state.ssr_state();
393    let build_virtual_dom = {
394        let build_virtual_dom = state.build_virtual_dom.clone();
395        let context_providers = state.config.context_providers.clone();
396        move || {
397            let mut vdom = build_virtual_dom();
398            for state in context_providers.as_slice() {
399                vdom.insert_any_root_context(state());
400            }
401            vdom
402        }
403    };
404
405    let (parts, _) = request.into_parts();
406    let url = parts
407        .uri
408        .path_and_query()
409        .ok_or(StatusCode::BAD_REQUEST)?
410        .to_string();
411    let parts: Arc<parking_lot::RwLock<http::request::Parts>> =
412        Arc::new(parking_lot::RwLock::new(parts));
413    // Create the server context with info from the request
414    let server_context = DioxusServerContext::from_shared_parts(parts.clone());
415    // Provide additional context from the render state
416    add_server_context(&server_context, &state.config.context_providers);
417
418    match ssr_state
419        .render(url, cfg, build_virtual_dom, &server_context)
420        .await
421    {
422        Ok((freshness, rx)) => {
423            let mut response = axum::response::Html::from(Body::from_stream(rx)).into_response();
424            freshness.write(response.headers_mut());
425            let headers = server_context.response_parts().headers.clone();
426            apply_request_parts_to_response(headers, &mut response);
427            Ok(response)
428        }
429        Err(e) => {
430            tracing::error!("Failed to render page: {}", e);
431            Ok(report_err(e).into_response())
432        }
433    }
434}
435
436fn report_err<E: std::fmt::Display>(e: E) -> Response<axum::body::Body> {
437    Response::builder()
438        .status(StatusCode::INTERNAL_SERVER_ERROR)
439        .body(body::Body::new(format!("Error: {}", e)))
440        .unwrap()
441}
442
443/// A handler for Dioxus server functions. This will run the server function and return the result.
444async fn handle_server_fns_inner(
445    path: &str,
446    additional_context: ContextProviders,
447    req: Request<Body>,
448) -> impl IntoResponse {
449    use server_fn::middleware::Service;
450
451    let path_string = path.to_string();
452
453    let future = move || async move {
454        let (parts, body) = req.into_parts();
455        let req = Request::from_parts(parts.clone(), body);
456
457        if let Some(mut service) =
458            server_fn::axum::get_server_fn_service(&path_string)
459        {
460            // Create the server context with info from the request
461            let server_context = DioxusServerContext::new(parts);
462            // Provide additional context from the render state
463            add_server_context(&server_context, &additional_context);
464
465            // store Accepts and Referrer in case we need them for redirect (below)
466            let accepts_html = req
467                .headers()
468                .get(ACCEPT)
469                .and_then(|v| v.to_str().ok())
470                .map(|v| v.contains("text/html"))
471                .unwrap_or(false);
472            let referrer = req.headers().get(REFERER).cloned();
473
474            // actually run the server fn (which may use the server context)
475            let fut = with_server_context(server_context.clone(), || service.run(req));
476            let mut res = ProvideServerContext::new(fut, server_context.clone()).await;
477
478            // it it accepts text/html (i.e., is a plain form post) and doesn't already have a
479            // Location set, then redirect to Referer
480            if accepts_html {
481                if let Some(referrer) = referrer {
482                    let has_location = res.headers().get(LOCATION).is_some();
483                    if !has_location {
484                        *res.status_mut() = StatusCode::FOUND;
485                        res.headers_mut().insert(LOCATION, referrer);
486                    }
487                }
488            }
489
490            // apply the response parts from the server context to the response
491            let mut res_options = server_context.response_parts_mut();
492            res.headers_mut().extend(res_options.headers.drain());
493
494            Ok(res)
495        } else {
496            Response::builder().status(StatusCode::BAD_REQUEST).body(
497                {
498                    #[cfg(target_family = "wasm")]
499                    {
500                        Body::from(format!(
501                            "No server function found for path: {path_string}\nYou may need to explicitly register the server function with `register_explicit`, rebuild your wasm binary to update a server function link or make sure the prefix your server and client use for server functions match.",
502                        ))
503                    }
504                    #[cfg(not(target_family = "wasm"))]
505                    {
506                        Body::from(format!(
507                            "No server function found for path: {path_string}\nYou may need to rebuild your wasm binary to update a server function link or make sure the prefix your server and client use for server functions match.",
508                        ))
509                    }
510                }
511            )
512        }
513        .expect("could not build Response")
514    };
515    #[cfg(target_arch = "wasm32")]
516    {
517        use futures_util::future::FutureExt;
518
519        let result = tokio::task::spawn_local(future);
520        let result = result.then(|f| async move { f.unwrap() });
521        result.await.unwrap_or_else(|e| {
522            use server_fn::error::NoCustomError;
523            use server_fn::error::ServerFnErrorSerde;
524            (
525                StatusCode::INTERNAL_SERVER_ERROR,
526                ServerFnError::<NoCustomError>::ServerError(e.to_string())
527                    .ser()
528                    .unwrap_or_default(),
529            )
530                .into_response()
531        })
532    }
533    #[cfg(not(target_arch = "wasm32"))]
534    {
535        future().await
536    }
537}