use std::net::SocketAddr;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Duration;
use axum::Router;
use axum::extract::State;
use axum::http::{HeaderValue, StatusCode, header};
use axum::response::{Html, IntoResponse, Response};
use axum::routing::get;
use tokio::task::JoinHandle;
use tower_http::services::ServeDir;
use tower_http::set_header::SetResponseHeaderLayer;
use super::livereload;
use super::state::AppState;
use super::watcher;
use crate::Error;
use crate::config::Config;
use crate::render::{RenderedSite, Theme};
pub struct Server {
state: Arc<AppState>,
router: Router,
config: Config,
config_path: PathBuf,
}
impl Server {
pub fn new(config_path: &Path) -> Result<Self, Error> {
let config = Config::from_path(config_path)?;
let theme = Theme::load(&config)?;
let rendered = RenderedSite::build(&config, &theme, false)?;
let state = Arc::new(AppState::new(rendered));
let router = Self::build_router(Arc::clone(&state), &config);
Ok(Self {
state,
router,
config,
config_path: config_path.to_path_buf(),
})
}
pub async fn run(self, port: u16) -> Result<(), Error> {
let watcher_handle = self.start_watcher()?;
let addr = SocketAddr::from(([0, 0, 0, 0], port));
tracing::info!("serving at http://localhost:{port}");
let listener = tokio::net::TcpListener::bind(addr).await?;
let state_for_shutdown = Arc::clone(&self.state);
axum::serve(listener, self.router)
.with_graceful_shutdown(async move {
shutdown_signal().await;
state_for_shutdown.shutdown.notify_waiters();
})
.await?;
tracing::info!("stopping file watcher");
let _ = tokio::time::timeout(Duration::from_secs(5), watcher_handle).await;
tracing::info!("server shut down");
Ok(())
}
pub async fn spawn_test(self) -> Result<(u16, Arc<AppState>, JoinHandle<()>), Error> {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
let port = listener.local_addr()?.port();
let state = Arc::clone(&self.state);
let router = self.router;
let handle = tokio::spawn(async move {
axum::serve(listener, router).await.ok();
});
Ok((port, state, handle))
}
pub async fn spawn_test_with_watcher(
self,
) -> Result<(u16, Arc<AppState>, JoinHandle<()>, JoinHandle<()>), Error> {
let watcher_handle = self.start_watcher()?;
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
let port = listener.local_addr()?.port();
let state = Arc::clone(&self.state);
let router = self.router;
let server_handle = tokio::spawn(async move {
axum::serve(listener, router).await.ok();
});
Ok((port, state, server_handle, watcher_handle))
}
fn start_watcher(&self) -> Result<JoinHandle<()>, Error> {
let watch_state = Arc::clone(&self.state);
let watch_config_path = self.config_path.clone();
let mut content_watcher = watcher::ContentWatcher::new(&self.config)?;
Ok(tokio::task::spawn(async move {
if let Err(e) = content_watcher.run(&watch_config_path, &watch_state).await {
tracing::error!("file watcher failed: {e}");
}
}))
}
fn build_router(state: Arc<AppState>, config: &Config) -> Router {
let mut static_router = Router::new();
if let Some(theme_static) = config.theme_dir.as_ref().map(|d| d.join("static"))
&& theme_static.is_dir()
{
static_router = static_router.fallback_service(ServeDir::new(theme_static));
}
if config.static_dir.is_dir() {
static_router = Router::new()
.fallback_service(ServeDir::new(&config.static_dir).fallback(static_router));
}
Router::new()
.route("/ws", get(livereload::ws_handler))
.nest_service("/static", static_router)
.fallback(handle_page)
.with_state(state)
.layer(SetResponseHeaderLayer::overriding(
header::CACHE_CONTROL,
HeaderValue::from_static("no-store"),
))
}
}
async fn handle_page(State(state): State<Arc<AppState>>, req: axum::extract::Request) -> Response {
let path = req.uri().path().to_string();
let site = state.site.read().await;
match site.lookup(&path) {
Some(html) => Html(livereload::inject_live_reload(html)).into_response(),
None => (
StatusCode::NOT_FOUND,
Html(livereload::inject_live_reload(&site.not_found_html)),
)
.into_response(),
}
}
async fn shutdown_signal() {
match tokio::signal::ctrl_c().await {
Ok(()) => tracing::info!("received ctrl-c, shutting down"),
Err(e) => tracing::error!("failed to install ctrl-c handler: {e}"),
}
}