dioxus_fullstack/server/mod.rs
1//! Dioxus utilities for the [Axum](https://docs.rs/axum/latest/axum/index.html) server framework.
2//!
3//! # Example
4//! ```rust, no_run
5//! #![allow(non_snake_case)]
6//! use dioxus::prelude::*;
7//!
8//! fn main() {
9//! #[cfg(feature = "web")]
10//! // Hydrate the application on the client
11//! dioxus::launch(app);
12//! #[cfg(feature = "server")]
13//! {
14//! tokio::runtime::Runtime::new()
15//! .unwrap()
16//! .block_on(async move {
17//! // Get the address the server should run on. If the CLI is running, the CLI proxies fullstack into the main address
18//! // and we use the generated address the CLI gives us
19//! let address = dioxus::cli_config::fullstack_address_or_localhost();
20//! let listener = tokio::net::TcpListener::bind(address)
21//! .await
22//! .unwrap();
23//! axum::serve(
24//! listener,
25//! axum::Router::new()
26//! // Server side render the application, serve static assets, and register server functions
27//! .serve_dioxus_application(ServeConfigBuilder::default(), app)
28//! .into_make_service(),
29//! )
30//! .await
31//! .unwrap();
32//! });
33//! }
34//! }
35//!
36//! fn app() -> Element {
37//! let mut text = use_signal(|| "...".to_string());
38//!
39//! rsx! {
40//! button {
41//! onclick: move |_| async move {
42//! if let Ok(data) = get_server_data().await {
43//! text.set(data);
44//! }
45//! },
46//! "Run a server function"
47//! }
48//! "Server said: {text}"
49//! }
50//! }
51//!
52//! #[server(GetServerData)]
53//! async fn get_server_data() -> Result<String, ServerFnError> {
54//! Ok("Hello from the server!".to_string())
55//! }
56//! ```
57
58pub mod launch;
59
60#[allow(unused)]
61pub(crate) type ContextProviders =
62 Arc<Vec<Box<dyn Fn() -> Box<dyn std::any::Any> + Send + Sync + 'static>>>;
63
64use axum::routing::*;
65use axum::{
66 body::{self, Body},
67 extract::State,
68 http::{Request, Response, StatusCode},
69 response::IntoResponse,
70};
71use dioxus_lib::prelude::{Element, VirtualDom};
72use http::header::*;
73
74use std::sync::Arc;
75
76use crate::prelude::*;
77
78/// A extension trait with utilities for integrating Dioxus with your Axum router.
79pub trait DioxusRouterExt<S> {
80 /// Registers server functions with the default handler. This handler function will pass an empty [`DioxusServerContext`] to your server functions.
81 ///
82 /// # Example
83 /// ```rust, no_run
84 /// # use dioxus_lib::prelude::*;
85 /// # use dioxus_fullstack::prelude::*;
86 /// #[tokio::main]
87 /// async fn main() {
88 /// let addr = dioxus::cli_config::fullstack_address_or_localhost();
89 /// let router = axum::Router::new()
90 /// // Register server functions routes with the default handler
91 /// .register_server_functions()
92 /// .into_make_service();
93 /// let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
94 /// axum::serve(listener, router).await.unwrap();
95 /// }
96 /// ```
97 fn register_server_functions(self) -> Self
98 where
99 Self: Sized,
100 {
101 self.register_server_functions_with_context(Default::default())
102 }
103
104 /// Registers server functions with some additional context to insert into the [`DioxusServerContext`] for that handler.
105 ///
106 /// # Example
107 /// ```rust, no_run
108 /// # use dioxus_lib::prelude::*;
109 /// # use dioxus_fullstack::prelude::*;
110 /// # use std::sync::Arc;
111 /// #[tokio::main]
112 /// async fn main() {
113 /// let addr = dioxus::cli_config::fullstack_address_or_localhost();
114 /// let router = axum::Router::new()
115 /// // Register server functions routes with the default handler
116 /// .register_server_functions_with_context(Arc::new(vec![Box::new(|| Box::new(1234567890u32))]))
117 /// .into_make_service();
118 /// let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
119 /// axum::serve(listener, router).await.unwrap();
120 /// }
121 /// ```
122 fn register_server_functions_with_context(self, context_providers: ContextProviders) -> Self;
123
124 /// Serves the static WASM for your Dioxus application (except the generated index.html).
125 ///
126 /// # Example
127 /// ```rust, no_run
128 /// # #![allow(non_snake_case)]
129 /// # use dioxus_lib::prelude::*;
130 /// # use dioxus_fullstack::prelude::*;
131 /// #[tokio::main]
132 /// async fn main() {
133 /// let addr = dioxus::cli_config::fullstack_address_or_localhost();
134 /// let router = axum::Router::new()
135 /// // Server side render the application, serve static assets, and register server functions
136 /// .serve_static_assets()
137 /// // Server render the application
138 /// // ...
139 /// .into_make_service();
140 /// let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
141 /// axum::serve(listener, router).await.unwrap();
142 /// }
143 /// ```
144 fn serve_static_assets(self) -> Self
145 where
146 Self: Sized;
147
148 /// Serves the Dioxus application. This will serve a complete server side rendered application.
149 /// This will serve static assets, server render the application, register server functions, and integrate with hot reloading.
150 ///
151 /// # Example
152 /// ```rust, no_run
153 /// # #![allow(non_snake_case)]
154 /// # use dioxus_lib::prelude::*;
155 /// # use dioxus_fullstack::prelude::*;
156 /// #[tokio::main]
157 /// async fn main() {
158 /// let addr = dioxus::cli_config::fullstack_address_or_localhost();
159 /// let router = axum::Router::new()
160 /// // Server side render the application, serve static assets, and register server functions
161 /// .serve_dioxus_application(ServeConfig::new().unwrap(), app)
162 /// .into_make_service();
163 /// let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
164 /// axum::serve(listener, router).await.unwrap();
165 /// }
166 ///
167 /// fn app() -> Element {
168 /// rsx! { "Hello World" }
169 /// }
170 /// ```
171 fn serve_dioxus_application<Cfg, Error>(self, cfg: Cfg, app: fn() -> Element) -> Self
172 where
173 Cfg: TryInto<ServeConfig, Error = Error>,
174 Error: std::error::Error,
175 Self: Sized;
176}
177
178impl<S> DioxusRouterExt<S> for Router<S>
179where
180 S: Send + Sync + Clone + 'static,
181{
182 fn register_server_functions_with_context(
183 mut self,
184 context_providers: ContextProviders,
185 ) -> Self {
186 use http::method::Method;
187
188 for (path, method) in server_fn::axum::server_fn_paths() {
189 tracing::trace!("Registering server function: {} {}", method, path);
190 let context_providers = context_providers.clone();
191 let handler = move |req| handle_server_fns_inner(path, context_providers, req);
192 self = match method {
193 Method::GET => self.route(path, get(handler)),
194 Method::POST => self.route(path, post(handler)),
195 Method::PUT => self.route(path, put(handler)),
196 _ => unimplemented!("Unsupported server function method: {}", method),
197 };
198 }
199
200 self
201 }
202
203 fn serve_static_assets(mut self) -> Self {
204 use tower_http::services::{ServeDir, ServeFile};
205
206 let public_path = crate::public_path();
207
208 if !public_path.exists() {
209 return self;
210 }
211
212 // Serve all files in public folder except index.html
213 let dir = std::fs::read_dir(&public_path).unwrap_or_else(|e| {
214 panic!(
215 "Couldn't read public directory at {:?}: {}",
216 &public_path, e
217 )
218 });
219
220 for entry in dir.flatten() {
221 let path = entry.path();
222 if path.ends_with("index.html") {
223 continue;
224 }
225 let route = path
226 .strip_prefix(&public_path)
227 .unwrap()
228 .iter()
229 .map(|segment| {
230 segment.to_str().unwrap_or_else(|| {
231 panic!("Failed to convert path segment {:?} to string", segment)
232 })
233 })
234 .collect::<Vec<_>>()
235 .join("/");
236 let route = format!("/{}", route);
237 if path.is_dir() {
238 self = self.nest_service(&route, ServeDir::new(path).precompressed_br());
239 } else {
240 self = self.nest_service(&route, ServeFile::new(path).precompressed_br());
241 }
242 }
243
244 self
245 }
246
247 fn serve_dioxus_application<Cfg, Error>(self, cfg: Cfg, app: fn() -> Element) -> Self
248 where
249 Cfg: TryInto<ServeConfig, Error = Error>,
250 Error: std::error::Error,
251 {
252 let cfg = cfg.try_into();
253 let context_providers = cfg
254 .as_ref()
255 .map(|cfg| cfg.context_providers.clone())
256 .unwrap_or_default();
257
258 // Add server functions and render index.html
259 let server = self
260 .serve_static_assets()
261 .register_server_functions_with_context(context_providers);
262
263 match cfg {
264 Ok(cfg) => {
265 let ssr_state = SSRState::new(&cfg);
266 server.fallback(
267 get(render_handler)
268 .with_state(RenderHandleState::new(cfg, app).with_ssr_state(ssr_state)),
269 )
270 }
271 Err(err) => {
272 tracing::trace!("Failed to create render handler. This is expected if you are only using fullstack for desktop/mobile server functions: {}", err);
273 server
274 }
275 }
276 }
277}
278
279fn apply_request_parts_to_response<B>(
280 headers: hyper::header::HeaderMap,
281 response: &mut axum::response::Response<B>,
282) {
283 let mut_headers = response.headers_mut();
284 for (key, value) in headers.iter() {
285 mut_headers.insert(key, value.clone());
286 }
287}
288
289fn add_server_context(server_context: &DioxusServerContext, context_providers: &ContextProviders) {
290 for index in 0..context_providers.len() {
291 let context_providers = context_providers.clone();
292 server_context.insert_boxed_factory(Box::new(move || context_providers[index]()));
293 }
294}
295
296/// State used by [`render_handler`] to render a dioxus component with axum
297#[derive(Clone)]
298pub struct RenderHandleState {
299 config: ServeConfig,
300 build_virtual_dom: Arc<dyn Fn() -> VirtualDom + Send + Sync>,
301 ssr_state: once_cell::sync::OnceCell<SSRState>,
302}
303
304impl RenderHandleState {
305 /// Create a new [`RenderHandleState`]
306 pub fn new(config: ServeConfig, root: fn() -> Element) -> Self {
307 Self {
308 config,
309 build_virtual_dom: Arc::new(move || VirtualDom::new(root)),
310 ssr_state: Default::default(),
311 }
312 }
313
314 /// Create a new [`RenderHandleState`] with a custom [`VirtualDom`] factory. This method can be used to pass context into the root component of your application.
315 pub fn new_with_virtual_dom_factory(
316 config: ServeConfig,
317 build_virtual_dom: impl Fn() -> VirtualDom + Send + Sync + 'static,
318 ) -> Self {
319 Self {
320 config,
321 build_virtual_dom: Arc::new(build_virtual_dom),
322 ssr_state: Default::default(),
323 }
324 }
325
326 /// Set the [`ServeConfig`] for this [`RenderHandleState`]
327 pub fn with_config(mut self, config: ServeConfig) -> Self {
328 self.config = config;
329 self
330 }
331
332 /// Set the [`SSRState`] for this [`RenderHandleState`]. Sharing a [`SSRState`] between multiple [`RenderHandleState`]s is more efficient than creating a new [`SSRState`] for each [`RenderHandleState`].
333 pub fn with_ssr_state(mut self, ssr_state: SSRState) -> Self {
334 self.ssr_state = once_cell::sync::OnceCell::new();
335 if self.ssr_state.set(ssr_state).is_err() {
336 panic!("SSRState already set");
337 }
338 self
339 }
340
341 fn ssr_state(&self) -> &SSRState {
342 self.ssr_state.get_or_init(|| SSRState::new(&self.config))
343 }
344}
345
346/// SSR renderer handler for Axum with added context injection.
347///
348/// # Example
349/// ```rust,no_run
350/// #![allow(non_snake_case)]
351/// use std::sync::{Arc, Mutex};
352///
353/// use axum::routing::get;
354/// use dioxus::prelude::*;
355///
356/// fn app() -> Element {
357/// rsx! {
358/// "hello!"
359/// }
360/// }
361///
362/// #[tokio::main]
363/// async fn main() {
364/// let addr = dioxus::cli_config::fullstack_address_or_localhost();
365/// let router = axum::Router::new()
366/// // Register server functions, etc.
367/// // Note you can use `register_server_functions_with_context`
368/// // to inject the context into server functions running outside
369/// // of an SSR render context.
370/// .fallback(get(render_handler)
371/// .with_state(RenderHandleState::new(ServeConfig::new().unwrap(), app))
372/// )
373/// .into_make_service();
374/// let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
375/// axum::serve(listener, router).await.unwrap();
376/// }
377/// ```
378pub async fn render_handler(
379 State(state): State<RenderHandleState>,
380 request: Request<Body>,
381) -> impl IntoResponse {
382 // Only respond to requests for HTML
383 if let Some(mime) = request.headers().get("Accept") {
384 let mime = mime.to_str().map(|mime| mime.to_ascii_lowercase());
385 match mime {
386 Ok(accepts) if accepts.contains("text/html") => {}
387 _ => return Err(StatusCode::NOT_ACCEPTABLE),
388 }
389 }
390
391 let cfg = &state.config;
392 let ssr_state = state.ssr_state();
393 let build_virtual_dom = {
394 let build_virtual_dom = state.build_virtual_dom.clone();
395 let context_providers = state.config.context_providers.clone();
396 move || {
397 let mut vdom = build_virtual_dom();
398 for state in context_providers.as_slice() {
399 vdom.insert_any_root_context(state());
400 }
401 vdom
402 }
403 };
404
405 let (parts, _) = request.into_parts();
406 let url = parts
407 .uri
408 .path_and_query()
409 .ok_or(StatusCode::BAD_REQUEST)?
410 .to_string();
411 let parts: Arc<parking_lot::RwLock<http::request::Parts>> =
412 Arc::new(parking_lot::RwLock::new(parts));
413 // Create the server context with info from the request
414 let server_context = DioxusServerContext::from_shared_parts(parts.clone());
415 // Provide additional context from the render state
416 add_server_context(&server_context, &state.config.context_providers);
417
418 match ssr_state
419 .render(url, cfg, build_virtual_dom, &server_context)
420 .await
421 {
422 Ok((freshness, rx)) => {
423 let mut response = axum::response::Html::from(Body::from_stream(rx)).into_response();
424 freshness.write(response.headers_mut());
425 let headers = server_context.response_parts().headers.clone();
426 apply_request_parts_to_response(headers, &mut response);
427 Ok(response)
428 }
429 Err(e) => {
430 tracing::error!("Failed to render page: {}", e);
431 Ok(report_err(e).into_response())
432 }
433 }
434}
435
436fn report_err<E: std::fmt::Display>(e: E) -> Response<axum::body::Body> {
437 Response::builder()
438 .status(StatusCode::INTERNAL_SERVER_ERROR)
439 .body(body::Body::new(format!("Error: {}", e)))
440 .unwrap()
441}
442
443/// A handler for Dioxus server functions. This will run the server function and return the result.
444async fn handle_server_fns_inner(
445 path: &str,
446 additional_context: ContextProviders,
447 req: Request<Body>,
448) -> impl IntoResponse {
449 use server_fn::middleware::Service;
450
451 let path_string = path.to_string();
452
453 let future = move || async move {
454 let (parts, body) = req.into_parts();
455 let req = Request::from_parts(parts.clone(), body);
456
457 if let Some(mut service) =
458 server_fn::axum::get_server_fn_service(&path_string)
459 {
460 // Create the server context with info from the request
461 let server_context = DioxusServerContext::new(parts);
462 // Provide additional context from the render state
463 add_server_context(&server_context, &additional_context);
464
465 // store Accepts and Referrer in case we need them for redirect (below)
466 let accepts_html = req
467 .headers()
468 .get(ACCEPT)
469 .and_then(|v| v.to_str().ok())
470 .map(|v| v.contains("text/html"))
471 .unwrap_or(false);
472 let referrer = req.headers().get(REFERER).cloned();
473
474 // actually run the server fn (which may use the server context)
475 let fut = with_server_context(server_context.clone(), || service.run(req));
476 let mut res = ProvideServerContext::new(fut, server_context.clone()).await;
477
478 // it it accepts text/html (i.e., is a plain form post) and doesn't already have a
479 // Location set, then redirect to Referer
480 if accepts_html {
481 if let Some(referrer) = referrer {
482 let has_location = res.headers().get(LOCATION).is_some();
483 if !has_location {
484 *res.status_mut() = StatusCode::FOUND;
485 res.headers_mut().insert(LOCATION, referrer);
486 }
487 }
488 }
489
490 // apply the response parts from the server context to the response
491 let mut res_options = server_context.response_parts_mut();
492 res.headers_mut().extend(res_options.headers.drain());
493
494 Ok(res)
495 } else {
496 Response::builder().status(StatusCode::BAD_REQUEST).body(
497 {
498 #[cfg(target_family = "wasm")]
499 {
500 Body::from(format!(
501 "No server function found for path: {path_string}\nYou may need to explicitly register the server function with `register_explicit`, rebuild your wasm binary to update a server function link or make sure the prefix your server and client use for server functions match.",
502 ))
503 }
504 #[cfg(not(target_family = "wasm"))]
505 {
506 Body::from(format!(
507 "No server function found for path: {path_string}\nYou may need to rebuild your wasm binary to update a server function link or make sure the prefix your server and client use for server functions match.",
508 ))
509 }
510 }
511 )
512 }
513 .expect("could not build Response")
514 };
515 #[cfg(target_arch = "wasm32")]
516 {
517 use futures_util::future::FutureExt;
518
519 let result = tokio::task::spawn_local(future);
520 let result = result.then(|f| async move { f.unwrap() });
521 result.await.unwrap_or_else(|e| {
522 use server_fn::error::NoCustomError;
523 use server_fn::error::ServerFnErrorSerde;
524 (
525 StatusCode::INTERNAL_SERVER_ERROR,
526 ServerFnError::<NoCustomError>::ServerError(e.to_string())
527 .ser()
528 .unwrap_or_default(),
529 )
530 .into_response()
531 })
532 }
533 #[cfg(not(target_arch = "wasm32"))]
534 {
535 future().await
536 }
537}