use std::net::SocketAddr;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Duration;
use axum::Router;
use axum::extract::State;
use axum::http::{HeaderValue, StatusCode, header};
use axum::response::{Html, IntoResponse, Response};
use axum::routing::get;
use tokio::task::JoinHandle;
use tower_http::services::ServeDir;
use tower_http::set_header::SetResponseHeaderLayer;
use super::livereload;
use super::state::AppState;
use super::watcher;
use crate::Error;
use crate::config::Config;
use crate::generated::FaviconSet;
use crate::render::{Mode, RenderedSite, Theme};
pub struct Server {
state: Arc<AppState>,
router: Router,
config: Config,
config_path: PathBuf,
}
impl Server {
pub fn new(config_path: &Path) -> Result<Self, Error> {
let config = Config::from_path(config_path)?;
let theme = Theme::load(&config)?;
let favicon = config
.favicon
.as_ref()
.map(|p| FaviconSet::generate(p, &config.title))
.transpose()?;
let rendered =
RenderedSite::build_with_favicon(&config, &theme, Mode::Serve, favicon.clone())?;
let state = Arc::new(AppState::new(rendered, favicon));
let router = Self::build_router(Arc::clone(&state), &config);
Ok(Self {
state,
router,
config,
config_path: config_path.to_path_buf(),
})
}
pub async fn run(self, port: u16) -> Result<(), Error> {
let watcher_handle = self.start_watcher()?;
let addr = SocketAddr::from(([0, 0, 0, 0], port));
tracing::info!("serving at http://localhost:{port}");
let listener = tokio::net::TcpListener::bind(addr).await?;
let state_for_shutdown = Arc::clone(&self.state);
axum::serve(listener, self.router)
.with_graceful_shutdown(async move {
shutdown_signal().await;
state_for_shutdown.shutdown.notify_waiters();
})
.await?;
tracing::info!("stopping file watcher");
let _ = tokio::time::timeout(Duration::from_secs(5), watcher_handle).await;
tracing::info!("server shut down");
Ok(())
}
pub async fn spawn_test(self) -> Result<(u16, Arc<AppState>, JoinHandle<()>), Error> {
self.spawn_test_server().await
}
pub async fn spawn_test_with_watcher(
self,
) -> Result<(u16, Arc<AppState>, JoinHandle<()>, JoinHandle<()>), Error> {
let watcher_handle = self.start_watcher()?;
let (port, state, server_handle) = self.spawn_test_server().await?;
Ok((port, state, server_handle, watcher_handle))
}
async fn spawn_test_server(self) -> Result<(u16, Arc<AppState>, JoinHandle<()>), Error> {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
let port = listener.local_addr()?.port();
let state = Arc::clone(&self.state);
let router = self.router;
let server_handle = tokio::spawn(async move {
axum::serve(listener, router).await.ok();
});
Ok((port, state, server_handle))
}
fn start_watcher(&self) -> Result<JoinHandle<()>, Error> {
let watch_state = Arc::clone(&self.state);
let watch_config_path = self.config_path.clone();
let mut content_watcher = watcher::ContentWatcher::new(&self.config)?;
Ok(tokio::task::spawn(async move {
if let Err(e) = content_watcher.run(&watch_config_path, &watch_state).await {
tracing::error!("file watcher failed: {e}");
}
}))
}
fn build_router(state: Arc<AppState>, config: &Config) -> Router {
let mut static_router = Router::new();
if let Some(theme_static) = config.theme_dir.as_ref().map(|d| d.join("static"))
&& theme_static.is_dir()
{
static_router = static_router.fallback_service(ServeDir::new(theme_static));
}
if config.static_dir.is_dir() {
static_router = Router::new()
.fallback_service(ServeDir::new(&config.static_dir).fallback(static_router));
}
Router::new()
.route("/ws", get(livereload::ws_handler))
.nest_service("/static", static_router)
.fallback(handle_page)
.with_state(state)
.layer(SetResponseHeaderLayer::overriding(
header::CACHE_CONTROL,
HeaderValue::from_static("no-store"),
))
}
}
async fn handle_page(State(state): State<Arc<AppState>>, req: axum::extract::Request) -> Response {
let path = req.uri().path().to_string();
let site = state.site.read().await;
let filename = path.trim_start_matches('/');
if let Some((_, content)) = site.root_files.iter().find(|(name, _)| name == filename) {
let content_type = root_file_content_type(filename);
return ([(header::CONTENT_TYPE, content_type)], content.clone()).into_response();
}
match site.lookup(&path) {
Some(html) => Html(livereload::inject_live_reload(html)).into_response(),
None => (
StatusCode::NOT_FOUND,
Html(livereload::inject_live_reload(&site.not_found_html)),
)
.into_response(),
}
}
fn root_file_content_type(filename: &str) -> &'static str {
match Path::new(filename)
.extension()
.and_then(|extension| extension.to_str())
{
Some("ico") => "image/x-icon",
Some("png") => "image/png",
Some("xml") => "application/xml",
Some("txt") => "text/plain; charset=utf-8",
Some("webmanifest") => "application/manifest+json",
_ => "application/octet-stream",
}
}
async fn shutdown_signal() {
match tokio::signal::ctrl_c().await {
Ok(()) => tracing::info!("received ctrl-c, shutting down"),
Err(e) => tracing::error!("failed to install ctrl-c handler: {e}"),
}
}
#[cfg(test)]
mod tests {
use super::root_file_content_type;
#[test]
fn root_file_content_type_handles_known_extensions() {
assert_eq!(root_file_content_type("favicon.ico"), "image/x-icon");
assert_eq!(root_file_content_type("favicon-32x32.png"), "image/png");
assert_eq!(root_file_content_type("sitemap.xml"), "application/xml");
assert_eq!(
root_file_content_type("robots.txt"),
"text/plain; charset=utf-8"
);
assert_eq!(
root_file_content_type("site.webmanifest"),
"application/manifest+json"
);
}
#[test]
fn root_file_content_type_falls_back_for_unknown_extensions() {
assert_eq!(
root_file_content_type("feed.unknown"),
"application/octet-stream"
);
assert_eq!(
root_file_content_type("LICENSE"),
"application/octet-stream"
);
}
}