use crate::broker::config::BrokerConfig;
use crate::error::{MqttError, Result};
use crate::time::{SystemTime, UNIX_EPOCH};
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use tokio::fs;
use tokio::sync::{broadcast, RwLock};
use tracing::{debug, error, info, warn};
fn unix_timestamp_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConfigChangeEvent {
pub timestamp: u64,
pub change_type: ConfigChangeType,
pub config_path: PathBuf,
pub previous_hash: u64,
pub new_hash: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ConfigChangeType {
FullReload,
AuthConfig,
TlsConfig,
ResourceLimits,
WebSocketConfig,
BridgeConfig,
StorageConfig,
}
pub struct HotReloadManager {
current_config: Arc<RwLock<BrokerConfig>>,
config_path: PathBuf,
change_sender: broadcast::Sender<ConfigChangeEvent>,
watcher_handle: Option<tokio::task::JoinHandle<()>>,
last_modified: Arc<RwLock<Option<crate::time::SystemTime>>>,
config_hash: Arc<RwLock<u64>>,
}
impl HotReloadManager {
pub fn new(config: BrokerConfig, config_path: PathBuf) -> Result<Self> {
let (change_sender, _) = broadcast::channel(100);
let initial_hash = Self::calculate_config_hash(&config);
let manager = Self {
current_config: Arc::new(RwLock::new(config)),
config_path,
change_sender,
watcher_handle: None,
last_modified: Arc::new(RwLock::new(None)),
config_hash: Arc::new(RwLock::new(initial_hash)),
};
Ok(manager)
}
pub async fn start(&mut self) -> Result<()> {
info!("Starting configuration hot-reload system");
if let Ok(metadata) = fs::metadata(&self.config_path).await {
if let Ok(modified) = metadata.modified() {
*self.last_modified.write().await = Some(modified);
}
}
let watcher = self.start_file_watcher();
self.watcher_handle = Some(watcher);
info!(
"Configuration hot-reload system started, monitoring: {:?}",
self.config_path
);
Ok(())
}
fn start_file_watcher(&self) -> tokio::task::JoinHandle<()> {
let config_path = self.config_path.clone();
let last_modified = self.last_modified.clone();
let config_hash = self.config_hash.clone();
let current_config = self.current_config.clone();
let change_sender = self.change_sender.clone();
let handle = tokio::spawn(async move {
let mut interval = tokio::time::interval(crate::time::Duration::from_secs(5));
loop {
interval.tick().await;
match Self::check_file_changed(&config_path, &last_modified).await {
Ok(true) => {
info!("Configuration file changed, reloading: {:?}", config_path);
match Self::reload_config_file(&config_path).await {
Ok(new_config) => {
let new_hash = Self::calculate_config_hash(&new_config);
let old_hash = *config_hash.read().await;
if new_hash == old_hash {
debug!("Configuration file changed but content hash unchanged");
} else {
if let Err(e) = new_config.validate() {
error!(
"Invalid configuration file, ignoring reload: {}",
e
);
continue;
}
*current_config.write().await = new_config;
*config_hash.write().await = new_hash;
let event = ConfigChangeEvent {
timestamp: unix_timestamp_secs(),
change_type: ConfigChangeType::FullReload,
config_path: config_path.clone(),
previous_hash: old_hash,
new_hash,
};
if let Err(e) = change_sender.send(event) {
warn!("Failed to send config change notification: {e}");
}
info!("Configuration successfully reloaded");
}
}
Err(e) => {
error!("Failed to reload configuration: {e}");
}
}
}
Ok(false) => {
}
Err(e) => {
warn!("Error checking configuration file: {e}");
}
}
}
});
handle
}
async fn check_file_changed(
config_path: &Path,
last_modified: &Arc<RwLock<Option<crate::time::SystemTime>>>,
) -> Result<bool> {
let metadata = fs::metadata(config_path)
.await
.map_err(|e| MqttError::Io(format!("Failed to read config file metadata: {e}")))?;
let current_modified = metadata
.modified()
.map_err(|e| MqttError::Io(format!("Failed to get file modification time: {e}")))?;
let mut last_mod = last_modified.write().await;
if let Some(last) = *last_mod {
if current_modified > last {
*last_mod = Some(current_modified);
return Ok(true);
}
} else {
*last_mod = Some(current_modified);
}
Ok(false)
}
pub async fn reload_config_file(config_path: &Path) -> Result<BrokerConfig> {
let config_content = fs::read_to_string(config_path)
.await
.map_err(|e| MqttError::Io(format!("Failed to read config file: {e}")))?;
let config = if config_path.extension().and_then(|s| s.to_str()) == Some("toml") {
toml::from_str(&config_content)
.map_err(|e| MqttError::Configuration(format!("Invalid TOML config: {e}")))?
} else {
serde_json::from_str(&config_content)
.map_err(|e| MqttError::Configuration(format!("Invalid JSON config: {e}")))?
};
Ok(config)
}
fn calculate_config_hash(config: &BrokerConfig) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
let json = serde_json::to_string(config).unwrap_or_default();
json.hash(&mut hasher);
hasher.finish()
}
pub async fn get_config(&self) -> BrokerConfig {
self.current_config.read().await.clone()
}
pub async fn reload_now(&self) -> Result<bool> {
info!("Manually triggering configuration reload");
let new_config = Self::reload_config_file(&self.config_path).await?;
let new_hash = Self::calculate_config_hash(&new_config);
let old_hash = *self.config_hash.read().await;
if new_hash == old_hash {
info!("Configuration unchanged, no reload needed");
Ok(false)
} else {
new_config.validate()?;
*self.current_config.write().await = new_config;
*self.config_hash.write().await = new_hash;
let event = ConfigChangeEvent {
timestamp: unix_timestamp_secs(),
change_type: ConfigChangeType::FullReload,
config_path: self.config_path.clone(),
previous_hash: old_hash,
new_hash,
};
if let Err(e) = self.change_sender.send(event) {
warn!("Failed to send config change notification: {e}");
}
info!("Configuration manually reloaded successfully");
Ok(true)
}
}
#[must_use]
pub fn config_path(&self) -> &Path {
&self.config_path
}
#[must_use]
pub fn current_config_handle(&self) -> Arc<RwLock<BrokerConfig>> {
Arc::clone(&self.current_config)
}
#[must_use]
pub fn subscribe_to_changes(&self) -> broadcast::Receiver<ConfigChangeEvent> {
self.change_sender.subscribe()
}
pub async fn apply_partial_config(
&self,
change_type: ConfigChangeType,
update_fn: impl FnOnce(&mut BrokerConfig),
) -> Result<()> {
info!("Applying partial configuration change: {:?}", change_type);
let mut config = self.current_config.write().await;
let old_hash = Self::calculate_config_hash(&config);
update_fn(&mut config);
config.validate()?;
let new_hash = Self::calculate_config_hash(&config);
let event = ConfigChangeEvent {
timestamp: unix_timestamp_secs(),
change_type,
config_path: self.config_path.clone(),
previous_hash: old_hash,
new_hash,
};
if let Err(e) = self.change_sender.send(event) {
warn!("Failed to send config change notification: {e}");
}
info!("Partial configuration change applied successfully");
Ok(())
}
#[must_use]
pub fn get_stats(&self) -> HotReloadStats {
HotReloadStats {
config_path: self.config_path.clone(),
current_hash: futures::executor::block_on(async { *self.config_hash.read().await }),
subscribers: self.change_sender.receiver_count(),
}
}
}
#[derive(Debug, Serialize)]
pub struct HotReloadStats {
pub config_path: PathBuf,
pub current_hash: u64,
pub subscribers: usize,
}
pub struct ConfigSubscriber {
receiver: broadcast::Receiver<ConfigChangeEvent>,
component_name: String,
}
impl ConfigSubscriber {
#[allow(clippy::must_use_candidate)]
pub fn new(receiver: broadcast::Receiver<ConfigChangeEvent>, component_name: String) -> Self {
Self {
receiver,
component_name,
}
}
pub async fn wait_for_change(&mut self) -> Result<ConfigChangeEvent> {
loop {
match self.receiver.recv().await {
Ok(event) => {
debug!(
"Component '{}' received config change: {:?}",
self.component_name, event.change_type
);
return Ok(event);
}
Err(broadcast::error::RecvError::Closed) => {
return Err(MqttError::InvalidState(
"Config change channel closed".to_string(),
));
}
Err(broadcast::error::RecvError::Lagged(skipped)) => {
warn!(
"Component '{}' lagged behind, skipped {} config changes",
self.component_name, skipped
);
}
}
}
}
pub fn try_recv_change(&mut self) -> Option<ConfigChangeEvent> {
match self.receiver.try_recv() {
Ok(event) => {
debug!(
"Component '{}' received config change: {:?}",
self.component_name, event.change_type
);
Some(event)
}
Err(_) => None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::NamedTempFile;
#[tokio::test]
async fn test_hot_reload_manager() {
let temp_file = NamedTempFile::new().unwrap();
let initial_config = BrokerConfig::default();
let config_json = serde_json::to_string_pretty(&initial_config).unwrap();
tokio::fs::write(temp_file.path(), config_json)
.await
.unwrap();
let mut manager =
HotReloadManager::new(initial_config.clone(), temp_file.path().to_path_buf()).unwrap();
let mut subscriber =
ConfigSubscriber::new(manager.subscribe_to_changes(), "test".to_string());
manager.start().await.unwrap();
let mut updated_config = initial_config;
updated_config.max_clients = 5000;
let updated_json = serde_json::to_string_pretty(&updated_config).unwrap();
tokio::fs::write(temp_file.path(), updated_json)
.await
.unwrap();
let change_result = tokio::time::timeout(
crate::time::Duration::from_secs(2),
subscriber.wait_for_change(),
)
.await;
match change_result {
Ok(Ok(event)) => {
assert!(matches!(event.change_type, ConfigChangeType::FullReload));
let current_config = manager.get_config().await;
assert_eq!(current_config.max_clients, 5000);
}
Ok(Err(e)) => panic!("Failed to receive config change: {e}"),
Err(_) => {
println!("File watcher timeout, testing manual reload");
let reloaded = manager.reload_now().await.unwrap();
assert!(reloaded);
let current_config = manager.get_config().await;
assert_eq!(current_config.max_clients, 5000);
}
}
}
#[tokio::test]
async fn test_partial_config_update() {
let temp_file = NamedTempFile::new().unwrap();
let initial_config = BrokerConfig::default();
let manager =
HotReloadManager::new(initial_config, temp_file.path().to_path_buf()).unwrap();
manager
.apply_partial_config(ConfigChangeType::ResourceLimits, |config| {
config.max_clients = 10000;
})
.await
.unwrap();
let updated_config = manager.get_config().await;
assert_eq!(updated_config.max_clients, 10000);
}
#[tokio::test]
async fn test_config_validation() {
let temp_file = NamedTempFile::new().unwrap();
let initial_config = BrokerConfig::default();
let manager =
HotReloadManager::new(initial_config, temp_file.path().to_path_buf()).unwrap();
let result = manager
.apply_partial_config(ConfigChangeType::ResourceLimits, |config| {
config.max_clients = 0;
})
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_config_change_event_timestamp() {
let temp_file = NamedTempFile::new().unwrap();
let initial_config = BrokerConfig::default();
let manager =
HotReloadManager::new(initial_config.clone(), temp_file.path().to_path_buf()).unwrap();
let mut subscriber =
ConfigSubscriber::new(manager.subscribe_to_changes(), "timestamp_test".to_string());
let before_timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
manager
.apply_partial_config(ConfigChangeType::ResourceLimits, |config| {
config.max_clients = 3000;
})
.await
.unwrap();
let after_timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let event = subscriber.wait_for_change().await.unwrap();
assert!(event.timestamp >= before_timestamp);
assert!(event.timestamp <= after_timestamp);
assert!(matches!(
event.change_type,
ConfigChangeType::ResourceLimits
));
}
}