pub mod launch;
#[allow(unused)]
pub(crate) type ContextProviders =
Arc<Vec<Box<dyn Fn() -> Box<dyn std::any::Any> + Send + Sync + 'static>>>;
use axum::routing::*;
use axum::{
body::{self, Body},
extract::State,
http::{Request, Response, StatusCode},
response::IntoResponse,
};
use dioxus_lib::prelude::{Element, VirtualDom};
use http::header::*;
use std::sync::Arc;
use crate::prelude::*;
pub trait DioxusRouterExt<S> {
fn register_server_functions(self) -> Self
where
Self: Sized,
{
self.register_server_functions_with_context(Default::default())
}
fn register_server_functions_with_context(self, context_providers: ContextProviders) -> Self;
fn serve_static_assets(self) -> Self
where
Self: Sized;
fn serve_dioxus_application<Cfg, Error>(self, cfg: Cfg, app: fn() -> Element) -> Self
where
Cfg: TryInto<ServeConfig, Error = Error>,
Error: std::error::Error,
Self: Sized;
}
impl<S> DioxusRouterExt<S> for Router<S>
where
S: Send + Sync + Clone + 'static,
{
fn register_server_functions_with_context(
mut self,
context_providers: ContextProviders,
) -> Self {
use http::method::Method;
for (path, method) in server_fn::axum::server_fn_paths() {
tracing::trace!("Registering server function: {} {}", method, path);
let context_providers = context_providers.clone();
let handler = move |req| handle_server_fns_inner(path, context_providers, req);
self = match method {
Method::GET => self.route(path, get(handler)),
Method::POST => self.route(path, post(handler)),
Method::PUT => self.route(path, put(handler)),
_ => unimplemented!("Unsupported server function method: {}", method),
};
}
self
}
fn serve_static_assets(mut self) -> Self {
use tower_http::services::{ServeDir, ServeFile};
let public_path = crate::public_path();
if !public_path.exists() {
return self;
}
let dir = std::fs::read_dir(&public_path).unwrap_or_else(|e| {
panic!(
"Couldn't read public directory at {:?}: {}",
&public_path, e
)
});
for entry in dir.flatten() {
let path = entry.path();
if path.ends_with("index.html") {
continue;
}
let route = path
.strip_prefix(&public_path)
.unwrap()
.iter()
.map(|segment| {
segment.to_str().unwrap_or_else(|| {
panic!("Failed to convert path segment {:?} to string", segment)
})
})
.collect::<Vec<_>>()
.join("/");
let route = format!("/{}", route);
if path.is_dir() {
self = self.nest_service(&route, ServeDir::new(path).precompressed_br());
} else {
self = self.nest_service(&route, ServeFile::new(path).precompressed_br());
}
}
self
}
fn serve_dioxus_application<Cfg, Error>(self, cfg: Cfg, app: fn() -> Element) -> Self
where
Cfg: TryInto<ServeConfig, Error = Error>,
Error: std::error::Error,
{
let cfg = cfg.try_into();
let context_providers = cfg
.as_ref()
.map(|cfg| cfg.context_providers.clone())
.unwrap_or_default();
let server = self
.serve_static_assets()
.register_server_functions_with_context(context_providers);
match cfg {
Ok(cfg) => {
let ssr_state = SSRState::new(&cfg);
server.fallback(
get(render_handler)
.with_state(RenderHandleState::new(cfg, app).with_ssr_state(ssr_state)),
)
}
Err(err) => {
tracing::trace!("Failed to create render handler. This is expected if you are only using fullstack for desktop/mobile server functions: {}", err);
server
}
}
}
}
fn apply_request_parts_to_response<B>(
headers: hyper::header::HeaderMap,
response: &mut axum::response::Response<B>,
) {
let mut_headers = response.headers_mut();
for (key, value) in headers.iter() {
mut_headers.insert(key, value.clone());
}
}
fn add_server_context(server_context: &DioxusServerContext, context_providers: &ContextProviders) {
for index in 0..context_providers.len() {
let context_providers = context_providers.clone();
server_context.insert_boxed_factory(Box::new(move || context_providers[index]()));
}
}
#[derive(Clone)]
pub struct RenderHandleState {
config: ServeConfig,
build_virtual_dom: Arc<dyn Fn() -> VirtualDom + Send + Sync>,
ssr_state: once_cell::sync::OnceCell<SSRState>,
}
impl RenderHandleState {
pub fn new(config: ServeConfig, root: fn() -> Element) -> Self {
Self {
config,
build_virtual_dom: Arc::new(move || VirtualDom::new(root)),
ssr_state: Default::default(),
}
}
pub fn new_with_virtual_dom_factory(
config: ServeConfig,
build_virtual_dom: impl Fn() -> VirtualDom + Send + Sync + 'static,
) -> Self {
Self {
config,
build_virtual_dom: Arc::new(build_virtual_dom),
ssr_state: Default::default(),
}
}
pub fn with_config(mut self, config: ServeConfig) -> Self {
self.config = config;
self
}
pub fn with_ssr_state(mut self, ssr_state: SSRState) -> Self {
self.ssr_state = once_cell::sync::OnceCell::new();
if self.ssr_state.set(ssr_state).is_err() {
panic!("SSRState already set");
}
self
}
fn ssr_state(&self) -> &SSRState {
self.ssr_state.get_or_init(|| SSRState::new(&self.config))
}
}
pub async fn render_handler(
State(state): State<RenderHandleState>,
request: Request<Body>,
) -> impl IntoResponse {
if let Some(mime) = request.headers().get("Accept") {
let mime = mime.to_str().map(|mime| mime.to_ascii_lowercase());
match mime {
Ok(accepts) if accepts.contains("text/html") => {}
_ => return Err(StatusCode::NOT_ACCEPTABLE),
}
}
let cfg = &state.config;
let ssr_state = state.ssr_state();
let build_virtual_dom = {
let build_virtual_dom = state.build_virtual_dom.clone();
let context_providers = state.config.context_providers.clone();
move || {
let mut vdom = build_virtual_dom();
for state in context_providers.as_slice() {
vdom.insert_any_root_context(state());
}
vdom
}
};
let (parts, _) = request.into_parts();
let url = parts
.uri
.path_and_query()
.ok_or(StatusCode::BAD_REQUEST)?
.to_string();
let parts: Arc<parking_lot::RwLock<http::request::Parts>> =
Arc::new(parking_lot::RwLock::new(parts));
let server_context = DioxusServerContext::from_shared_parts(parts.clone());
add_server_context(&server_context, &state.config.context_providers);
match ssr_state
.render(url, cfg, build_virtual_dom, &server_context)
.await
{
Ok((freshness, rx)) => {
let mut response = axum::response::Html::from(Body::from_stream(rx)).into_response();
freshness.write(response.headers_mut());
let headers = server_context.response_parts().headers.clone();
apply_request_parts_to_response(headers, &mut response);
Ok(response)
}
Err(e) => {
tracing::error!("Failed to render page: {}", e);
Ok(report_err(e).into_response())
}
}
}
fn report_err<E: std::fmt::Display>(e: E) -> Response<axum::body::Body> {
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(body::Body::new(format!("Error: {}", e)))
.unwrap()
}
async fn handle_server_fns_inner(
path: &str,
additional_context: ContextProviders,
req: Request<Body>,
) -> impl IntoResponse {
use server_fn::middleware::Service;
let path_string = path.to_string();
let future = move || async move {
let (parts, body) = req.into_parts();
let req = Request::from_parts(parts.clone(), body);
if let Some(mut service) =
server_fn::axum::get_server_fn_service(&path_string)
{
let server_context = DioxusServerContext::new(parts);
add_server_context(&server_context, &additional_context);
let accepts_html = req
.headers()
.get(ACCEPT)
.and_then(|v| v.to_str().ok())
.map(|v| v.contains("text/html"))
.unwrap_or(false);
let referrer = req.headers().get(REFERER).cloned();
let fut = with_server_context(server_context.clone(), || service.run(req));
let mut res = ProvideServerContext::new(fut, server_context.clone()).await;
if accepts_html {
if let Some(referrer) = referrer {
let has_location = res.headers().get(LOCATION).is_some();
if !has_location {
*res.status_mut() = StatusCode::FOUND;
res.headers_mut().insert(LOCATION, referrer);
}
}
}
let mut res_options = server_context.response_parts_mut();
res.headers_mut().extend(res_options.headers.drain());
Ok(res)
} else {
Response::builder().status(StatusCode::BAD_REQUEST).body(
{
#[cfg(target_family = "wasm")]
{
Body::from(format!(
"No server function found for path: {path_string}\nYou may need to explicitly register the server function with `register_explicit`, rebuild your wasm binary to update a server function link or make sure the prefix your server and client use for server functions match.",
))
}
#[cfg(not(target_family = "wasm"))]
{
Body::from(format!(
"No server function found for path: {path_string}\nYou may need to rebuild your wasm binary to update a server function link or make sure the prefix your server and client use for server functions match.",
))
}
}
)
}
.expect("could not build Response")
};
#[cfg(target_arch = "wasm32")]
{
use futures_util::future::FutureExt;
let result = tokio::task::spawn_local(future);
let result = result.then(|f| async move { f.unwrap() });
result.await.unwrap_or_else(|e| {
use server_fn::error::NoCustomError;
use server_fn::error::ServerFnErrorSerde;
(
StatusCode::INTERNAL_SERVER_ERROR,
ServerFnError::<NoCustomError>::ServerError(e.to_string())
.ser()
.unwrap_or_default(),
)
.into_response()
})
}
#[cfg(not(target_arch = "wasm32"))]
{
future().await
}
}