use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio::sync::watch;
use tracing::info;
use crate::service::PredictionService;
use super::routes::routes;
#[derive(Debug, Clone)]
pub struct ServerConfig {
pub host: String,
pub port: u16,
pub await_explicit_shutdown: bool,
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
host: "0.0.0.0".to_string(),
port: 5000,
await_explicit_shutdown: false,
}
}
}
pub async fn serve(config: ServerConfig, service: Arc<PredictionService>) -> anyhow::Result<()> {
let shutdown_rx = service.shutdown_rx();
let app = routes(service.clone());
let addr: SocketAddr = format!("{}:{}", config.host, config.port).parse()?;
let listener = TcpListener::bind(addr).await?;
let actual_addr = listener.local_addr()?;
info!("Starting coglet server on {}", actual_addr);
axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal(config.await_explicit_shutdown, shutdown_rx))
.await?;
info!("Server shutdown complete");
service.shutdown().await;
Ok(())
}
async fn shutdown_signal(await_explicit_shutdown: bool, mut shutdown_rx: watch::Receiver<bool>) {
let ctrl_c = async {
tokio::signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler - is tokio runtime configured correctly?");
};
#[cfg(unix)]
let terminate = async {
if await_explicit_shutdown {
tracing::info!("await_explicit_shutdown enabled, ignoring SIGTERM");
std::future::pending::<()>().await
} else {
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect(
"failed to install SIGTERM handler - is tokio runtime configured correctly?",
)
.recv()
.await;
}
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
let explicit_shutdown = async {
while !*shutdown_rx.borrow() {
if shutdown_rx.changed().await.is_err() {
std::future::pending::<()>().await;
}
}
};
tokio::select! {
_ = ctrl_c => {
info!("Received SIGINT, shutting down...");
}
_ = terminate => {
info!("Received SIGTERM, shutting down...");
}
_ = explicit_shutdown => {
info!("Shutdown requested via /shutdown endpoint...");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn server_config_default() {
let config = ServerConfig::default();
assert_eq!(config.host, "0.0.0.0");
assert_eq!(config.port, 5000);
assert!(!config.await_explicit_shutdown);
}
}