#[cfg(feature = "tls")]
use std::path::PathBuf;
use std::{net::SocketAddr, sync::Arc, time::Duration};
#[cfg(feature = "tls")]
use axum_server::tls_rustls::RustlsConfig;
use axum_server::{Address, Handle};
use color_eyre::eyre::{self, Context};
use tokio::{signal, time::sleep};
use tokio_util::sync::CancellationToken;
use tracing::{info, instrument};
use crate::{
api::{middleware::ratelimiter::rate_limiter_layer, router::app, state::AppState},
config::Config,
service::db::KorrosyncServiceRedb,
};
use crate::logging::init_logging;
pub mod api;
pub mod cli;
pub mod config;
pub mod logging;
pub mod model;
pub mod service;
const SHUTDOWN_DURATION_SECS: u64 = 30;
pub async fn run_server(cfg: Config) -> eyre::Result<()> {
init_logging();
let addr: SocketAddr = cfg
.server
.address
.parse()
.context("Error parsing binding address")?;
let state = AppState {
sync: Arc::new(KorrosyncServiceRedb::new(cfg.db.path).context("DB Init Error")?),
};
let shutdown_token_cleanup = CancellationToken::new();
let (rate_limiter, cleanup_task) =
rate_limiter_layer(shutdown_token_cleanup.clone(), &cfg.rate_limit);
let app = app(state)
.layer(rate_limiter)
.into_make_service_with_connect_info::<SocketAddr>();
let shutdown_handle = Handle::new();
tokio::spawn(shutdown_signal(shutdown_handle.clone()));
#[cfg(feature = "tls")]
{
if cfg.server.use_tls {
info!("TLS Server listening on {}", &addr);
let tls_config = RustlsConfig::from_pem_file(
PathBuf::from(cfg.server.cert_path),
PathBuf::from(cfg.server.key_path),
)
.await
.context("Error loading TLS keys")?;
axum_server::bind_rustls(addr, tls_config)
.handle(shutdown_handle)
.serve(app)
.await
.context("Failed to start TLS server")?;
} else {
info!("Server listening on {}", &addr);
axum_server::bind(addr)
.handle(shutdown_handle)
.serve(app)
.await
.context("Failed to start server")?;
}
}
#[cfg(not(feature = "tls"))]
{
info!("Server listening on {}", &addr);
axum_server::bind(addr)
.handle(shutdown_handle)
.serve(app)
.await
.context("Failed to start server")?;
}
shutdown_token_cleanup.cancel();
cleanup_task.await.map_err(|e| {
tracing::error!("Rate limiter cleanup task failed: {}", e);
e
})?;
info!("Server shutdown complete");
Ok(())
}
#[instrument(fields(graceful_shutdown), skip(handle))]
async fn shutdown_signal<A: Address>(handle: Handle<A>) {
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let interrupt = async {
signal::unix::signal(signal::unix::SignalKind::interrupt())
.expect("failed to install signal handler")
.recv()
.await;
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let interrupt = std::future::pending::<()>();
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = interrupt => info!("Got SIGINT"),
_ = ctrl_c => info!("Got Ctrl-C"),
_ = terminate => info!("Got SIGTERM"),
}
info!("Server is shutting down...");
handle.graceful_shutdown(Some(Duration::from_secs(SHUTDOWN_DURATION_SECS)));
for remaining in (1..=SHUTDOWN_DURATION_SECS).rev() {
sleep(Duration::from_secs(1)).await;
let connections = handle.connection_count();
tracing::info!("{connections} live connections left ({remaining}s left)");
if connections == 0 {
break;
}
if remaining == 1 {
tracing::warn!("Forcing shutdown with live connections");
}
}
}