use crate::config::SearchConfig;
use anyhow::{Context, Result};
use notify::{Config, Event, EventKind, RecommendedWatcher, RecursiveMode, Watcher};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use thiserror::Error;
use tokio::sync::RwLock;
use tracing::{debug, error, info, warn};
#[derive(Error, Debug)]
pub enum HotReloadError {
#[error("Failed to setup file watcher: {0}")]
WatcherSetupError(String),
#[error("File watch error: {0}")]
WatchError(String),
#[error("Configuration reload failed: {0}")]
ReloadError(String),
}
pub struct ConfigReloader {
config: Arc<RwLock<SearchConfig>>,
config_path: PathBuf,
watcher: Option<RecommendedWatcher>,
}
impl ConfigReloader {
pub fn new<P: AsRef<Path>>(config: Arc<RwLock<SearchConfig>>, config_path: P) -> Result<Self> {
let config_path = config_path.as_ref().to_path_buf();
if !config_path.exists() {
warn!(
"Configuration file does not exist: {}",
config_path.display()
);
}
Ok(Self {
config,
config_path,
watcher: None,
})
}
pub async fn watch(&mut self) -> Result<()> {
info!(
"Starting hot reload watcher for: {}",
self.config_path.display()
);
let (tx, mut rx) = tokio::sync::mpsc::channel(100);
let mut watcher = RecommendedWatcher::new(
move |result: Result<Event, notify::Error>| {
let _ = tx.blocking_send(result);
},
Config::default(),
)
.map_err(|e| HotReloadError::WatcherSetupError(e.to_string()))?;
watcher
.watch(&self.config_path, RecursiveMode::NonRecursive)
.map_err(|e| HotReloadError::WatchError(e.to_string()))?;
self.watcher = Some(watcher);
while let Some(result) = rx.recv().await {
match result {
Ok(event) => {
if self.should_reload(&event) {
debug!("Configuration file changed: {:?}", event);
if let Err(e) = self.reload_config().await {
error!("Failed to reload configuration: {}", e);
}
}
}
Err(e) => {
warn!("File watch error: {}", e);
}
}
}
Ok(())
}
fn should_reload(&self, event: &Event) -> bool {
match event.kind {
EventKind::Modify(_) | EventKind::Create(_) => {
event.paths.iter().any(|p| p.ends_with(&self.config_path))
}
_ => false,
}
}
async fn reload_config(&self) -> Result<()> {
info!(
"Reloading configuration from: {}",
self.config_path.display()
);
let new_config = SearchConfig::load_from_file(&self.config_path).await?;
new_config
.validate()
.context("New configuration validation failed")?;
let old_weights = {
let cfg = self.config.read().await;
cfg.fusion.weights.clone()
};
let weights_changed = old_weights.fts != new_config.fusion.weights.fts
|| old_weights.vector != new_config.fusion.weights.vector
|| old_weights.graph != new_config.fusion.weights.graph
|| old_weights.recency != new_config.fusion.weights.recency
|| old_weights.churn != new_config.fusion.weights.churn;
if weights_changed {
info!("Fusion weights changed:");
info!(
" Old weights: fts={:.3}, vector={:.3}, graph={:.3}, recency={:.3}, churn={:.3}",
old_weights.fts,
old_weights.vector,
old_weights.graph,
old_weights.recency,
old_weights.churn
);
info!(
" New weights: fts={:.3}, vector={:.3}, graph={:.3}, recency={:.3}, churn={:.3}",
new_config.fusion.weights.fts,
new_config.fusion.weights.vector,
new_config.fusion.weights.graph,
new_config.fusion.weights.recency,
new_config.fusion.weights.churn
);
}
{
let mut cfg = self.config.write().await;
cfg.fusion.weights = new_config.fusion.weights;
cfg.fusion.rrf_k = new_config.fusion.rrf_k;
cfg.fusion.method = new_config.fusion.method;
info!("Configuration reloaded successfully");
}
Ok(())
}
pub async fn reload(&self) -> Result<()> {
self.reload_config().await
}
pub fn config_path(&self) -> &Path {
&self.config_path
}
}
impl Drop for ConfigReloader {
fn drop(&mut self) {
if let Some(mut watcher) = self.watcher.take() {
if let Err(e) = watcher.unwatch(&self.config_path) {
warn!("Failed to unwatch configuration file: {}", e);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serial_test::serial;
use std::io::Write;
use std::sync::Mutex;
use tempfile::NamedTempFile;
static CONFIG_MUTEX: Mutex<()> = Mutex::new(());
#[tokio::test]
#[serial]
async fn test_config_reloader_creation() {
let config = Arc::new(RwLock::new(SearchConfig::default()));
let temp_file = NamedTempFile::new().unwrap();
let reloader = ConfigReloader::new(config, temp_file.path());
assert!(reloader.is_ok());
}
#[tokio::test]
#[serial]
async fn test_manual_reload() {
let _guard = CONFIG_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
std::env::remove_var("MAPROOM_SEARCH_FUSION_WEIGHTS_FTS");
std::env::remove_var("MAPROOM_SEARCH_FUSION_WEIGHTS_VECTOR");
std::env::remove_var("MAPROOM_SEARCH_FUSION_WEIGHTS_GRAPH");
std::env::remove_var("MAPROOM_SEARCH_FUSION_WEIGHTS_RECENCY");
std::env::remove_var("MAPROOM_SEARCH_FUSION_WEIGHTS_CHURN");
let config = Arc::new(RwLock::new(SearchConfig::default()));
let mut temp_file = NamedTempFile::new().unwrap();
let yaml = r#"
embedding:
provider: openai
model_name: test-model
dimension: 1536
cache_size: 1000
cache_ttl_seconds: 3600
fusion:
method: rrf
rrf_k: 60
weights:
fts: 0.5
vector: 0.3
graph: 0.1
recency: 0.08
churn: 0.02
performance:
max_candidates_per_method: 100
final_result_limit: 20
timeout_ms: 1000
parallel_execution: true
index:
ivfflat_lists: 100
ivfflat_probes: 10
refresh_interval_seconds: 3600
feature_flags:
enable_vector_search: true
enable_hybrid_fusion: true
enable_graph_signals: true
enable_temporal_signals: true
enable_query_cache: true
enable_hot_reload: true
"#;
temp_file.write_all(yaml.as_bytes()).unwrap();
temp_file.flush().unwrap();
let reloader = ConfigReloader::new(config.clone(), temp_file.path()).unwrap();
reloader.reload().await.unwrap();
let cfg = config.read().await;
assert_eq!(cfg.fusion.weights.fts, 0.5);
assert_eq!(cfg.fusion.weights.vector, 0.3);
}
#[tokio::test]
#[serial]
async fn test_invalid_config_rejected() {
let _guard = CONFIG_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
std::env::remove_var("MAPROOM_SEARCH_FUSION_WEIGHTS_FTS");
std::env::remove_var("MAPROOM_SEARCH_FUSION_WEIGHTS_VECTOR");
std::env::remove_var("MAPROOM_SEARCH_FUSION_WEIGHTS_GRAPH");
std::env::remove_var("MAPROOM_SEARCH_FUSION_WEIGHTS_RECENCY");
std::env::remove_var("MAPROOM_SEARCH_FUSION_WEIGHTS_CHURN");
let config = Arc::new(RwLock::new(SearchConfig::default()));
let mut temp_file = NamedTempFile::new().unwrap();
let yaml = r#"
embedding:
provider: openai
model_name: test-model
dimension: 1536
cache_size: 1000
cache_ttl_seconds: 3600
fusion:
method: rrf
rrf_k: 60
weights:
fts: -0.5
vector: 0.3
graph: 0.1
recency: 0.08
churn: 0.02
performance:
max_candidates_per_method: 100
final_result_limit: 20
timeout_ms: 1000
parallel_execution: true
index:
ivfflat_lists: 100
ivfflat_probes: 10
refresh_interval_seconds: 3600
feature_flags:
enable_vector_search: true
enable_hybrid_fusion: true
enable_graph_signals: true
enable_temporal_signals: true
enable_query_cache: true
enable_hot_reload: true
"#;
temp_file.write_all(yaml.as_bytes()).unwrap();
temp_file.flush().unwrap();
let reloader = ConfigReloader::new(config.clone(), temp_file.path()).unwrap();
let result = reloader.reload().await;
assert!(result.is_err());
let cfg = config.read().await;
assert_eq!(cfg.fusion.weights.fts, 0.4); }
#[test]
fn test_should_reload() {
let config = Arc::new(RwLock::new(SearchConfig::default()));
let temp_file = NamedTempFile::new().unwrap();
let reloader = ConfigReloader::new(config, temp_file.path()).unwrap();
let event = Event::new(EventKind::Modify(notify::event::ModifyKind::Any))
.add_path(temp_file.path().to_path_buf());
assert!(reloader.should_reload(&event));
let event = Event::new(EventKind::Create(notify::event::CreateKind::Any))
.add_path(temp_file.path().to_path_buf());
assert!(reloader.should_reload(&event));
let event = Event::new(EventKind::Remove(notify::event::RemoveKind::Any))
.add_path(temp_file.path().to_path_buf());
assert!(!reloader.should_reload(&event));
}
}