use crate::{BehaviorLoader, BoxedBehavior};
use anyhow::{Context as AnyhowContext, Result};
use notify::{Event, EventKind, RecommendedWatcher, RecursiveMode, Watcher};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{mpsc, RwLock};
use tracing::{debug, error, info, warn};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HotReloadConfig {
pub enabled: bool,
pub watch_paths: Vec<PathBuf>,
#[serde(default = "default_debounce_duration")]
pub debounce_duration_ms: u64,
}
fn default_debounce_duration() -> u64 {
500
}
impl Default for HotReloadConfig {
fn default() -> Self {
Self {
enabled: false,
watch_paths: vec![PathBuf::from("behaviors")],
debounce_duration_ms: default_debounce_duration(),
}
}
}
#[derive(Debug)]
pub enum ReloadEvent {
Success {
path: PathBuf,
behavior: BoxedBehavior,
},
Error {
path: PathBuf,
error: String,
},
}
pub struct HotReloadWatcher {
config: HotReloadConfig,
loader: BehaviorLoader,
pending_reloads: Arc<RwLock<HashMap<PathBuf, tokio::time::Instant>>>,
}
impl HotReloadWatcher {
pub fn new(config: HotReloadConfig, loader: BehaviorLoader) -> Self {
Self {
config,
loader,
pending_reloads: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn start(self) -> Result<mpsc::UnboundedReceiver<ReloadEvent>> {
if !self.config.enabled {
info!("Hot-reload is disabled");
let (_tx, rx) = mpsc::unbounded_channel();
return Ok(rx);
}
let (reload_tx, reload_rx) = mpsc::unbounded_channel();
let (fs_tx, mut fs_rx) = mpsc::unbounded_channel();
let mut watcher: RecommendedWatcher = notify::recommended_watcher(move |res: Result<Event, _>| match res {
Ok(event) => {
if let Err(e) = fs_tx.send(event) {
error!("Failed to send file system event: {}", e);
}
}
Err(e) => {
error!("File system watch error: {}", e);
}
})?;
for watch_path in &self.config.watch_paths {
if !watch_path.exists() {
warn!("Watch path does not exist: {:?}", watch_path);
continue;
}
watcher
.watch(watch_path, RecursiveMode::Recursive)
.with_context(|| format!("Failed to watch path: {:?}", watch_path))?;
info!("Watching for changes: {:?}", watch_path);
}
let loader = self.loader.clone();
let pending_reloads = self.pending_reloads.clone();
let debounce_duration = Duration::from_millis(self.config.debounce_duration_ms);
tokio::spawn(async move {
let _watcher = watcher;
while let Some(event) = fs_rx.recv().await {
if let Err(e) = handle_fs_event(event, &loader, &reload_tx, &pending_reloads, debounce_duration).await {
error!("Error handling file system event: {}", e);
}
}
info!("Hot-reload watcher stopped");
});
Ok(reload_rx)
}
pub fn is_enabled(&self) -> bool {
self.config.enabled
}
pub fn watch_paths(&self) -> &[PathBuf] {
&self.config.watch_paths
}
}
async fn handle_fs_event(
event: Event,
loader: &BehaviorLoader,
reload_tx: &mpsc::UnboundedSender<ReloadEvent>,
pending_reloads: &Arc<RwLock<HashMap<PathBuf, tokio::time::Instant>>>,
debounce_duration: Duration,
) -> Result<()> {
match event.kind {
EventKind::Modify(_) | EventKind::Create(_) => {}
_ => return Ok(()),
}
for path in event.paths {
if !is_behavior_tree_file(&path) {
continue;
}
debug!("File change detected: {:?}", path);
let now = tokio::time::Instant::now();
let mut pending = pending_reloads.write().await;
if let Some(last_reload) = pending.get(&path) {
if now.duration_since(*last_reload) < debounce_duration {
debug!("Debouncing reload for: {:?}", path);
continue;
}
}
pending.insert(path.clone(), now);
drop(pending);
reload_behavior_tree(&path, loader, reload_tx).await;
}
Ok(())
}
fn is_behavior_tree_file(path: &Path) -> bool {
path.extension().map(|ext| ext == "json").unwrap_or(false)
}
async fn reload_behavior_tree(path: &Path, loader: &BehaviorLoader, reload_tx: &mpsc::UnboundedSender<ReloadEvent>) {
info!("Reloading behavior tree: {:?}", path);
tokio::time::sleep(Duration::from_millis(50)).await;
match loader.load_from_file(path) {
Ok(behavior) => {
info!("Successfully reloaded: {:?}", path);
let event = ReloadEvent::Success {
path: path.to_path_buf(),
behavior,
};
if let Err(e) = reload_tx.send(event) {
error!("Failed to send reload event: {}", e);
}
}
Err(e) => {
error!("Failed to reload {:?}: {}", path, e);
let event = ReloadEvent::Error {
path: path.to_path_buf(),
error: format!("{:#}", e),
};
if let Err(e) = reload_tx.send(event) {
error!("Failed to send reload error event: {}", e);
}
}
}
}