use axum::{
Router,
body::Body,
extract::{
Request, State, WebSocketUpgrade,
ws::{Message, Utf8Bytes, WebSocket},
},
http::{HeaderMap, HeaderValue, StatusCode, header},
routing::get,
};
use futures::{sink::SinkExt, stream::StreamExt};
use mime_guess::mime;
use std::{
fs,
io::ErrorKind,
path::{Path, PathBuf},
sync::Arc,
};
use tokio::{net::TcpListener, sync::broadcast};
use crate::{
http_layer::template::{error_html, index_html},
utils::is_ignored,
};
const WEBSOCKET_FUNCTION: &str = include_str!("../templates/websocket.js");
const RELOAD_PAYLOAD: &str = include_str!("../templates/reload.js");
pub(crate) async fn serve(tcp_listener: TcpListener, router: Router) {
axum::serve(tcp_listener, router).await.unwrap();
}
pub struct Options {
pub hard_reload: bool,
pub index_listing: bool,
pub auto_ignore: bool,
}
pub(crate) struct AppState {
pub(crate) hard_reload: bool,
pub(crate) index_listing: bool,
pub(crate) auto_ignore: bool,
pub(crate) tx: Arc<broadcast::Sender<()>>,
pub(crate) root: PathBuf,
}
impl Default for Options {
fn default() -> Self {
Self {
hard_reload: false,
index_listing: true,
auto_ignore: false,
}
}
}
pub(crate) fn create_server(state: AppState) -> Router {
let tx = state.tx.clone();
Router::new()
.route("/", get(static_assets))
.route("/{*path}", get(static_assets))
.route(
"/live-server-ws",
get(|ws: WebSocketUpgrade| async move {
ws.on_failed_upgrade(|error| {
log::error!("Failed to upgrade websocket: {error}");
})
.on_upgrade(|socket: WebSocket| on_websocket_upgrade(socket, tx))
}),
)
.with_state(Arc::new(state))
}
async fn on_websocket_upgrade(socket: WebSocket, tx: Arc<broadcast::Sender<()>>) {
let (mut sender, mut receiver) = socket.split();
let mut rx = tx.subscribe();
let mut send_task = tokio::spawn(async move {
while rx.recv().await.is_ok() {
sender
.send(Message::Text(Utf8Bytes::default()))
.await
.unwrap();
}
});
let mut recv_task =
tokio::spawn(async move { while let Some(Ok(_)) = receiver.next().await {} });
tokio::select! {
_ = (&mut send_task) => recv_task.abort(),
_ = (&mut recv_task) => send_task.abort(),
};
}
fn get_index_listing(uri_path: &str, root: &Path, auto_ignore: bool) -> String {
let is_root = uri_path == "/";
let path = root.join(&uri_path[1..]);
let entries = fs::read_dir(path).unwrap();
let mut entry_names = entries
.into_iter()
.filter_map(|e| {
if let Ok(entry) = e {
if auto_ignore {
match is_ignored(root, &entry.path()) {
Ok(ignored) => {
if ignored {
return None;
}
}
Err(err) => {
log::error!("Failed to check ignore files: {err}");
return None;
}
}
}
let is_dir = entry.metadata().ok()?.is_dir();
let trailing = if is_dir { "/" } else { "" };
entry
.file_name()
.to_str()
.map(|name| format!("{name}{trailing}"))
} else {
None
}
})
.collect::<Vec<String>>();
entry_names.sort();
if !is_root {
entry_names.insert(0, "..".to_string());
}
entry_names
.into_iter()
.map(|en| format!("<li><a href=\"{en}\">{en}</a></li>"))
.collect::<Vec<String>>()
.join("\n")
}
async fn static_assets(
state: State<Arc<AppState>>,
req: Request<Body>,
) -> (StatusCode, HeaderMap, Body) {
let is_reload = req.uri().query().is_some_and(|x| x == "reload");
let uri_path = req.uri().path();
if uri_path.starts_with("//") {
let redirect = format!("/{}", uri_path.trim_start_matches("/"));
let mut headers = HeaderMap::new();
headers.append(header::LOCATION, HeaderValue::from_str(&redirect).unwrap());
return (StatusCode::TEMPORARY_REDIRECT, headers, Body::empty());
}
let mut path = state.root.join(&uri_path[1..]);
let is_accessing_dir = path.is_dir();
if is_accessing_dir {
if !uri_path.ends_with('/') {
let redirect = format!("{uri_path}/");
let mut headers = HeaderMap::new();
headers.append(header::LOCATION, HeaderValue::from_str(&redirect).unwrap());
return (StatusCode::TEMPORARY_REDIRECT, headers, Body::empty());
}
path.push("index.html");
}
let mime = mime_guess::from_path(&path).first_or_text_plain();
let mut headers = HeaderMap::new();
let content_type = if mime.type_() == mime::TEXT {
format!("{}; charset=utf-8", mime.as_ref())
} else {
mime.as_ref().to_string()
};
headers.append(
header::CONTENT_TYPE,
HeaderValue::from_str(&content_type).unwrap(),
);
if state.auto_ignore {
match is_ignored(&state.root, &path) {
Ok(ignored) => {
if ignored {
let err_msg =
"Unable to access ignored or hidden file, because `--ignore` is enabled";
let body = generate_error_body(err_msg, state.hard_reload, is_reload);
return (StatusCode::FORBIDDEN, HeaderMap::new(), body);
}
}
Err(err) => {
let err_msg = format!("Failed to check ignore files: {err}");
let body = generate_error_body(&err_msg, state.hard_reload, is_reload);
log::error!("{err_msg}");
return (StatusCode::INTERNAL_SERVER_ERROR, HeaderMap::new(), body);
}
}
}
let mut file = match fs::read(&path) {
Ok(file) => file,
Err(err) => {
let status_code = match err.kind() {
ErrorKind::NotFound => {
if state.index_listing && is_accessing_dir {
let script = format_script(state.hard_reload, is_reload, false);
let html = index_html(
uri_path,
&script,
&get_index_listing(uri_path, &state.root, state.auto_ignore),
);
return (StatusCode::OK, headers, html);
}
StatusCode::NOT_FOUND
}
_ => StatusCode::INTERNAL_SERVER_ERROR,
};
match path.to_str() {
Some(path) => log::warn!("Failed to read \"{path}\": {err}"),
None => log::warn!("Failed to read file with invalid path: {err}"),
}
return (
status_code,
headers,
if mime == "text/html" {
generate_error_body(&err.to_string(), state.hard_reload, is_reload)
} else {
Body::from(err.to_string())
},
);
}
};
if mime == "text/html" {
let text = match String::from_utf8(file) {
Ok(text) => text,
Err(err) => {
log::error!("Failed to read {path:?} as utf-8: {err}");
let html = generate_error_body(&err.to_string(), state.hard_reload, is_reload);
return (StatusCode::INTERNAL_SERVER_ERROR, headers, html);
}
};
let script = format_script(state.hard_reload, is_reload, false);
file = format!("{text}{script}").into_bytes();
} else if state.hard_reload {
headers.append(
header::CACHE_CONTROL,
HeaderValue::from_str("max-age=30").unwrap(),
);
}
(StatusCode::OK, headers, Body::from(file))
}
fn format_script(hard_reload: bool, is_reload: bool, is_error: bool) -> String {
match (is_reload, is_error) {
(true, false) => format!("<script>{RELOAD_PAYLOAD}</script>"),
(true, true) => String::new(),
_ => {
let hard = if hard_reload { "true" } else { "false" };
format!(r#"<script>{WEBSOCKET_FUNCTION}({hard})</script>"#)
}
}
}
fn generate_error_body(err_msg: &str, hard_reload: bool, is_reload: bool) -> Body {
let script = format_script(hard_reload, is_reload, true);
error_html(&script, err_msg)
}