1use std::path::{Path, PathBuf};
10use std::sync::Arc;
11use std::time::Duration;
12
13use anyhow::{Context, Result};
14use arc_swap::ArcSwap;
15use notify::{Event, EventKind, RecursiveMode, Watcher};
16use tokio::sync::mpsc;
17use tracing::{error, info, warn};
18
19use crate::build_runtime;
20use crate::config::Config;
21use crate::proxy::Runtime;
22
23pub async fn watch(path: PathBuf, runtime: Arc<ArcSwap<Runtime>>) -> Result<()> {
26 let (tx, mut rx) = mpsc::channel::<()>(8);
27
28 let target_name = path.file_name().map(|n| n.to_os_string());
33
34 let mut watcher = notify::recommended_watcher(move |res: notify::Result<Event>| match res {
38 Ok(event) if is_modifying(&event) && touches(&event, target_name.as_deref()) => {
39 let _ = tx.try_send(());
40 }
41 Ok(_) => {}
42 Err(e) => warn!(error = %e, "config watcher error"),
43 })
44 .context("creating config watcher")?;
45
46 let watch_dir = path
49 .parent()
50 .filter(|p| !p.as_os_str().is_empty())
51 .map(Path::to_path_buf)
52 .unwrap_or_else(|| PathBuf::from("."));
53 watcher
54 .watch(&watch_dir, RecursiveMode::NonRecursive)
55 .with_context(|| format!("watching {}", watch_dir.display()))?;
56
57 info!(path = %path.display(), "config hot-reload enabled");
58
59 while rx.recv().await.is_some() {
60 tokio::time::sleep(Duration::from_millis(200)).await;
62 while rx.try_recv().is_ok() {}
63
64 match reload(&path, &runtime) {
65 Ok(()) => info!(path = %path.display(), "config reloaded"),
66 Err(e) => {
67 error!(
68 error = format!("{e:#}"),
69 "config reload failed; keeping previous config"
70 )
71 }
72 }
73 }
74 Ok(())
75}
76
77fn reload(path: &Path, runtime: &ArcSwap<Runtime>) -> Result<()> {
80 let cfg = Config::load(path.to_str()).context("reloading config")?;
81 let new_runtime = build_runtime(Arc::new(cfg)).context("rebuilding runtime")?;
82 runtime.store(Arc::new(new_runtime));
83 Ok(())
84}
85
86fn is_modifying(event: &Event) -> bool {
87 matches!(event.kind, EventKind::Modify(_) | EventKind::Create(_))
88}
89
90fn touches(event: &Event, target_name: Option<&std::ffi::OsStr>) -> bool {
93 match target_name {
94 Some(name) => event.paths.iter().any(|p| p.file_name() == Some(name)),
95 None => true,
96 }
97}
98
99#[cfg(test)]
100mod tests {
101 use super::*;
102 use std::io::Write;
103
104 #[test]
105 fn reload_swaps_in_new_policy_and_rejects_bad_config() {
106 let dir = std::env::temp_dir().join(format!("edgeguard-reload-{}", std::process::id()));
107 std::fs::create_dir_all(&dir).unwrap();
108 let path = dir.join("edgeguard.toml");
109
110 std::fs::write(&path, "[ratelimit]\nenabled = false\n").unwrap();
112 let initial = build_runtime(Arc::new(Config::load(path.to_str()).unwrap())).unwrap();
113 assert!(initial.ip_limiter.is_none());
114 let swap = ArcSwap::from_pointee(initial);
115
116 std::fs::write(
118 &path,
119 "[ratelimit]\nenabled = true\nrate = \"10/sec\"\nburst = 5\n",
120 )
121 .unwrap();
122 reload(&path, &swap).unwrap();
123 assert!(swap.load().ip_limiter.is_some());
124
125 let mut f = std::fs::File::create(&path).unwrap();
127 write!(
128 f,
129 "[ratelimit]\nenabled = true\nrate = \"0/sec\"\nburst = 5\n"
130 )
131 .unwrap();
132 drop(f);
133 assert!(reload(&path, &swap).is_err());
134 assert!(swap.load().ip_limiter.is_some()); let _ = std::fs::remove_dir_all(&dir);
137 }
138}