use std::path::{Path, PathBuf};
use std::sync::Arc;
use arc_swap::ArcSwap;
use notify::{Event, RecursiveMode, Watcher};
use parking_lot::RwLock;
use thiserror::Error;
use tokio::task::JoinHandle;
use tonic::transport::{Identity, ServerTlsConfig};
use tracing::{error, info, warn};
use amaters_net::tls_acceptor::{TlsCredsRef, build_rustls_config};
use crate::config::{ConfigError, ReloadableConfig, ServerConfig};
#[derive(Debug, Error)]
pub enum HotReloadError {
#[error("File watcher error: {0}")]
Watch(#[from] notify::Error),
#[error("IO error reading TLS file: {0}")]
Io(#[from] std::io::Error),
#[error("TLS credential error: {0}")]
Tls(String),
#[error("rustls config error: {0}")]
Rustls(String),
#[error("Config error: {0}")]
Config(#[from] ConfigError),
}
#[derive(Clone, Debug)]
pub struct TlsCreds {
pub cert_pem: Vec<u8>,
pub key_pem: Vec<u8>,
}
impl TlsCreds {
pub fn load_from_files(cert_path: &Path, key_path: &Path) -> Result<Self, HotReloadError> {
let cert_pem = std::fs::read(cert_path)?;
let key_pem = std::fs::read(key_path)?;
Ok(Self { cert_pem, key_pem })
}
pub fn to_server_tls_config(&self) -> ServerTlsConfig {
let identity = Identity::from_pem(&self.cert_pem, &self.key_pem);
ServerTlsConfig::new().identity(identity)
}
}
pub fn build_server_tls_config(
cert_path: &Path,
key_path: &Path,
) -> Result<ServerTlsConfig, HotReloadError> {
Ok(TlsCreds::load_from_files(cert_path, key_path)?.to_server_tls_config())
}
pub async fn spawn_config_reloader(
config_path: PathBuf,
config: Arc<RwLock<ServerConfig>>,
) -> JoinHandle<()> {
let initial = config.read().clone();
let reloadable = ReloadableConfig::new(initial);
reloadable.set_config_path(config_path.clone());
let reloadable_for_task = reloadable.clone();
let config_for_task = config.clone();
tokio::spawn(async move {
#[cfg(unix)]
{
use tokio::signal::unix::{SignalKind, signal};
let mut hangup = match signal(SignalKind::hangup()) {
Ok(s) => s,
Err(e) => {
warn!("Failed to register SIGHUP signal handler: {}", e);
return;
}
};
loop {
hangup.recv().await;
info!("SIGHUP received — reloading config from {:?}", config_path);
match reloadable_for_task.reload_from_stored_path() {
Ok(report) if report.success => {
let updated = reloadable_for_task.snapshot();
*config_for_task.write() = updated;
info!("Config reloaded successfully: {}", report);
}
Ok(report) => {
error!("Config reload failed — keeping old config: {}", report);
}
Err(e) => {
error!("Config reload error — keeping old config: {}", e);
}
}
}
}
#[cfg(not(unix))]
{
warn!(
"SIGHUP config reload is only supported on Unix platforms. \
Use ReloadableConfig::manual_reload() as an alternative."
);
}
})
}
pub async fn spawn_tls_reloader(
cert_path: PathBuf,
key_path: PathBuf,
tls_creds: Arc<ArcSwap<TlsCreds>>,
) -> Result<(), HotReloadError> {
let (tx, mut rx) = tokio::sync::mpsc::channel::<notify::Result<Event>>(16);
let mut watcher = notify::recommended_watcher(move |event: notify::Result<Event>| {
let _ = tx.blocking_send(event);
})?;
let cert_dir = cert_path.parent().unwrap_or_else(|| Path::new("."));
watcher.watch(cert_dir, RecursiveMode::NonRecursive)?;
let cert_path_task = cert_path.clone();
let key_path_task = key_path.clone();
tokio::spawn(async move {
let _watcher = watcher;
while let Some(event) = rx.recv().await {
match event {
Ok(e) => {
let relevant = e
.paths
.iter()
.any(|p| p == &cert_path_task || p == &key_path_task);
if !relevant {
continue;
}
match TlsCreds::load_from_files(&cert_path_task, &key_path_task) {
Ok(new_creds) => {
tls_creds.store(Arc::new(new_creds));
info!("TLS credentials reloaded from {:?}", cert_path_task);
}
Err(e) => {
error!("TLS reload failed — keeping existing credentials: {}", e);
}
}
}
Err(e) => {
warn!("File-watcher error (TLS reloader): {}", e);
}
}
}
});
Ok(())
}
pub fn swap_rustls_config(
store: &Arc<ArcSwap<rustls::ServerConfig>>,
creds: &TlsCreds,
) -> Result<(), HotReloadError> {
let creds_ref = TlsCredsRef::new(&creds.cert_pem, &creds.key_pem);
let new_config =
build_rustls_config(&creds_ref).map_err(|e| HotReloadError::Rustls(e.to_string()))?;
store.store(Arc::new(new_config));
Ok(())
}
pub async fn spawn_tls_reloader_with_rustls_store(
cert_path: PathBuf,
key_path: PathBuf,
tls_creds: Arc<ArcSwap<TlsCreds>>,
rustls_store: Arc<ArcSwap<rustls::ServerConfig>>,
) -> Result<(), HotReloadError> {
let (tx, mut rx) = tokio::sync::mpsc::channel::<notify::Result<Event>>(16);
let mut watcher = notify::recommended_watcher(move |event: notify::Result<Event>| {
let _ = tx.blocking_send(event);
})?;
let cert_dir = cert_path.parent().unwrap_or_else(|| Path::new("."));
watcher.watch(cert_dir, RecursiveMode::NonRecursive)?;
let cert_path_task = cert_path.clone();
let key_path_task = key_path.clone();
tokio::spawn(async move {
let _watcher = watcher;
while let Some(event) = rx.recv().await {
match event {
Ok(e) => {
let relevant = e
.paths
.iter()
.any(|p| p == &cert_path_task || p == &key_path_task);
if !relevant {
continue;
}
let new_creds = match TlsCreds::load_from_files(&cert_path_task, &key_path_task)
{
Ok(c) => c,
Err(e) => {
error!(
"TLS reload failed (file read) — keeping existing credentials: {e}",
);
continue;
}
};
if let Err(e) = swap_rustls_config(&rustls_store, &new_creds) {
error!("TLS reload failed (rustls build) — keeping existing config: {e}",);
continue;
}
tls_creds.store(Arc::new(new_creds));
info!(
"TLS credentials reloaded (legacy + rustls) from {:?}",
cert_path_task
);
}
Err(e) => {
warn!("File-watcher error (TLS reloader): {e}");
}
}
}
});
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::env;
use std::fs;
fn make_config(bind: &str) -> ServerConfig {
let mut c = ServerConfig::default();
c.server.bind_address = bind.to_string();
c
}
#[test]
fn test_config_diff_empty_when_identical() {
use crate::config::diff;
let c = make_config("127.0.0.1:7878");
let d = diff(&c, &c);
assert!(
d.is_empty(),
"Diff of identical configs should be empty, got {:?}",
d
);
}
#[test]
fn test_config_diff_detects_log_level_change() {
use crate::config::ReloadableSection;
use crate::config::diff;
let old = make_config("127.0.0.1:7878");
let mut new = old.clone();
new.logging.level = "debug".to_string();
let d = diff(&old, &new);
assert!(
d.reloadable_changes.contains(&ReloadableSection::Logging),
"Expected Logging in reloadable_changes, got {:?}",
d.reloadable_changes
);
}
#[test]
fn test_config_diff_detects_rate_limit_change() {
use crate::config::ReloadableSection;
use crate::config::diff;
let old = make_config("127.0.0.1:7878");
let mut new = old.clone();
new.server.max_connections = old.server.max_connections + 500;
let d = diff(&old, &new);
assert!(
d.reloadable_changes.contains(&ReloadableSection::RateLimit),
"Expected RateLimit in reloadable_changes, got {:?}",
d.reloadable_changes
);
}
#[test]
fn test_config_diff_non_reloadable_bind_address() {
use crate::config::{NonReloadableSection, diff};
let old = make_config("127.0.0.1:7878");
let new = make_config("127.0.0.1:9999");
let d = diff(&old, &new);
assert!(
d.non_reloadable_changes
.contains(&NonReloadableSection::BindAddress),
"Expected BindAddress in non_reloadable_changes, got {:?}",
d.non_reloadable_changes
);
}
#[test]
fn test_manual_reload_applies_log_level_change() {
let dir = env::temp_dir();
let path = dir.join("amaters_hot_reload_test_manual.toml");
let initial = make_config("127.0.0.1:7878");
initial.save_to_file(&path).expect("save initial config");
let rc = ReloadableConfig::new(initial.clone());
rc.set_config_path(path.clone());
let mut updated = initial.clone();
updated.logging.level = "warn".to_string();
updated.save_to_file(&path).expect("save updated config");
let report = rc.manual_reload().expect("manual_reload succeeded");
assert!(report.success, "Expected reload success: {:?}", report);
assert_eq!(
rc.snapshot().logging.level,
"warn",
"Log level should be updated to 'warn'"
);
fs::remove_file(&path).ok();
}
#[test]
fn test_tls_creds_load_missing_file() {
let result = TlsCreds::load_from_files(
Path::new("/nonexistent/cert.pem"),
Path::new("/nonexistent/key.pem"),
);
assert!(result.is_err(), "Expected error for missing files");
}
#[test]
fn test_tls_creds_load_valid_files() {
let dir = env::temp_dir();
let cert = dir.join("amaters_hot_reload_test_cert.pem");
let key = dir.join("amaters_hot_reload_test_key.pem");
fs::write(
&cert,
b"-----BEGIN CERTIFICATE-----\ntest\n-----END CERTIFICATE-----\n",
)
.expect("write cert");
fs::write(
&key,
b"-----BEGIN PRIVATE KEY-----\ntest\n-----END PRIVATE KEY-----\n",
)
.expect("write key");
let creds = TlsCreds::load_from_files(&cert, &key).expect("load creds");
assert!(!creds.cert_pem.is_empty());
assert!(!creds.key_pem.is_empty());
fs::remove_file(&cert).ok();
fs::remove_file(&key).ok();
}
#[test]
fn test_tls_creds_arc_swap() {
let dir = env::temp_dir();
let cert = dir.join("amaters_arc_swap_cert.pem");
let key = dir.join("amaters_arc_swap_key.pem");
fs::write(&cert, b"cert_v1").expect("write cert");
fs::write(&key, b"key_v1").expect("write key");
let creds1 = TlsCreds::load_from_files(&cert, &key).expect("load v1");
let store: Arc<ArcSwap<TlsCreds>> = Arc::new(ArcSwap::from_pointee(creds1));
assert_eq!(store.load().cert_pem, b"cert_v1");
fs::write(&cert, b"cert_v2").expect("write cert v2");
fs::write(&key, b"key_v2").expect("write key v2");
let creds2 = TlsCreds::load_from_files(&cert, &key).expect("load v2");
store.store(Arc::new(creds2));
assert_eq!(store.load().cert_pem, b"cert_v2");
fs::remove_file(&cert).ok();
fs::remove_file(&key).ok();
}
#[test]
fn test_build_server_tls_config_file_error() {
let result = build_server_tls_config(
Path::new("/nonexistent/cert.pem"),
Path::new("/nonexistent/key.pem"),
);
assert!(
matches!(result, Err(HotReloadError::Io(_))),
"Expected Io error, got {:?}",
result
);
}
#[test]
fn test_swap_rustls_config_rejects_invalid_pem() {
let dir = env::temp_dir();
let cert = dir.join(format!(
"amaters_swap_rustls_cert_{}.pem",
uuid::Uuid::new_v4()
));
let key = dir.join(format!(
"amaters_swap_rustls_key_{}.pem",
uuid::Uuid::new_v4()
));
fs::write(&cert, b"not-pem").expect("write cert");
fs::write(&key, b"not-pem").expect("write key");
let creds = TlsCreds::load_from_files(&cert, &key).expect("load creds");
let placeholder = make_placeholder_server_config();
let store: Arc<ArcSwap<rustls::ServerConfig>> =
Arc::new(ArcSwap::from_pointee(placeholder));
let result = swap_rustls_config(&store, &creds);
assert!(
matches!(result, Err(HotReloadError::Rustls(_))),
"Expected Rustls error, got {:?}",
result
);
fs::remove_file(&cert).ok();
fs::remove_file(&key).ok();
}
#[test]
fn test_swap_rustls_config_accepts_valid_pem() {
let _ = rustls::crypto::ring::default_provider().install_default();
let (cert_pem, key_pem) = generate_pem_pair("swap.test");
let creds = TlsCreds { cert_pem, key_pem };
let placeholder = make_placeholder_server_config();
let store: Arc<ArcSwap<rustls::ServerConfig>> =
Arc::new(ArcSwap::from_pointee(placeholder));
swap_rustls_config(&store, &creds).expect("swap should succeed");
let _ = store.load();
}
fn make_placeholder_server_config() -> rustls::ServerConfig {
let _ = rustls::crypto::ring::default_provider().install_default();
let (cert_pem, key_pem) = generate_pem_pair("placeholder.test");
let creds_ref = TlsCredsRef::new(&cert_pem, &key_pem);
build_rustls_config(&creds_ref).expect("placeholder rustls config")
}
fn generate_pem_pair(cn: &str) -> (Vec<u8>, Vec<u8>) {
use amaters_net::tls::SelfSignedGenerator;
use rustls::pki_types::PrivateKeyDer;
let generator = SelfSignedGenerator::new(cn)
.with_san(cn)
.with_san("localhost");
let (cert_der, key_der) = generator.generate().expect("generate cert");
let cert_pem = pem_encode("CERTIFICATE", cert_der.as_ref());
let key_pem = match key_der {
PrivateKeyDer::Pkcs8(k) => pem_encode("PRIVATE KEY", k.secret_pkcs8_der()),
PrivateKeyDer::Pkcs1(k) => pem_encode("RSA PRIVATE KEY", k.secret_pkcs1_der()),
PrivateKeyDer::Sec1(k) => pem_encode("EC PRIVATE KEY", k.secret_sec1_der()),
_ => panic!("unsupported key kind"),
};
(cert_pem, key_pem)
}
fn pem_encode(label: &str, der: &[u8]) -> Vec<u8> {
let mut out = format!("-----BEGIN {label}-----\n").into_bytes();
let b64 = base64_encode_test(der);
for chunk in b64.as_bytes().chunks(64) {
out.extend_from_slice(chunk);
out.push(b'\n');
}
out.extend_from_slice(format!("-----END {label}-----\n").as_bytes());
out
}
fn base64_encode_test(data: &[u8]) -> String {
const ALPHABET: &[u8; 64] =
b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
let mut out = String::with_capacity(data.len().div_ceil(3) * 4);
let mut i = 0;
while i + 3 <= data.len() {
let n = ((data[i] as u32) << 16) | ((data[i + 1] as u32) << 8) | (data[i + 2] as u32);
out.push(ALPHABET[((n >> 18) & 0x3f) as usize] as char);
out.push(ALPHABET[((n >> 12) & 0x3f) as usize] as char);
out.push(ALPHABET[((n >> 6) & 0x3f) as usize] as char);
out.push(ALPHABET[(n & 0x3f) as usize] as char);
i += 3;
}
let rem = data.len() - i;
if rem == 1 {
let n = (data[i] as u32) << 16;
out.push(ALPHABET[((n >> 18) & 0x3f) as usize] as char);
out.push(ALPHABET[((n >> 12) & 0x3f) as usize] as char);
out.push('=');
out.push('=');
} else if rem == 2 {
let n = ((data[i] as u32) << 16) | ((data[i + 1] as u32) << 8);
out.push(ALPHABET[((n >> 18) & 0x3f) as usize] as char);
out.push(ALPHABET[((n >> 12) & 0x3f) as usize] as char);
out.push(ALPHABET[((n >> 6) & 0x3f) as usize] as char);
out.push('=');
}
out
}
#[tokio::test]
async fn test_spawn_config_reloader_returns_handle() {
let dir = env::temp_dir();
let path = dir.join("amaters_sighup_test_config.toml");
let initial = make_config("127.0.0.1:7878");
initial.save_to_file(&path).expect("save config");
let config = Arc::new(RwLock::new(initial.clone()));
let handle = spawn_config_reloader(path.clone(), config.clone()).await;
assert!(!handle.is_finished(), "Reloader task should be running");
handle.abort();
fs::remove_file(&path).ok();
}
#[cfg(unix)]
#[tokio::test]
#[ignore = "Integration test — sends a real SIGHUP; run manually with --ignored"]
async fn test_sighup_reloads_config() {
use std::time::Duration;
let dir = env::temp_dir();
let path = dir.join("amaters_sighup_integration_test.toml");
let initial = make_config("127.0.0.1:7878");
initial.save_to_file(&path).expect("save config");
let config = Arc::new(RwLock::new(initial.clone()));
let handle = spawn_config_reloader(path.clone(), config.clone()).await;
tokio::time::sleep(Duration::from_millis(50)).await;
let mut updated = initial.clone();
updated.logging.level = "debug".to_string();
updated.save_to_file(&path).expect("save updated config");
let pid = std::process::id();
let _ = std::process::Command::new("kill")
.args(["-HUP", &pid.to_string()])
.status()
.expect("failed to invoke kill command");
tokio::time::sleep(Duration::from_millis(200)).await;
assert_eq!(
config.read().logging.level,
"debug",
"Expected log level to be 'debug' after SIGHUP reload"
);
handle.abort();
fs::remove_file(&path).ok();
}
}