use std::sync::Arc;
use std::time::Duration;
use autumn_web::prelude::AppState as AutumnAppState;
use axum::Router;
use axum::http::{HeaderName, HeaderValue, Method};
use tower_http::cors::{AllowHeaders, AllowMethods, AllowOrigin, CorsLayer};
use tower_http::set_header::SetResponseHeaderLayer;
use tower_http::trace::TraceLayer;
use crate::AletheiaDB;
use crate::http::config::{CorsConfig, ServerConfig};
use crate::http::handlers::all_routes;
use crate::http::state::AppState;
pub async fn run_server(config: ServerConfig) -> std::io::Result<()> {
config
.rate_limit()
.validate()
.map_err(std::io::Error::other)?;
let db = Arc::new(build_database(&config)?);
let our_state = AppState::new(db);
let startup_state = our_state.clone();
let shutdown_state = our_state.clone();
let persist_on_shutdown = config.data_dir().is_some();
eprintln!(
"Starting AletheiaDB HTTP server on {}",
config.bind_address()
);
match config.data_dir() {
Some(path) => eprintln!("Data directory: {}", path.display()),
None => eprintln!(
"WARNING: no data directory configured — running in-memory; state is lost on shutdown."
),
}
if config.cors().is_permissive() {
eprintln!(
"WARNING: CORS is configured in permissive mode (any origin allowed). \
This is not recommended for production."
);
}
unsafe {
apply_autumn_env(&config);
}
autumn_web::app()
.on_startup(move |autumn_state| {
let installed = startup_state.clone();
async move {
autumn_state.insert_extension(installed);
Ok(())
}
})
.on_shutdown(move || {
let db = shutdown_state.db_arc();
let should_persist = persist_on_shutdown;
async move {
if !should_persist {
return;
}
match tokio::task::spawn_blocking(move || db.persist_indexes()).await {
Ok(Ok(())) => eprintln!("Shutdown: indexes persisted."),
Ok(Err(e)) => eprintln!("Shutdown: persist_indexes failed: {e}"),
Err(e) => eprintln!("Shutdown: persist_indexes task panicked: {e}"),
}
}
})
.routes(all_routes())
.run()
.await;
Ok(())
}
fn build_database(config: &ServerConfig) -> std::io::Result<AletheiaDB> {
match config.to_unified_config() {
None => AletheiaDB::new().map_err(|e| std::io::Error::other(e.to_string())),
Some(unified) => AletheiaDB::with_unified_config(unified)
.map_err(|e| std::io::Error::other(e.to_string())),
}
}
unsafe fn apply_autumn_env(config: &ServerConfig) {
unsafe {
std::env::set_var("AUTUMN_SERVER__HOST", config.host());
std::env::set_var("AUTUMN_SERVER__PORT", config.port().to_string());
let cors = config.cors();
let origins: Vec<&str> = if cors.is_permissive() {
vec!["*"]
} else {
cors.allowed_origins().iter().map(String::as_str).collect()
};
std::env::set_var("AUTUMN_CORS__ALLOWED_ORIGINS", origins.join(","));
std::env::set_var(
"AUTUMN_CORS__ALLOWED_METHODS",
cors.get_allowed_methods().join(","),
);
std::env::set_var(
"AUTUMN_CORS__ALLOWED_HEADERS",
cors.get_allowed_headers().join(","),
);
std::env::set_var("AUTUMN_CORS__MAX_AGE_SECS", cors.get_max_age().to_string());
std::env::set_var("AUTUMN_SECURITY__CSRF__ENABLED", "false");
}
}
pub fn build_test_router(state: AppState, config: &ServerConfig) -> Result<Router, String> {
config.rate_limit().validate()?;
let autumn_state = AutumnAppState::detached();
autumn_state.insert_extension(state);
let mut router: Router<AutumnAppState> = Router::new();
for route in all_routes() {
router = router.route(route.path, route.handler);
}
let router = router
.layer(SetResponseHeaderLayer::if_not_present(
HeaderName::from_static("x-content-type-options"),
HeaderValue::from_static("nosniff"),
))
.layer(SetResponseHeaderLayer::if_not_present(
HeaderName::from_static("x-frame-options"),
HeaderValue::from_static("DENY"),
))
.layer(SetResponseHeaderLayer::if_not_present(
HeaderName::from_static("content-security-policy"),
HeaderValue::from_static("default-src 'none'; frame-ancestors 'none'"),
))
.layer(build_cors_layer(config.cors()))
.layer(TraceLayer::new_for_http());
Ok(router.with_state(autumn_state))
}
fn build_cors_layer(cors: &CorsConfig) -> CorsLayer {
let origin = if cors.is_permissive() {
AllowOrigin::any()
} else {
let values: Vec<HeaderValue> = cors
.allowed_origins()
.iter()
.filter_map(|o| HeaderValue::from_str(o).ok())
.collect();
AllowOrigin::list(values)
};
let methods: Vec<Method> = cors
.get_allowed_methods()
.iter()
.filter_map(|m| m.parse().ok())
.collect();
let headers: Vec<HeaderName> = cors
.get_allowed_headers()
.iter()
.filter_map(|h| h.parse().ok())
.collect();
CorsLayer::new()
.allow_origin(origin)
.allow_methods(AllowMethods::list(methods))
.allow_headers(AllowHeaders::list(headers))
.max_age(Duration::from_secs(u64::from(cors.get_max_age())))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::http::config::RateLimitConfig;
#[test]
fn build_test_router_succeeds_with_default_config() {
let db = Arc::new(AletheiaDB::new().unwrap());
let state = AppState::new(db);
let config = ServerConfig::default();
assert!(build_test_router(state, &config).is_ok());
}
#[test]
fn build_test_router_rejects_invalid_rate_limit() {
let db = Arc::new(AletheiaDB::new().unwrap());
let state = AppState::new(db);
let config = ServerConfig::builder()
.rate_limit(RateLimitConfig::new(0, 1))
.build();
assert!(build_test_router(state, &config).is_err());
}
#[test]
fn build_cors_layer_permissive_runs() {
let _ = build_cors_layer(&CorsConfig::permissive());
}
#[test]
fn build_cors_layer_restrictive_runs() {
let _ = build_cors_layer(&CorsConfig::restrictive());
}
}