use axum::Router;
use std::net::SocketAddr;
#[cfg(feature = "tls")]
use std::path::Path;
use tokio::net::TcpListener;
use tokio::signal;
use tracing::{info, warn};
#[derive(Debug, Clone)]
pub struct ServerConfig {
pub host: String,
pub port: u16,
pub tls: Option<TlsConfig>,
}
#[derive(Debug, Clone)]
pub struct TlsConfig {
pub cert_path: String,
pub key_path: String,
}
impl ServerConfig {
pub fn address(&self) -> Result<SocketAddr, std::net::AddrParseError> {
Ok(SocketAddr::new(self.host.parse()?, self.port))
}
pub fn default_http() -> Self {
Self {
host: "0.0.0.0".to_string(),
port: 8000,
tls: None,
}
}
pub fn default_https(cert_path: impl Into<String>, key_path: impl Into<String>) -> Self {
Self {
host: "0.0.0.0".to_string(),
port: 8443,
tls: Some(TlsConfig {
cert_path: cert_path.into(),
key_path: key_path.into(),
}),
}
}
pub fn is_tls(&self) -> bool {
self.tls.is_some()
}
}
pub async fn start_server(
app: Router,
config: ServerConfig,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
if let Some(tls_config) = &config.tls {
#[cfg(feature = "tls")]
{
start_server_tls(app, &config, tls_config).await
}
#[cfg(not(feature = "tls"))]
{
let _ = tls_config;
Err("TLS support not compiled in. Enable the 'tls' feature.".into())
}
} else {
start_server_http(app, &config).await
}
}
async fn start_server_http(
app: Router,
config: &ServerConfig,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let addr = config.address()?;
let listener = TcpListener::bind(addr).await?;
warn!(
"⚠️ Server starting at http://{} (PLAINTEXT - NOT SECURE)",
addr
);
warn!("⚠️ Use --tls-cert and --tls-key for production deployments");
axum::serve(listener, app)
.with_graceful_shutdown(handle_shutdown_signal())
.await?;
info!("🛑 Server has shut down gracefully");
Ok(())
}
#[cfg(feature = "tls")]
async fn start_server_tls(
app: Router,
config: &ServerConfig,
tls_config: &TlsConfig,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
use axum_server::tls_rustls::RustlsConfig;
use axum_server::Handle;
let _ = rustls::crypto::ring::default_provider().install_default();
let addr = config.address()?;
let rustls_config = RustlsConfig::from_pem_file(&tls_config.cert_path, &tls_config.key_path)
.await
.map_err(|e| format!("Failed to load TLS certificates: {}", e))?;
info!("🔐 HTTPS server starting at https://{}", addr);
info!(" Certificate: {}", tls_config.cert_path);
info!(" Private key: {}", tls_config.key_path);
let handle = Handle::new();
let shutdown_handle = handle.clone();
tokio::spawn(async move {
handle_shutdown_signal().await;
info!("Initiating graceful shutdown...");
shutdown_handle.graceful_shutdown(Some(std::time::Duration::from_secs(30)));
});
axum_server::bind_rustls(addr, rustls_config)
.handle(handle)
.serve(app.into_make_service())
.await?;
info!("🛑 Server has shut down gracefully");
Ok(())
}
async fn handle_shutdown_signal() {
#[cfg(unix)]
{
use tokio::signal::unix::{signal, SignalKind};
let mut terminate =
signal(SignalKind::terminate()).expect("Failed to install SIGTERM handler");
tokio::select! {
_ = signal::ctrl_c() => {
info!("Received Ctrl+C signal");
},
_ = terminate.recv() => {
info!("Received SIGTERM signal");
},
}
}
#[cfg(not(unix))]
{
if let Err(e) = signal::ctrl_c().await {
tracing::error!("Failed to listen for Ctrl+C: {}", e);
} else {
info!("Received Ctrl+C");
}
}
}
#[cfg(feature = "tls")]
pub fn generate_self_signed_cert(
common_name: &str,
) -> Result<(String, String), Box<dyn std::error::Error + Send + Sync>> {
use rcgen::{generate_simple_self_signed, CertifiedKey};
let subject_alt_names = vec![common_name.to_string(), "localhost".to_string()];
let CertifiedKey { cert, key_pair } = generate_simple_self_signed(subject_alt_names)?;
Ok((cert.pem(), key_pair.serialize_pem()))
}
#[cfg(feature = "tls")]
pub fn write_self_signed_cert(
common_name: &str,
cert_path: impl AsRef<Path>,
key_path: impl AsRef<Path>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let (cert_pem, key_pem) = generate_self_signed_cert(common_name)?;
std::fs::write(&cert_path, cert_pem)?;
std::fs::write(&key_path, key_pem)?;
info!(
"Generated self-signed certificate: {}",
cert_path.as_ref().display()
);
info!("Generated private key: {}", key_path.as_ref().display());
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_server_config_address() {
let config = ServerConfig::default_http();
let addr = config.address().unwrap();
assert_eq!(addr.port(), 8000);
}
#[test]
fn test_tls_detection() {
let http_config = ServerConfig::default_http();
assert!(!http_config.is_tls());
let https_config = ServerConfig::default_https("cert.pem", "key.pem");
assert!(https_config.is_tls());
}
#[cfg(feature = "tls")]
#[test]
fn test_generate_self_signed_cert() {
let (cert, key) = generate_self_signed_cert("test.local").unwrap();
assert!(cert.contains("BEGIN CERTIFICATE"));
assert!(key.contains("BEGIN PRIVATE KEY"));
}
}