use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Duration;
use adk_core::Agent;
use notify::{EventKind, RecursiveMode, Watcher};
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
use super::loader::AgentConfigLoader;
const DEFAULT_DEBOUNCE: Duration = Duration::from_millis(500);
pub struct HotReloadWatcher {
loader: Arc<AgentConfigLoader>,
active_agents: Arc<RwLock<HashMap<PathBuf, Arc<dyn Agent>>>>,
debounce: Duration,
}
impl HotReloadWatcher {
pub fn new(loader: Arc<AgentConfigLoader>) -> Self {
Self {
loader,
active_agents: Arc::new(RwLock::new(HashMap::new())),
debounce: DEFAULT_DEBOUNCE,
}
}
pub fn with_debounce(loader: Arc<AgentConfigLoader>, debounce: Duration) -> Self {
Self { loader, active_agents: Arc::new(RwLock::new(HashMap::new())), debounce }
}
pub async fn watch(&self, dir: &Path) -> adk_core::Result<tokio::task::JoinHandle<()>> {
let dir = dir.to_path_buf();
let dir_display = dir.display().to_string();
let agents = self.loader.load_directory(&dir).await?;
{
let mut active = self.active_agents.write().await;
let yaml_files = collect_yaml_files(&dir)?;
for file_path in yaml_files {
if let Ok(agent) = self.loader.load_file(&file_path).await {
active.insert(file_path, agent);
}
}
}
info!("hot reload watcher initialized with {} agents from {dir_display}", agents.len());
let (tx, mut rx) = tokio::sync::mpsc::channel::<PathBuf>(100);
let dir_for_watcher = dir.clone();
let notify_tx = tx.clone();
let _watcher_handle = std::thread::spawn(move || {
let rt_tx = notify_tx;
let mut watcher = match notify::recommended_watcher(
move |res: Result<notify::Event, notify::Error>| {
if let Ok(event) = res {
match event.kind {
EventKind::Modify(_) | EventKind::Create(_) => {
for path in event.paths {
if is_yaml_file(&path) {
let _ = rt_tx.blocking_send(path);
}
}
}
_ => {}
}
}
},
) {
Ok(w) => w,
Err(e) => {
warn!("failed to create filesystem watcher: {e}");
return;
}
};
if let Err(e) = watcher.watch(&dir_for_watcher, RecursiveMode::NonRecursive) {
warn!("failed to watch directory {}: {e}", dir_for_watcher.display());
return;
}
debug!("filesystem watcher started for {}", dir_for_watcher.display());
loop {
std::thread::park();
}
});
let loader = Arc::clone(&self.loader);
let active_agents = Arc::clone(&self.active_agents);
let debounce = self.debounce;
let handle = tokio::spawn(async move {
let mut pending: HashMap<PathBuf, tokio::time::Instant> = HashMap::new();
loop {
let next_deadline = pending.values().min().copied();
tokio::select! {
Some(path) = rx.recv() => {
let deadline = tokio::time::Instant::now() + debounce;
pending.insert(path, deadline);
}
_ = async {
match next_deadline {
Some(deadline) => tokio::time::sleep_until(deadline).await,
None => std::future::pending::<()>().await,
}
} => {
let now = tokio::time::Instant::now();
let ready: Vec<PathBuf> = pending
.iter()
.filter(|(_, deadline)| **deadline <= now)
.map(|(path, _)| path.clone())
.collect();
for path in ready {
pending.remove(&path);
reload_agent(&loader, &active_agents, &path).await;
}
}
}
}
});
Ok(handle)
}
pub async fn get_agent(&self, name: &str) -> Option<Arc<dyn Agent>> {
let agents = self.active_agents.read().await;
agents.values().find(|agent| agent.name() == name).cloned()
}
pub async fn all_agents(&self) -> Vec<Arc<dyn Agent>> {
self.active_agents.read().await.values().cloned().collect()
}
}
async fn reload_agent(
loader: &AgentConfigLoader,
active_agents: &RwLock<HashMap<PathBuf, Arc<dyn Agent>>>,
path: &Path,
) {
let path_display = path.display();
info!("reloading agent from {path_display}");
match loader.reload_file(path).await {
Ok(agent) => {
let agent_name = agent.name().to_string();
active_agents.write().await.insert(path.to_path_buf(), agent);
info!("successfully reloaded agent '{agent_name}' from {path_display}");
}
Err(e) => {
warn!("failed to reload agent from {path_display}: {e}");
}
}
}
fn is_yaml_file(path: &Path) -> bool {
path.extension()
.and_then(|ext| ext.to_str())
.map(|ext| {
let ext = ext.to_lowercase();
ext == "yaml" || ext == "yml"
})
.unwrap_or(false)
}
fn collect_yaml_files(dir: &Path) -> adk_core::Result<Vec<PathBuf>> {
let mut files = Vec::new();
let entries = std::fs::read_dir(dir).map_err(|e| {
adk_core::AdkError::config(format!("failed to read directory '{}': {e}", dir.display()))
})?;
for entry in entries {
let entry = entry.map_err(|e| {
adk_core::AdkError::config(format!(
"failed to read directory entry in '{}': {e}",
dir.display()
))
})?;
let path = entry.path();
if path.is_file() && is_yaml_file(&path) {
files.push(path);
}
}
files.sort();
Ok(files)
}
#[cfg(test)]
mod tests {
use super::*;
use adk_core::{Llm, LlmRequest, Tool};
use async_trait::async_trait;
use std::io::Write;
use super::super::loader::ModelFactory;
struct MockModelFactory;
#[async_trait]
impl ModelFactory for MockModelFactory {
async fn create_model(
&self,
provider: &str,
model_id: &str,
) -> adk_core::Result<Arc<dyn Llm>> {
Ok(Arc::new(MockLlm { name: format!("{provider}/{model_id}") }))
}
}
struct MockLlm {
name: String,
}
#[async_trait]
impl Llm for MockLlm {
fn name(&self) -> &str {
&self.name
}
async fn generate_content(
&self,
_request: LlmRequest,
_stream: bool,
) -> adk_core::Result<adk_core::LlmResponseStream> {
unimplemented!("mock LLM")
}
}
struct MockToolRegistry;
impl adk_core::ToolRegistry for MockToolRegistry {
fn resolve(&self, _tool_name: &str) -> Option<Arc<dyn Tool>> {
None
}
fn available_tools(&self) -> Vec<String> {
vec![]
}
}
fn write_yaml(dir: &Path, filename: &str, content: &str) -> PathBuf {
let path = dir.join(filename);
let mut file = std::fs::File::create(&path).unwrap();
file.write_all(content.as_bytes()).unwrap();
path
}
#[test]
fn test_is_yaml_file() {
assert!(is_yaml_file(Path::new("agent.yaml")));
assert!(is_yaml_file(Path::new("agent.yml")));
assert!(is_yaml_file(Path::new("agent.YAML")));
assert!(is_yaml_file(Path::new("agent.YML")));
assert!(!is_yaml_file(Path::new("agent.json")));
assert!(!is_yaml_file(Path::new("agent.txt")));
assert!(!is_yaml_file(Path::new("agent")));
}
#[tokio::test]
async fn test_watcher_initial_load() {
let dir = tempfile::tempdir().unwrap();
write_yaml(
dir.path(),
"agent.yaml",
r#"
name: test_agent
model:
provider: gemini
model_id: gemini-2.0-flash
instructions: "Hello"
"#,
);
let registry = Arc::new(MockToolRegistry);
let factory = Arc::new(MockModelFactory);
let loader = Arc::new(AgentConfigLoader::new(registry, factory));
let watcher = HotReloadWatcher::new(loader);
let handle = watcher.watch(dir.path()).await.unwrap();
let agent = watcher.get_agent("test_agent").await;
assert!(agent.is_some());
assert_eq!(agent.unwrap().name(), "test_agent");
handle.abort();
}
#[tokio::test]
async fn test_watcher_get_agent_not_found() {
let dir = tempfile::tempdir().unwrap();
write_yaml(
dir.path(),
"agent.yaml",
r#"
name: test_agent
model:
provider: gemini
model_id: gemini-2.0-flash
"#,
);
let registry = Arc::new(MockToolRegistry);
let factory = Arc::new(MockModelFactory);
let loader = Arc::new(AgentConfigLoader::new(registry, factory));
let watcher = HotReloadWatcher::new(loader);
let handle = watcher.watch(dir.path()).await.unwrap();
let agent = watcher.get_agent("nonexistent").await;
assert!(agent.is_none());
handle.abort();
}
#[tokio::test]
async fn test_watcher_all_agents() {
let dir = tempfile::tempdir().unwrap();
write_yaml(
dir.path(),
"agent_a.yaml",
r#"
name: agent_a
model:
provider: gemini
model_id: gemini-2.0-flash
"#,
);
write_yaml(
dir.path(),
"agent_b.yml",
r#"
name: agent_b
model:
provider: openai
model_id: gpt-4
"#,
);
let registry = Arc::new(MockToolRegistry);
let factory = Arc::new(MockModelFactory);
let loader = Arc::new(AgentConfigLoader::new(registry, factory));
let watcher = HotReloadWatcher::new(loader);
let handle = watcher.watch(dir.path()).await.unwrap();
let agents = watcher.all_agents().await;
assert_eq!(agents.len(), 2);
let names: Vec<&str> = agents.iter().map(|a| a.name()).collect();
assert!(names.contains(&"agent_a"));
assert!(names.contains(&"agent_b"));
handle.abort();
}
#[tokio::test]
async fn test_reload_agent_success() {
let dir = tempfile::tempdir().unwrap();
let path = write_yaml(
dir.path(),
"agent.yaml",
r#"
name: reloadable
model:
provider: gemini
model_id: gemini-2.0-flash
instructions: "Version 1"
"#,
);
let registry = Arc::new(MockToolRegistry);
let factory = Arc::new(MockModelFactory);
let loader = Arc::new(AgentConfigLoader::new(registry, factory));
let active_agents: Arc<RwLock<HashMap<PathBuf, Arc<dyn Agent>>>> =
Arc::new(RwLock::new(HashMap::new()));
let agent = loader.load_file(&path).await.unwrap();
active_agents.write().await.insert(path.clone(), agent);
std::fs::write(
&path,
r#"
name: reloadable
model:
provider: gemini
model_id: gemini-2.0-flash
instructions: "Version 2"
"#,
)
.unwrap();
reload_agent(&loader, &active_agents, &path).await;
let agents = active_agents.read().await;
let agent = agents.get(&path).unwrap();
assert_eq!(agent.name(), "reloadable");
}
#[tokio::test]
async fn test_reload_agent_validation_failure_keeps_previous() {
let dir = tempfile::tempdir().unwrap();
let path = write_yaml(
dir.path(),
"agent.yaml",
r#"
name: stable_agent
model:
provider: gemini
model_id: gemini-2.0-flash
instructions: "Valid agent"
"#,
);
let registry = Arc::new(MockToolRegistry);
let factory = Arc::new(MockModelFactory);
let loader = Arc::new(AgentConfigLoader::new(registry, factory));
let active_agents: Arc<RwLock<HashMap<PathBuf, Arc<dyn Agent>>>> =
Arc::new(RwLock::new(HashMap::new()));
let agent = loader.load_file(&path).await.unwrap();
active_agents.write().await.insert(path.clone(), agent);
std::fs::write(&path, "invalid: yaml: content: [broken").unwrap();
reload_agent(&loader, &active_agents, &path).await;
let agents = active_agents.read().await;
let agent = agents.get(&path).unwrap();
assert_eq!(agent.name(), "stable_agent");
}
#[test]
fn test_collect_yaml_files() {
let dir = tempfile::tempdir().unwrap();
write_yaml(dir.path(), "a.yaml", "name: a\n");
write_yaml(dir.path(), "b.yml", "name: b\n");
write_yaml(dir.path(), "c.json", "{}");
write_yaml(dir.path(), "d.txt", "hello");
let files = collect_yaml_files(dir.path()).unwrap();
assert_eq!(files.len(), 2);
assert!(files.iter().any(|f| f.file_name().unwrap() == "a.yaml"));
assert!(files.iter().any(|f| f.file_name().unwrap() == "b.yml"));
}
#[tokio::test]
async fn test_watcher_with_custom_debounce() {
let registry = Arc::new(MockToolRegistry);
let factory = Arc::new(MockModelFactory);
let loader = Arc::new(AgentConfigLoader::new(registry, factory));
let watcher = HotReloadWatcher::with_debounce(loader, Duration::from_millis(100));
assert_eq!(watcher.debounce, Duration::from_millis(100));
}
}