use anyhow::Result;
use bytes::Bytes;
use http_body_util::Full;
use hyper::header::{self, HeaderName, HeaderValue};
use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper::{Method, Request, Response, StatusCode};
use hyper_util::rt::TokioIo;
use socket2::{Domain, Protocol, Socket, Type};
use std::net::SocketAddr;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::net::TcpListener;
use tracing::{debug, info, warn};
use crate::routes::NanoWeb;
#[derive(Clone)]
pub struct ServeConfig {
pub public_dir: PathBuf,
pub port: u16,
pub dev: bool,
pub spa_mode: bool,
pub config_prefix: String,
pub log_requests: bool,
}
struct AppState {
server: Arc<NanoWeb>,
config: ServeConfig,
}
fn create_reuse_port_listener(addr: SocketAddr) -> Result<std::net::TcpListener> {
let socket = Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP))?;
socket.set_reuse_address(true)?;
#[cfg(unix)]
socket.set_reuse_port(true)?;
socket.set_nonblocking(true)?;
socket.bind(&addr.into())?;
socket.listen(8192)?; Ok(socket.into())
}
pub async fn start_server(config: ServeConfig) -> Result<()> {
let server = Arc::new(NanoWeb::new());
server.populate_routes(&config.public_dir, &config.config_prefix)?;
let state = Arc::new(AppState {
server,
config: config.clone(),
});
info!("Routes loaded: {}", state.server.route_count());
let addr: SocketAddr = ([0, 0, 0, 0], config.port).into();
let std_listener = create_reuse_port_listener(addr)?;
let listener = TcpListener::from_std(std_listener)?;
info!("Starting server on http://{}", addr);
info!("Serving directory: {:?}", config.public_dir);
let shutdown = shutdown_signal();
tokio::pin!(shutdown);
loop {
tokio::select! {
result = listener.accept() => {
let (stream, _) = result?;
let io = TokioIo::new(stream);
let state = state.clone();
tokio::spawn(async move {
let service = service_fn(move |req| {
let state = state.clone();
async move { handle_request(req, state) }
});
if let Err(e) = http1::Builder::new()
.keep_alive(true)
.pipeline_flush(true)
.serve_connection(io, service)
.await
{
debug!("Connection error: {}", e);
}
});
}
() = &mut shutdown => {
info!("Shutdown signal received, stopping server");
break;
}
}
}
Ok(())
}
async fn shutdown_signal() {
let ctrl_c = async {
tokio::signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("failed to install SIGTERM handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
() = ctrl_c => {},
() = terminate => {},
}
}
type HyperResponse = Response<Full<Bytes>>;
#[allow(clippy::needless_pass_by_value, clippy::unnecessary_wraps)]
fn handle_request(
req: Request<hyper::body::Incoming>,
state: Arc<AppState>,
) -> Result<HyperResponse, std::convert::Infallible> {
let is_head = req.method() == Method::HEAD;
if req.method() != Method::GET && !is_head {
return Ok(response(
StatusCode::METHOD_NOT_ALLOWED,
"Method Not Allowed",
));
}
let raw_path = req.uri().path();
if raw_path == "/_health" {
let body = format!(
r#"{{"status":"ok","timestamp":"{}"}}"#,
httpdate::fmt_http_date(std::time::SystemTime::now())
);
return Ok(Response::builder()
.status(StatusCode::OK)
.header("content-type", "application/json")
.body(Full::new(Bytes::from(body)))
.expect("health check response"));
}
let path = match crate::path::validate_request_path(raw_path) {
Ok(sanitized) => sanitized,
Err(e) => {
warn!("Path validation failed for '{}': {}", raw_path, e);
return Ok(response(StatusCode::BAD_REQUEST, "Bad Request"));
}
};
if state.config.dev {
let _ = state.server.refresh_if_modified(
&path,
&state.config.public_dir,
&state.config.config_prefix,
);
}
let accept_encoding = req
.headers()
.get("accept-encoding")
.and_then(|h| h.to_str().ok())
.unwrap_or("");
let if_none_match = req
.headers()
.get("if-none-match")
.and_then(|h| h.to_str().ok());
let mut buf = state.server.get_response(&path, accept_encoding);
if buf.is_none() && !path.ends_with('/') {
let with_slash = format!("{path}/");
buf = state.server.get_response(&with_slash, accept_encoding);
}
if buf.is_none() && state.config.spa_mode {
debug!("SPA fallback for: {}", path);
buf = state.server.get_response("/", accept_encoding);
}
let resp = if let Some(ref b) = buf {
if let Some(etag) = if_none_match {
if etag == b.etag.as_ref() {
return Ok(Response::builder()
.status(StatusCode::NOT_MODIFIED)
.header("etag", b.etag.as_ref())
.header("cache-control", b.cache_control.as_ref())
.body(Full::new(Bytes::new()))
.expect("304 response"));
}
}
build_response(b, is_head)
} else {
debug!("Route not found: {path}");
response(StatusCode::NOT_FOUND, "Not Found")
};
if state.config.log_requests {
info!(
method = %req.method(),
path = %path,
status = resp.status().as_u16(),
"request"
);
}
Ok(resp)
}
fn response(status: StatusCode, body: &'static str) -> HyperResponse {
Response::builder()
.status(status)
.body(Full::new(Bytes::from_static(body.as_bytes())))
.expect("error response")
}
fn build_response(buf: &crate::response_buffer::ResponseBuffer, head_only: bool) -> HyperResponse {
let mut builder = Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, buf.content_type.as_ref())
.header(header::ETAG, buf.etag.as_ref())
.header(header::LAST_MODIFIED, buf.last_modified.as_ref())
.header(header::CACHE_CONTROL, buf.cache_control.as_ref())
.header(header::CONTENT_LENGTH, buf.content_length.as_ref())
.header(
header::X_CONTENT_TYPE_OPTIONS,
HeaderValue::from_static("nosniff"),
)
.header(
header::X_FRAME_OPTIONS,
HeaderValue::from_static("SAMEORIGIN"),
)
.header(
header::REFERRER_POLICY,
HeaderValue::from_static("strict-origin-when-cross-origin"),
)
.header(
header::STRICT_TRANSPORT_SECURITY,
HeaderValue::from_static("max-age=63072000; includeSubDomains"),
)
.header(
HeaderName::from_static("permissions-policy"),
HeaderValue::from_static("camera=(), microphone=(), geolocation=()"),
)
.header(
HeaderName::from_static("x-dns-prefetch-control"),
HeaderValue::from_static("off"),
);
if let Some(encoding) = buf.content_encoding {
builder = builder.header(header::CONTENT_ENCODING, HeaderValue::from_static(encoding));
}
if buf.vary_encoding {
builder = builder.header(header::VARY, HeaderValue::from_static("Accept-Encoding"));
}
let body = if head_only {
Bytes::new()
} else {
buf.body.clone()
};
builder.body(Full::new(body)).expect("response body")
}