use std::net::SocketAddr;
use axum::Router;
use axum::http::{Request, StatusCode};
use axum::middleware;
use axum::response::{IntoResponse, Response};
use tokio::net::TcpListener;
use tower_http::compression::CompressionLayer;
use tower_http::services::ServeDir;
use tower_http::set_header::SetResponseHeaderLayer;
use tower_http::trace::TraceLayer;
use tracing::info;
use crate::auth::extractor::JwtSecret;
use crate::config::Config;
use crate::db::DbPool;
use crate::error::Result;
use crate::middleware::request_id::request_id;
use crate::middleware::security_headers::security_headers;
pub struct App {
config: Config,
router: Router,
static_dir: Option<String>,
db: Option<DbPool>,
run_migrations: bool,
}
impl App {
pub fn new(config: Config) -> Self {
Self {
config,
router: Router::new(),
static_dir: None,
db: None,
run_migrations: false,
}
}
pub fn db(mut self, pool: DbPool) -> Self {
self.db = Some(pool);
self
}
pub fn router(mut self, router: Router) -> Self {
self.router = router;
self
}
pub fn run_migrations(mut self) -> Self {
self.run_migrations = true;
self
}
pub fn static_dir(mut self, path: impl Into<String>) -> Self {
self.static_dir = Some(path.into());
self
}
fn build_router(self) -> Router {
let router = attach_static_files(self.router, self.static_dir);
let health_routes = axum::Router::new()
.route("/_ping", axum::routing::get(crate::health::ping))
.route("/_health", axum::routing::get(crate::health::check))
.with_state(self.db.clone());
let auth_ctx = AuthExtensions {
jwt_secret: self.config.jwt_secret().map(|s| s.to_owned()),
db: self.db.clone(),
};
router
.merge(health_routes)
.layer(CompressionLayer::new())
.layer(middleware::from_fn(security_headers))
.layer(middleware::from_fn_with_state(
auth_ctx,
inject_auth_extensions,
))
.layer(middleware::from_fn(request_id))
.layer(TraceLayer::new_for_http())
}
pub async fn serve(self) -> Result<()> {
if self.run_migrations {
if let Some(pool) = &self.db {
crate::db::migrate(pool).await?;
}
}
let addr = format!("{}:{}", self.config.host, self.config.port);
let router = self.build_router();
let listener = TcpListener::bind(&addr).await?;
info!("Blixt server running on http://{addr}");
axum::serve(
listener,
router.into_make_service_with_connect_info::<SocketAddr>(),
)
.await?;
Ok(())
}
}
#[derive(Clone)]
struct AuthExtensions {
jwt_secret: Option<String>,
db: Option<DbPool>,
}
async fn inject_auth_extensions(
axum::extract::State(ctx): axum::extract::State<AuthExtensions>,
mut request: Request<axum::body::Body>,
next: axum::middleware::Next,
) -> Response {
if let Some(secret) = &ctx.jwt_secret {
request.extensions_mut().insert(JwtSecret(secret.clone()));
}
if let Some(pool) = &ctx.db {
request.extensions_mut().insert(pool.clone());
}
next.run(request).await
}
fn attach_static_files(router: Router, static_dir: Option<String>) -> Router {
match static_dir {
Some(dir) => {
let cache_header = SetResponseHeaderLayer::overriding(
axum::http::header::CACHE_CONTROL,
axum::http::HeaderValue::from_static("public, max-age=31536000, immutable"),
);
let serve_dir = ServeDir::new(dir);
let static_router = Router::new()
.fallback_service(serve_dir)
.layer(cache_header)
.layer(middleware::from_fn(block_dotfiles));
router.nest("/static", static_router)
}
None => router,
}
}
async fn block_dotfiles(
request: Request<axum::body::Body>,
next: axum::middleware::Next,
) -> Response {
if is_dotfile_path(request.uri().path()) {
return StatusCode::NOT_FOUND.into_response();
}
next.run(request).await
}
fn is_dotfile_path(path: &str) -> bool {
path.contains("/..")
|| path
.split('/')
.any(|segment| segment.starts_with('.') && !segment.is_empty())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_helpers::test_config;
use axum::body::Body;
use axum::routing::get;
use tower::ServiceExt;
fn build_test_app(static_dir: Option<&str>) -> Router {
let routes = Router::new().route("/health", get(|| async { "ok" }));
let mut app = App::new(test_config()).router(routes);
if let Some(dir) = static_dir {
app = app.static_dir(dir);
}
app.build_router()
}
#[tokio::test]
async fn response_includes_all_security_headers() {
let app = build_test_app(None);
let request = Request::builder()
.uri("/health")
.body(Body::empty())
.expect("failed to build request");
let response = app.oneshot(request).await.expect("failed to send request");
let headers = response.headers();
assert!(headers.contains_key("content-security-policy"));
assert!(headers.contains_key("strict-transport-security"));
assert!(headers.contains_key("x-content-type-options"));
assert!(headers.contains_key("x-frame-options"));
assert!(headers.contains_key("referrer-policy"));
assert!(headers.contains_key("permissions-policy"));
}
#[tokio::test]
async fn dotfile_request_returns_404() {
let app = build_test_app(Some("tests/fixtures/static"));
let request = Request::builder()
.uri("/static/.env")
.body(Body::empty())
.expect("failed to build request");
let response = app.oneshot(request).await.expect("failed to send request");
assert_eq!(response.status(), StatusCode::NOT_FOUND);
}
#[tokio::test]
async fn path_traversal_returns_404() {
let app = build_test_app(Some("tests/fixtures/static"));
let request = Request::builder()
.uri("/static/../Cargo.toml")
.body(Body::empty())
.expect("failed to build request");
let response = app.oneshot(request).await.expect("failed to send request");
assert_eq!(response.status(), StatusCode::NOT_FOUND);
}
#[tokio::test]
async fn valid_static_file_returns_200() {
let app = build_test_app(Some("tests/fixtures/static"));
let request = Request::builder()
.uri("/static/css/test.css")
.body(Body::empty())
.expect("failed to build request");
let response = app.oneshot(request).await.expect("failed to send request");
assert_eq!(response.status(), StatusCode::OK);
}
#[test]
fn is_dotfile_path_detects_dotfiles() {
assert!(is_dotfile_path("/.env"));
assert!(is_dotfile_path("/css/.hidden"));
assert!(is_dotfile_path("/../etc/passwd"));
assert!(!is_dotfile_path("/css/style.css"));
assert!(!is_dotfile_path("/js/app.js"));
}
}