1use std::path::{Path, PathBuf};
8use std::sync::Arc;
9use std::time::Duration;
10
11use kvlar_core::{Engine, Policy};
12use notify::{Event, EventKind, RecommendedWatcher, RecursiveMode, Watcher};
13use tokio::sync::{RwLock, mpsc};
14use tokio::time::sleep;
15
16pub type ExtendsResolver =
21 Arc<dyn Fn(&str) -> Result<String, kvlar_core::KvlarError> + Send + Sync>;
22
23pub fn spawn_watcher(
41 engine: Arc<RwLock<Engine>>,
42 policy_paths: Vec<PathBuf>,
43 extends_resolver: Option<ExtendsResolver>,
44) -> Result<(tokio::task::JoinHandle<()>, RecommendedWatcher), Box<dyn std::error::Error>> {
45 let (tx, mut rx) = mpsc::channel::<()>(16);
46
47 let mut watcher = notify::recommended_watcher(move |result: Result<Event, notify::Error>| {
49 match result {
50 Ok(event) => {
51 if matches!(
52 event.kind,
53 EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_)
54 ) {
55 let _ = tx.try_send(());
57 }
58 }
59 Err(e) => {
60 eprintln!(" [kvlar] watch error: {}", e);
61 }
62 }
63 })?;
64
65 for path in &policy_paths {
67 let watch_path = if path.is_file() {
68 path.parent().unwrap_or(Path::new("."))
69 } else {
70 path.as_path()
71 };
72 watcher.watch(watch_path, RecursiveMode::NonRecursive)?;
73 }
74
75 let paths = policy_paths.clone();
76 eprintln!(
77 " [kvlar] watching {} policy file(s) for changes",
78 paths.len()
79 );
80
81 let handle = tokio::spawn(async move {
83 loop {
84 if rx.recv().await.is_none() {
86 break; }
88
89 sleep(Duration::from_millis(300)).await;
91 while rx.try_recv().is_ok() {}
92
93 eprintln!(" [kvlar] policy change detected, reloading...");
95 match reload_policies(&paths, extends_resolver.as_ref()) {
96 Ok(new_engine) => {
97 let rule_count = new_engine.rule_count();
98 let policy_count = new_engine.policy_count();
99 let mut eng = engine.write().await;
100 *eng = new_engine;
101 drop(eng);
102 eprintln!(
103 " [kvlar] ✓ reloaded {} policies ({} rules)",
104 policy_count, rule_count
105 );
106 }
107 Err(e) => {
108 eprintln!(" [kvlar] ✗ reload failed, keeping previous policy: {}", e);
109 }
110 }
111 }
112 });
113
114 Ok((handle, watcher))
115}
116
117fn reload_policies(
122 paths: &[PathBuf],
123 extends_resolver: Option<&ExtendsResolver>,
124) -> Result<Engine, String> {
125 let mut engine = Engine::new();
126
127 for path in paths {
128 let mut policy = Policy::from_file(path)
129 .map_err(|e| format!("failed to load {}: {}", path.display(), e))?;
130
131 if let Some(resolver) = extends_resolver {
133 policy
134 .resolve_extends(&|name| resolver(name))
135 .map_err(|e| format!("failed to resolve extends in {}: {}", path.display(), e))?;
136 }
137
138 engine.load_policy(policy);
139 }
140
141 Ok(engine)
142}
143
144#[cfg(test)]
145mod tests {
146 use super::*;
147 use std::fs;
148 use tempfile::TempDir;
149
150 fn write_policy(dir: &Path, name: &str, content: &str) -> PathBuf {
152 let path = dir.join(name);
153 fs::write(&path, content).unwrap();
154 path
155 }
156
157 const POLICY_V1: &str = r#"
158name: test-policy
159description: Version 1
160version: "1"
161rules:
162 - id: allow-read
163 description: Allow reads
164 match_on:
165 resources: ["read_file"]
166 effect:
167 type: allow
168"#;
169
170 const POLICY_V2: &str = r#"
171name: test-policy
172description: Version 2
173version: "2"
174rules:
175 - id: allow-read
176 description: Allow reads
177 match_on:
178 resources: ["read_file"]
179 effect:
180 type: allow
181 - id: deny-write
182 description: Deny writes
183 match_on:
184 resources: ["write_file"]
185 effect:
186 type: deny
187 reason: "Write denied"
188"#;
189
190 const POLICY_INVALID: &str = r#"
191name: broken
192this is not valid YAML policy: [[[
193"#;
194
195 #[tokio::test]
196 async fn test_reload_policies_success() {
197 let dir = TempDir::new().unwrap();
198 let path = write_policy(dir.path(), "policy.yaml", POLICY_V1);
199
200 let result = reload_policies(&[path], None);
201 assert!(result.is_ok());
202 let engine = result.unwrap();
203 assert_eq!(engine.rule_count(), 1);
204 }
205
206 #[tokio::test]
207 async fn test_reload_policies_invalid_keeps_error() {
208 let dir = TempDir::new().unwrap();
209 let path = write_policy(dir.path(), "bad.yaml", POLICY_INVALID);
210
211 let result = reload_policies(&[path], None);
212 assert!(result.is_err());
213 }
214
215 #[tokio::test]
216 async fn test_watcher_detects_change() {
217 let dir = TempDir::new().unwrap();
218 let path = write_policy(dir.path(), "policy.yaml", POLICY_V1);
219
220 let mut initial_engine = Engine::new();
222 let policy = Policy::from_file(&path).unwrap();
223 initial_engine.load_policy(policy);
224 assert_eq!(initial_engine.rule_count(), 1);
225
226 let engine = Arc::new(RwLock::new(initial_engine));
227
228 let (_handle, _watcher) = spawn_watcher(engine.clone(), vec![path.clone()], None).unwrap();
230
231 sleep(Duration::from_millis(100)).await;
233
234 fs::write(&path, POLICY_V2).unwrap();
236
237 sleep(Duration::from_millis(800)).await;
239
240 let eng = engine.read().await;
242 assert_eq!(
243 eng.rule_count(),
244 2,
245 "engine should have reloaded with 2 rules"
246 );
247 }
248
249 #[tokio::test]
250 async fn test_watcher_keeps_old_on_invalid() {
251 let dir = TempDir::new().unwrap();
252 let path = write_policy(dir.path(), "policy.yaml", POLICY_V1);
253
254 let mut initial_engine = Engine::new();
255 let policy = Policy::from_file(&path).unwrap();
256 initial_engine.load_policy(policy);
257
258 let engine = Arc::new(RwLock::new(initial_engine));
259
260 let (_handle, _watcher) = spawn_watcher(engine.clone(), vec![path.clone()], None).unwrap();
261
262 sleep(Duration::from_millis(100)).await;
263
264 fs::write(&path, POLICY_INVALID).unwrap();
266
267 sleep(Duration::from_millis(800)).await;
268
269 let eng = engine.read().await;
271 assert_eq!(
272 eng.rule_count(),
273 1,
274 "engine should keep old policy on parse error"
275 );
276 }
277}