use super::{hot_reload::*, WsReloadState};
use crate::{server::HotReloadState, Result};
use axum::{
body::Body,
extract::{
ws::{Message, WebSocket},
Extension, WebSocketUpgrade,
},
http::{
self,
header::{HeaderName, HeaderValue},
Method, Response, StatusCode,
},
response::IntoResponse,
routing::{get, get_service},
Router,
};
use axum_server::tls_rustls::RustlsConfig;
use dioxus_cli_config::{CrateConfig, WebHttpsConfig};
use std::{fs, io, process::Command, sync::Arc};
use tower::ServiceBuilder;
use tower_http::{
cors::{Any, CorsLayer},
services::fs::{ServeDir, ServeFileSystemResponseBody},
ServiceBuilderExt,
};
pub async fn setup_router(
config: CrateConfig,
ws_reload: Arc<WsReloadState>,
hot_reload: Option<HotReloadState>,
) -> Result<Router> {
let cors = CorsLayer::new()
.allow_methods([Method::GET, Method::POST])
.allow_origin(Any)
.allow_headers(Any);
let (coep, coop) = if config.cross_origin_policy {
(
HeaderValue::from_static("require-corp"),
HeaderValue::from_static("same-origin"),
)
} else {
(
HeaderValue::from_static("unsafe-none"),
HeaderValue::from_static("unsafe-none"),
)
};
let file_service_config = config.clone();
let file_service = ServiceBuilder::new()
.override_response_header(
HeaderName::from_static("cross-origin-embedder-policy"),
coep,
)
.override_response_header(HeaderName::from_static("cross-origin-opener-policy"), coop)
.and_then(move |response| async move { Ok(no_cache(file_service_config, response)) })
.service(ServeDir::new(config.out_dir()));
let mut router = Router::new();
for proxy_config in config.dioxus_config.web.proxy {
router = super::proxy::add_proxy(router, &proxy_config)?;
}
router = router.fallback(get_service(file_service).handle_error(
|error: std::convert::Infallible| async move {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),
)
},
));
router = if let Some(base_path) = config.dioxus_config.web.app.base_path.clone() {
let base_path = format!("/{}", base_path.trim_matches('/'));
Router::new()
.nest(&base_path, router)
.fallback(get(move || {
let base_path = base_path.clone();
async move { format!("Outside of the base path: {}", base_path) }
}))
} else {
router
};
router = router.route("/_dioxus/ws", get(ws_handler));
router = router
.route("/_dioxus/hot_reload", get(hot_reload_handler))
.layer(cors)
.layer(Extension(ws_reload));
if let Some(hot_reload) = hot_reload {
router = router.layer(Extension(hot_reload))
}
Ok(router)
}
fn no_cache(
file_service_config: CrateConfig,
response: Response<ServeFileSystemResponseBody>,
) -> Response<Body> {
let mut response = if file_service_config.dioxus_config.web.watcher.index_on_404
&& response.status() == StatusCode::NOT_FOUND
{
let body = Body::from(
std::fs::read_to_string(file_service_config.out_dir().join("index.html"))
.ok()
.unwrap(),
);
Response::builder()
.status(StatusCode::OK)
.body(body)
.unwrap()
} else {
response.into_response()
};
let headers = response.headers_mut();
headers.insert(
http::header::CACHE_CONTROL,
HeaderValue::from_static("no-cache"),
);
headers.insert(http::header::PRAGMA, HeaderValue::from_static("no-cache"));
headers.insert(http::header::EXPIRES, HeaderValue::from_static("0"));
response
}
async fn ws_handler(
ws: WebSocketUpgrade,
Extension(state): Extension<Arc<WsReloadState>>,
) -> impl IntoResponse {
ws.on_upgrade(move |socket| ws_reload_handler(socket, state))
}
async fn ws_reload_handler(mut socket: WebSocket, state: Arc<WsReloadState>) {
let mut rx = state.update.subscribe();
let reload_watcher = tokio::spawn(async move {
loop {
rx.recv().await.unwrap();
let _ = socket.send(Message::Text(String::from("reload"))).await;
println!("forcing reload");
rx = rx.resubscribe();
}
});
reload_watcher.await.unwrap();
}
const DEFAULT_KEY_PATH: &str = "ssl/key.pem";
const DEFAULT_CERT_PATH: &str = "ssl/cert.pem";
pub async fn get_rustls(config: &CrateConfig) -> Result<Option<RustlsConfig>> {
let web_config = &config.dioxus_config.web.https;
if web_config.enabled != Some(true) {
return Ok(None);
}
let (cert_path, key_path) = if let Some(true) = web_config.mkcert {
get_rustls_with_mkcert(web_config)?
} else {
get_rustls_without_mkcert(web_config)?
};
Ok(Some(
RustlsConfig::from_pem_file(cert_path, key_path).await?,
))
}
pub fn get_rustls_with_mkcert(web_config: &WebHttpsConfig) -> Result<(String, String)> {
let key_path = web_config
.key_path
.clone()
.unwrap_or(DEFAULT_KEY_PATH.to_string());
let cert_path = web_config
.cert_path
.clone()
.unwrap_or(DEFAULT_CERT_PATH.to_string());
if key_path == DEFAULT_KEY_PATH && cert_path == DEFAULT_CERT_PATH {
_ = fs::create_dir("ssl");
}
let cmd = Command::new("mkcert")
.args([
"-install",
"-key-file",
&key_path,
"-cert-file",
&cert_path,
"localhost",
"::1",
"127.0.0.1",
])
.spawn();
match cmd {
Err(e) => {
match e.kind() {
io::ErrorKind::NotFound => tracing::error!("mkcert is not installed. See https://github.com/FiloSottile/mkcert#installation for installation instructions."),
e => tracing::error!("an error occurred while generating mkcert certificates: {}", e.to_string()),
};
return Err("failed to generate mkcert certificates".into());
}
Ok(mut cmd) => {
cmd.wait()?;
}
}
Ok((cert_path, key_path))
}
pub fn get_rustls_without_mkcert(web_config: &WebHttpsConfig) -> Result<(String, String)> {
if let (Some(key), Some(cert)) = (web_config.key_path.clone(), web_config.cert_path.clone()) {
Ok((cert, key))
} else {
Err("https is enabled but cert or key path is missing".into())
}
}