Skip to main content

kvlar_proxy/
watcher.rs

1//! Policy hot-reload via filesystem watcher.
2//!
3//! Watches policy files for changes and atomically swaps the engine
4//! when valid new policies are detected. Uses `notify` for cross-platform
5//! filesystem watching with debouncing.
6
7use std::path::{Path, PathBuf};
8use std::sync::Arc;
9use std::time::Duration;
10
11use kvlar_core::{Engine, Policy};
12use notify::{Event, EventKind, RecommendedWatcher, RecursiveMode, Watcher};
13use tokio::sync::{RwLock, mpsc};
14use tokio::time::sleep;
15
16/// Callback type for resolving `extends` directives in policies.
17///
18/// Takes a template/file name and returns the YAML string content.
19/// This keeps kvlar-core pure (no I/O) — the caller provides the resolver.
20pub type ExtendsResolver =
21    Arc<dyn Fn(&str) -> Result<String, kvlar_core::KvlarError> + Send + Sync>;
22
23/// Spawns a filesystem watcher task that reloads policies on change.
24///
25/// Returns a `JoinHandle` for the watcher task and the underlying
26/// `RecommendedWatcher` (which must be kept alive for watching to continue).
27///
28/// # Arguments
29///
30/// * `engine` — Shared engine to swap on successful reload
31/// * `policy_paths` — Paths to policy YAML files to watch
32/// * `extends_resolver` — Optional callback to resolve `extends` directives
33///
34/// # Behavior
35///
36/// - Debounces filesystem events (300ms) to coalesce rapid saves
37/// - On change: re-reads all policy files, builds a new engine, swaps atomically
38/// - On parse error: logs the error to stderr, keeps the previous valid engine
39/// - All output goes to stderr (safe for stdio proxy mode)
40pub fn spawn_watcher(
41    engine: Arc<RwLock<Engine>>,
42    policy_paths: Vec<PathBuf>,
43    extends_resolver: Option<ExtendsResolver>,
44) -> Result<(tokio::task::JoinHandle<()>, RecommendedWatcher), Box<dyn std::error::Error>> {
45    let (tx, mut rx) = mpsc::channel::<()>(16);
46
47    // Set up the filesystem watcher
48    let mut watcher = notify::recommended_watcher(move |result: Result<Event, notify::Error>| {
49        match result {
50            Ok(event) => {
51                if matches!(
52                    event.kind,
53                    EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_)
54                ) {
55                    // Signal the reload task (non-blocking, drop if channel full)
56                    let _ = tx.try_send(());
57                }
58            }
59            Err(e) => {
60                eprintln!("  [kvlar] watch error: {}", e);
61            }
62        }
63    })?;
64
65    // Watch each policy file's parent directory (to catch renames/recreations)
66    for path in &policy_paths {
67        let watch_path = if path.is_file() {
68            path.parent().unwrap_or(Path::new("."))
69        } else {
70            path.as_path()
71        };
72        watcher.watch(watch_path, RecursiveMode::NonRecursive)?;
73    }
74
75    let paths = policy_paths.clone();
76    eprintln!(
77        "  [kvlar] watching {} policy file(s) for changes",
78        paths.len()
79    );
80
81    // Spawn the reload task
82    let handle = tokio::spawn(async move {
83        loop {
84            // Wait for a change signal
85            if rx.recv().await.is_none() {
86                break; // Channel closed, watcher dropped
87            }
88
89            // Debounce: drain any additional signals within 300ms
90            sleep(Duration::from_millis(300)).await;
91            while rx.try_recv().is_ok() {}
92
93            // Reload all policies
94            eprintln!("  [kvlar] policy change detected, reloading...");
95            match reload_policies(&paths, extends_resolver.as_ref()) {
96                Ok(new_engine) => {
97                    let rule_count = new_engine.rule_count();
98                    let policy_count = new_engine.policy_count();
99                    let mut eng = engine.write().await;
100                    *eng = new_engine;
101                    drop(eng);
102                    eprintln!(
103                        "  [kvlar] ✓ reloaded {} policies ({} rules)",
104                        policy_count, rule_count
105                    );
106                }
107                Err(e) => {
108                    eprintln!("  [kvlar] ✗ reload failed, keeping previous policy: {}", e);
109                }
110            }
111        }
112    });
113
114    Ok((handle, watcher))
115}
116
117/// Reads all policy files and builds a new engine.
118///
119/// If any file fails to parse, returns an error — the caller should
120/// keep the previous engine (fail-safe).
121fn reload_policies(
122    paths: &[PathBuf],
123    extends_resolver: Option<&ExtendsResolver>,
124) -> Result<Engine, String> {
125    let mut engine = Engine::new();
126
127    for path in paths {
128        let mut policy = Policy::from_file(path)
129            .map_err(|e| format!("failed to load {}: {}", path.display(), e))?;
130
131        // Resolve extends if a resolver is provided
132        if let Some(resolver) = extends_resolver {
133            policy
134                .resolve_extends(&|name| resolver(name))
135                .map_err(|e| format!("failed to resolve extends in {}: {}", path.display(), e))?;
136        }
137
138        engine.load_policy(policy);
139    }
140
141    Ok(engine)
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147    use std::fs;
148    use tempfile::TempDir;
149
150    /// Helper: create a policy file with given content.
151    fn write_policy(dir: &Path, name: &str, content: &str) -> PathBuf {
152        let path = dir.join(name);
153        fs::write(&path, content).unwrap();
154        path
155    }
156
157    const POLICY_V1: &str = r#"
158name: test-policy
159description: Version 1
160version: "1"
161rules:
162  - id: allow-read
163    description: Allow reads
164    match_on:
165      resources: ["read_file"]
166    effect:
167      type: allow
168"#;
169
170    const POLICY_V2: &str = r#"
171name: test-policy
172description: Version 2
173version: "2"
174rules:
175  - id: allow-read
176    description: Allow reads
177    match_on:
178      resources: ["read_file"]
179    effect:
180      type: allow
181  - id: deny-write
182    description: Deny writes
183    match_on:
184      resources: ["write_file"]
185    effect:
186      type: deny
187      reason: "Write denied"
188"#;
189
190    const POLICY_INVALID: &str = r#"
191name: broken
192this is not valid YAML policy: [[[
193"#;
194
195    #[tokio::test]
196    async fn test_reload_policies_success() {
197        let dir = TempDir::new().unwrap();
198        let path = write_policy(dir.path(), "policy.yaml", POLICY_V1);
199
200        let result = reload_policies(&[path], None);
201        assert!(result.is_ok());
202        let engine = result.unwrap();
203        assert_eq!(engine.rule_count(), 1);
204    }
205
206    #[tokio::test]
207    async fn test_reload_policies_invalid_keeps_error() {
208        let dir = TempDir::new().unwrap();
209        let path = write_policy(dir.path(), "bad.yaml", POLICY_INVALID);
210
211        let result = reload_policies(&[path], None);
212        assert!(result.is_err());
213    }
214
215    #[tokio::test]
216    async fn test_watcher_detects_change() {
217        let dir = TempDir::new().unwrap();
218        let path = write_policy(dir.path(), "policy.yaml", POLICY_V1);
219
220        // Create engine with v1
221        let mut initial_engine = Engine::new();
222        let policy = Policy::from_file(&path).unwrap();
223        initial_engine.load_policy(policy);
224        assert_eq!(initial_engine.rule_count(), 1);
225
226        let engine = Arc::new(RwLock::new(initial_engine));
227
228        // Start watcher
229        let (_handle, _watcher) = spawn_watcher(engine.clone(), vec![path.clone()], None).unwrap();
230
231        // Give watcher time to start
232        sleep(Duration::from_millis(100)).await;
233
234        // Write v2 (adds a second rule)
235        fs::write(&path, POLICY_V2).unwrap();
236
237        // Wait for debounce + reload (300ms debounce + buffer)
238        sleep(Duration::from_millis(800)).await;
239
240        // Engine should now have 2 rules
241        let eng = engine.read().await;
242        assert_eq!(
243            eng.rule_count(),
244            2,
245            "engine should have reloaded with 2 rules"
246        );
247    }
248
249    #[tokio::test]
250    async fn test_watcher_keeps_old_on_invalid() {
251        let dir = TempDir::new().unwrap();
252        let path = write_policy(dir.path(), "policy.yaml", POLICY_V1);
253
254        let mut initial_engine = Engine::new();
255        let policy = Policy::from_file(&path).unwrap();
256        initial_engine.load_policy(policy);
257
258        let engine = Arc::new(RwLock::new(initial_engine));
259
260        let (_handle, _watcher) = spawn_watcher(engine.clone(), vec![path.clone()], None).unwrap();
261
262        sleep(Duration::from_millis(100)).await;
263
264        // Write invalid YAML
265        fs::write(&path, POLICY_INVALID).unwrap();
266
267        sleep(Duration::from_millis(800)).await;
268
269        // Engine should still have original 1 rule
270        let eng = engine.read().await;
271        assert_eq!(
272            eng.rule_count(),
273            1,
274            "engine should keep old policy on parse error"
275        );
276    }
277}