use crate::{HttpConfig, HttpError, HttpResult};
use elif_core::{Container, app_config::AppConfigTrait};
use axum::{
Router,
routing::{get, IntoMakeService},
extract::State,
response::Json,
};
use serde_json::{json, Value};
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::signal;
use tower::ServiceBuilder;
use tower_http::{
trace::TraceLayer,
timeout::TimeoutLayer,
limit::RequestBodyLimitLayer,
};
use tracing::{info, warn};
#[derive(Clone)]
pub struct AppState {
pub container: Arc<Container>,
pub config: HttpConfig,
}
pub struct StatefulHttpServer {
router: Router,
state: AppState,
addr: SocketAddr,
}
pub struct StatefulHttpServerBuilder {
container: Option<Arc<Container>>,
http_config: Option<HttpConfig>,
custom_routes: Vec<Router>,
}
impl StatefulHttpServerBuilder {
pub fn new() -> Self {
Self {
container: None,
http_config: None,
custom_routes: Vec::new(),
}
}
pub fn container(mut self, container: Arc<Container>) -> Self {
self.container = Some(container);
self
}
pub fn http_config(mut self, config: HttpConfig) -> Self {
self.http_config = Some(config);
self
}
pub fn add_routes(mut self, routes: Router) -> Self {
self.custom_routes.push(routes);
self
}
pub fn build(self) -> HttpResult<StatefulHttpServer> {
let container = self.container
.ok_or_else(|| HttpError::config("Container is required"))?;
let http_config = match self.http_config {
Some(config) => config,
None => HttpConfig::from_env()?,
};
http_config.validate()?;
let server = StatefulHttpServer::new(container.clone(), http_config, self.custom_routes)?;
Ok(server)
}
}
impl Default for StatefulHttpServerBuilder {
fn default() -> Self {
Self::new()
}
}
impl StatefulHttpServer {
pub fn new(
container: Arc<Container>,
http_config: HttpConfig,
custom_routes: Vec<Router>,
) -> HttpResult<Self> {
let app_config = container.config();
let addr = format!("{}:{}", app_config.server.host, app_config.server.port)
.parse::<SocketAddr>()
.map_err(|e| HttpError::config(format!("Invalid server address: {}", e)))?;
let state = AppState {
container,
config: http_config.clone(),
};
let container = state.container.clone();
let config = state.config.clone();
let health_handler = move || {
let container = container.clone();
let config = config.clone();
async move {
stateless_health_check_with_context(container, config).await
}
};
let mut router = Router::new()
.route(&http_config.health_check_path, get(health_handler));
for custom_router in custom_routes {
router = router.merge(custom_router);
}
let middleware_stack = ServiceBuilder::new()
.layer(RequestBodyLimitLayer::new(http_config.max_request_size))
.layer(TimeoutLayer::new(http_config.request_timeout()));
if http_config.enable_tracing {
router = router.layer(TraceLayer::new_for_http());
}
router = router.layer(middleware_stack);
Ok(Self {
router,
state,
addr,
})
}
pub async fn run(self) -> HttpResult<()> {
info!(
"Starting stateful HTTP server on {} with DI container integration",
self.addr
);
let service = self.router.with_state(self.state);
let listener = tokio::net::TcpListener::bind(self.addr).await
.map_err(|e| HttpError::startup(format!("Failed to bind to {}: {}", self.addr, e)))?;
info!("Stateful HTTP server listening on {}", self.addr);
axum::serve(listener, service)
.with_graceful_shutdown(stateful_shutdown_signal())
.await
.map_err(|e| HttpError::startup(format!("Server failed: {}", e)))?;
info!("Stateful HTTP server stopped gracefully");
Ok(())
}
pub async fn run_with_shutdown<F>(self, shutdown: F) -> HttpResult<()>
where
F: std::future::Future<Output = ()> + Send + 'static,
{
info!(
"Starting stateful HTTP server on {} with custom shutdown handler",
self.addr
);
let service = self.router.with_state(self.state);
let listener = tokio::net::TcpListener::bind(self.addr).await
.map_err(|e| HttpError::startup(format!("Failed to bind to {}: {}", self.addr, e)))?;
info!("Stateful HTTP server listening on {}", self.addr);
axum::serve(listener, service)
.with_graceful_shutdown(shutdown)
.await
.map_err(|e| HttpError::startup(format!("Server failed: {}", e)))?;
info!("Stateful HTTP server stopped gracefully");
Ok(())
}
pub fn addr(&self) -> SocketAddr {
self.addr
}
}
async fn stateful_health_check(State(state): State<AppState>) -> Result<Json<Value>, HttpError> {
stateless_health_check_with_context(state.container, state.config).await
}
async fn stateless_health_check_with_context(
container: Arc<Container>,
config: HttpConfig
) -> Result<Json<Value>, HttpError> {
let database = container.database();
let db_healthy = database.is_connected();
if !db_healthy {
warn!("Health check failed: database not connected");
return Err(HttpError::health_check("Database connection unavailable"));
}
let app_config = container.config();
let response = json!({
"status": "healthy",
"timestamp": chrono::Utc::now().to_rfc3339(),
"version": "0.1.0",
"environment": format!("{:?}", app_config.environment),
"server": "stateful",
"services": {
"database": if db_healthy { "healthy" } else { "unhealthy" },
"container": "healthy"
},
"config": {
"request_timeout": config.request_timeout_secs,
"health_check_path": config.health_check_path,
"tracing_enabled": config.enable_tracing
}
});
Ok(Json(response))
}
async fn stateful_shutdown_signal() {
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {
info!("Received Ctrl+C, initiating graceful shutdown");
},
_ = terminate => {
info!("Received terminate signal, initiating graceful shutdown");
},
}
}
#[cfg(test)]
mod tests {
use super::*;
use elif_core::container::test_implementations::*;
use std::sync::Arc;
fn create_test_container() -> Arc<Container> {
let config = Arc::new(create_test_config());
let database = Arc::new(TestDatabase::new()) as Arc<dyn elif_core::DatabaseConnection>;
Container::builder()
.config(config)
.database(database)
.build()
.unwrap()
.into()
}
#[test]
fn test_stateful_server_builder() {
let container = create_test_container();
let http_config = HttpConfig::default();
let server = StatefulHttpServerBuilder::new()
.container(container)
.http_config(http_config)
.build();
assert!(server.is_ok());
let server = server.unwrap();
assert_eq!(server.addr().port(), 8080);
}
#[test]
fn test_stateful_server_builder_missing_container() {
let result = StatefulHttpServerBuilder::new().build();
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), HttpError::ConfigError { .. }));
}
#[test]
fn test_stateful_server_with_custom_routes() {
let container = create_test_container();
let http_config = HttpConfig::default();
let custom_routes = Router::new()
.route("/api/test", get(|| async { "test" }));
let server = StatefulHttpServerBuilder::new()
.container(container)
.http_config(http_config)
.add_routes(custom_routes)
.build();
assert!(server.is_ok());
}
#[tokio::test]
async fn test_stateful_health_check_handler() {
let container = create_test_container();
let state = AppState {
container,
config: HttpConfig::default(),
};
let result = stateful_health_check(State(state)).await;
assert!(result.is_ok());
let response = result.unwrap();
let status = response.0.get("status").and_then(|v| v.as_str()).unwrap();
assert_eq!(status, "healthy");
let server_type = response.0.get("server").and_then(|v| v.as_str()).unwrap();
assert_eq!(server_type, "stateful");
assert!(response.0.get("services").is_some());
assert!(response.0.get("config").is_some());
}
#[test]
fn test_app_state_clone() {
let container = create_test_container();
let state = AppState {
container,
config: HttpConfig::default(),
};
let cloned_state = state.clone();
assert_eq!(cloned_state.config.health_check_path, "/health");
assert_eq!(cloned_state.container.config().name, "test-app");
}
}