use std::net::SocketAddr;
use std::path::Path;
use std::sync::Arc;
use std::time::Duration;
use axum::Router;
use axum::extract::State;
use axum::http::{HeaderValue, StatusCode, header};
use axum::response::{Html, IntoResponse, Response};
use axum::routing::get;
use tokio::task::JoinHandle;
use tower_http::services::ServeDir;
use tower_http::set_header::SetResponseHeaderLayer;
use super::livereload;
use super::rebuilder::{Rebuilder, favicon_mtime};
use super::state::AppState;
use super::watcher;
use crate::Error;
use crate::config::Config;
use crate::favicon::FaviconSet;
use crate::render::{BuiltSite, Theme};
pub struct Server {
state: Arc<AppState>,
router: Router,
config: Config,
rebuilder: Option<Rebuilder>,
}
impl Server {
pub fn new(config_path: &Path) -> Result<Self, Error> {
let config = Config::from_path(config_path)?;
let theme = Theme::load(&config)?;
let initial_favicon = match &config.favicon {
Some(path) => {
let mtime = favicon_mtime(path)?;
let set = FaviconSet::generate(path, &config.title)?;
Some((set, path.clone(), mtime))
}
None => None,
};
let favicon_set = initial_favicon.as_ref().map(|(set, _, _)| set.clone());
let built = BuiltSite::build_with_favicon(&config, &theme, favicon_set)?;
log_diagnostics(&built);
let state = Arc::new(AppState::new(built));
let rebuilder = Rebuilder::with_initial_favicon(config_path.to_path_buf(), initial_favicon);
let router = Self::build_router(Arc::clone(&state), &config);
Ok(Self {
state,
router,
config,
rebuilder: Some(rebuilder),
})
}
pub async fn run(mut self, port: u16) -> Result<(), Error> {
let watcher_handle = self.start_watcher()?;
let addr = SocketAddr::from(([0, 0, 0, 0], port));
tracing::info!("serving at http://localhost:{port}");
let listener = tokio::net::TcpListener::bind(addr).await?;
let state_for_shutdown = Arc::clone(&self.state);
axum::serve(listener, self.router)
.with_graceful_shutdown(async move {
shutdown_signal().await;
state_for_shutdown.shutdown.notify_waiters();
})
.await?;
tracing::info!("stopping file watcher");
let _ = tokio::time::timeout(Duration::from_secs(5), watcher_handle).await;
tracing::info!("server shut down");
Ok(())
}
pub async fn spawn_test(self) -> Result<(u16, Arc<AppState>, JoinHandle<()>), Error> {
self.spawn_test_server().await
}
pub async fn spawn_test_with_watcher(
mut self,
) -> Result<(u16, Arc<AppState>, JoinHandle<()>, JoinHandle<()>), Error> {
let watcher_handle = self.start_watcher()?;
let (port, state, server_handle) = self.spawn_test_server().await?;
Ok((port, state, server_handle, watcher_handle))
}
async fn spawn_test_server(self) -> Result<(u16, Arc<AppState>, JoinHandle<()>), Error> {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
let port = listener.local_addr()?.port();
let state = Arc::clone(&self.state);
let router = self.router;
let server_handle = tokio::spawn(async move {
axum::serve(listener, router).await.ok();
});
Ok((port, state, server_handle))
}
fn start_watcher(&mut self) -> Result<JoinHandle<()>, Error> {
let mut rebuilder = self
.rebuilder
.take()
.expect("start_watcher called twice on the same Server");
let watch_state = Arc::clone(&self.state);
let mut content_watcher =
watcher::ContentWatcher::new(&self.config, rebuilder.config_path())?;
Ok(tokio::task::spawn(async move {
if let Err(e) = content_watcher.run(&mut rebuilder, &watch_state).await {
tracing::error!("file watcher failed: {e}");
}
}))
}
fn build_router(state: Arc<AppState>, config: &Config) -> Router {
let mut static_router = Router::new();
if let Some(theme_static) = config.theme_dir.as_ref().map(|d| d.join("static"))
&& theme_static.is_dir()
{
static_router = static_router.fallback_service(ServeDir::new(theme_static));
}
if config.static_dir.is_dir() {
static_router = Router::new()
.fallback_service(ServeDir::new(&config.static_dir).fallback(static_router));
}
Router::new()
.route("/ws", get(livereload::ws_handler))
.nest_service("/static", static_router)
.fallback(handle_page)
.with_state(state)
.layer(SetResponseHeaderLayer::overriding(
header::CACHE_CONTROL,
HeaderValue::from_static("no-store"),
))
}
}
async fn handle_page(State(state): State<Arc<AppState>>, req: axum::extract::Request) -> Response {
let path = req.uri().path().to_string();
let site = state.site.read().await;
let filename = path.trim_start_matches('/');
if let Some((_, content)) = site.root_files.iter().find(|(name, _)| name == filename) {
let content_type = root_file_content_type(filename);
return ([(header::CONTENT_TYPE, content_type)], content.clone()).into_response();
}
match site.lookup(&path) {
Some(html) => Html(livereload::inject_live_reload(html)).into_response(),
None => (
StatusCode::NOT_FOUND,
Html(livereload::inject_live_reload(&site.not_found_html)),
)
.into_response(),
}
}
fn root_file_content_type(filename: &str) -> &'static str {
match Path::new(filename)
.extension()
.and_then(|extension| extension.to_str())
{
Some("ico") => "image/x-icon",
Some("png") => "image/png",
Some("xml") => "application/xml",
Some("txt") => "text/plain; charset=utf-8",
Some("webmanifest") => "application/manifest+json",
_ => "application/octet-stream",
}
}
async fn shutdown_signal() {
match tokio::signal::ctrl_c().await {
Ok(()) => tracing::info!("received ctrl-c, shutting down"),
Err(e) => tracing::error!("failed to install ctrl-c handler: {e}"),
}
}
pub(crate) fn log_diagnostics(built: &BuiltSite) {
for link in &built.diagnostics.broken_wiki_links {
tracing::warn!(
page = %link.source,
target = %link.target,
"broken wiki-link",
);
}
}
#[cfg(test)]
mod tests {
use super::root_file_content_type;
#[test]
fn root_file_content_type_handles_known_extensions() {
assert_eq!(root_file_content_type("favicon.ico"), "image/x-icon");
assert_eq!(root_file_content_type("favicon-32x32.png"), "image/png");
assert_eq!(root_file_content_type("sitemap.xml"), "application/xml");
assert_eq!(
root_file_content_type("robots.txt"),
"text/plain; charset=utf-8"
);
assert_eq!(
root_file_content_type("site.webmanifest"),
"application/manifest+json"
);
}
#[test]
fn root_file_content_type_falls_back_for_unknown_extensions() {
assert_eq!(
root_file_content_type("feed.unknown"),
"application/octet-stream"
);
assert_eq!(
root_file_content_type("LICENSE"),
"application/octet-stream"
);
}
}