use notify::{Config, Event, EventKind, RecommendedWatcher, RecursiveMode, Watcher};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{RwLock, mpsc};
use tracing::{debug, error, info};
use crate::database::Database;
use crate::yaml::parser::parse_yaml_database;
use crate::yaml::schema::AuthConfig;
pub struct HotReloadManager {
yaml_path: PathBuf,
database: Arc<RwLock<Database>>,
auth_config: Arc<RwLock<Option<AuthConfig>>>,
watcher: Option<RecommendedWatcher>,
shutdown_tx: Option<mpsc::Sender<()>>,
}
impl HotReloadManager {
pub fn new(
yaml_path: PathBuf,
database: Arc<RwLock<Database>>,
auth_config: Arc<RwLock<Option<AuthConfig>>>,
) -> Self {
Self {
yaml_path,
database,
auth_config,
watcher: None,
shutdown_tx: None,
}
}
pub async fn start_watching(&mut self) -> crate::Result<()> {
info!(
"Starting hot reload watcher for: {}",
self.yaml_path.display()
);
let (tx, mut rx) = mpsc::channel(100);
let (shutdown_tx, mut shutdown_rx) = mpsc::channel(1);
self.shutdown_tx = Some(shutdown_tx);
let yaml_path = self.yaml_path.clone();
let database = Arc::clone(&self.database);
let auth_config = Arc::clone(&self.auth_config);
let mut watcher = RecommendedWatcher::new(
move |res: Result<Event, notify::Error>| {
match res {
Ok(event) => {
if matches!(event.kind, EventKind::Modify(_) | EventKind::Create(_)) {
let _ = tx.blocking_send(event);
}
}
Err(e) => error!("Watch error: {:?}", e),
}
},
Config::default()
.with_poll_interval(Duration::from_secs(1))
.with_compare_contents(true),
)
.map_err(|e| {
crate::YamlBaseError::Io(std::io::Error::other(format!(
"Failed to create file watcher: {}",
e
)))
})?;
let watch_path = self.yaml_path.parent().unwrap_or(Path::new("."));
watcher
.watch(watch_path, RecursiveMode::NonRecursive)
.map_err(|e| {
crate::YamlBaseError::Io(std::io::Error::other(format!(
"Failed to watch path: {}",
e
)))
})?;
self.watcher = Some(watcher);
tokio::spawn(async move {
let mut last_reload = std::time::Instant::now();
loop {
tokio::select! {
Some(event) = rx.recv() => {
let affects_our_file = event.paths.iter().any(|p| {
p.ends_with(&yaml_path) || p == &yaml_path
});
if !affects_our_file {
continue;
}
let now = std::time::Instant::now();
if now.duration_since(last_reload) < Duration::from_millis(500) {
debug!("Ignoring file change event (debouncing)");
continue;
}
last_reload = now;
info!("📁 Detected change in YAML file, reloading...");
tokio::time::sleep(Duration::from_millis(100)).await;
match Self::reload_database(&yaml_path, &database, &auth_config).await {
Ok(stats) => {
info!(
"✅ Hot reload successful! Loaded {} tables with {} total rows",
stats.table_count, stats.total_rows
);
}
Err(e) => {
error!("❌ Hot reload failed: {}", e);
error!("Database will continue with previous data");
}
}
}
_ = shutdown_rx.recv() => {
info!("Shutting down hot reload watcher");
break;
}
}
}
});
Ok(())
}
async fn reload_database(
yaml_path: &Path,
database: &Arc<RwLock<Database>>,
auth_config: &Arc<RwLock<Option<AuthConfig>>>,
) -> crate::Result<ReloadStats> {
let (new_db, new_auth) = parse_yaml_database(yaml_path).await?;
let stats = ReloadStats {
table_count: new_db.tables.len(),
total_rows: new_db.tables.values().map(|t| t.rows.len()).sum(),
};
{
let mut db_guard = database.write().await;
*db_guard = new_db;
}
{
let mut auth_guard = auth_config.write().await;
*auth_guard = new_auth;
}
Ok(stats)
}
pub async fn stop_watching(&mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(()).await;
}
self.watcher = None;
info!("Hot reload watcher stopped");
}
}
struct ReloadStats {
table_count: usize,
total_rows: usize,
}
impl Database {
pub fn validate_schema_compatibility(&self, new_db: &Database) -> Result<(), Vec<String>> {
let mut errors = Vec::new();
for (table_name, old_table) in &self.tables {
if let Some(new_table) = new_db.tables.get(table_name) {
for old_col in &old_table.columns {
if let Some(new_col) = new_table.columns.iter().find(|c| c.name == old_col.name)
{
if old_col.sql_type != new_col.sql_type {
errors.push(format!(
"Table '{}': Column '{}' type changed from {:?} to {:?}",
table_name, old_col.name, old_col.sql_type, new_col.sql_type
));
}
}
}
}
}
if errors.is_empty() {
Ok(())
} else {
Err(errors)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_hot_reload_detection() {
let temp_dir = tempfile::tempdir().unwrap();
let file_path = temp_dir.path().join("test.yaml");
tokio::fs::write(
&file_path,
r#"database:
name: test_db
tables:
users:
columns:
id: "INTEGER PRIMARY KEY"
name: "VARCHAR(100)"
data:
- id: 1
name: "Alice"
"#,
)
.await
.unwrap();
let (db, auth) = parse_yaml_database(&file_path).await.unwrap();
let database = Arc::new(RwLock::new(db));
let auth_config = Arc::new(RwLock::new(auth));
{
let db = database.read().await;
assert_eq!(db.tables.len(), 1);
assert!(db.tables.contains_key("users"));
let users_table = db.tables.get("users").unwrap();
assert_eq!(users_table.rows.len(), 1);
}
tokio::fs::write(
&file_path,
r#"database:
name: test_db
tables:
users:
columns:
id: "INTEGER PRIMARY KEY"
name: "VARCHAR(100)"
data:
- id: 1
name: "Alice"
- id: 2
name: "Bob"
"#,
)
.await
.unwrap();
let stats = HotReloadManager::reload_database(&file_path, &database, &auth_config)
.await
.unwrap();
assert_eq!(stats.table_count, 1);
assert_eq!(stats.total_rows, 2);
let db = database.read().await;
assert_eq!(db.tables.len(), 1);
let users_table = db.tables.get("users").unwrap();
assert_eq!(users_table.rows.len(), 2);
}
}