Skip to main content

punch_kernel/
config_watcher.rs

1//! Config Hot Reload — poll-based config watcher with callback support.
2//!
3//! The [`KernelConfigWatcher`] wraps the underlying [`ConfigWatcher`] from
4//! `punch-types` and adds a poll-based mtime check, callback registration,
5//! and diff logging for the kernel layer. It distinguishes between hot-reloadable
6//! fields (rate limits, model defaults, channels, MCP servers, memory settings)
7//! and fields that require a restart (API listen address, database path, API key).
8
9use std::path::PathBuf;
10use std::sync::Arc;
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::time::Duration;
13
14use tokio::sync::RwLock;
15use tokio::task::JoinHandle;
16use tracing::{debug, info, warn};
17
18use punch_types::config::PunchConfig;
19use punch_types::hot_reload::{ConfigChange, ValidationSeverity, diff_configs, validate_config};
20
21// ---------------------------------------------------------------------------
22// ConfigDiff (kernel-level summary)
23// ---------------------------------------------------------------------------
24
25/// Summary of what changed between two configs — used by callbacks to react
26/// to specific categories of changes.
27#[derive(Debug, Clone, Default)]
28pub struct KernelConfigDiff {
29    /// Whether rate limit settings changed.
30    pub rate_limit_changed: bool,
31    /// Whether the default model changed.
32    pub model_changed: bool,
33    /// Channel names that were added, removed, or modified.
34    pub channels_changed: Vec<String>,
35    /// MCP server names that were added, removed, or modified.
36    pub mcp_servers_changed: Vec<String>,
37    /// Whether memory configuration changed.
38    pub memory_changed: bool,
39    /// Non-reloadable fields that changed (require restart).
40    pub requires_restart: Vec<String>,
41}
42
43impl KernelConfigDiff {
44    /// Build a `KernelConfigDiff` from the low-level `ConfigChange` list.
45    fn from_changes(changes: &[ConfigChange]) -> Self {
46        let mut diff = Self::default();
47
48        for change in changes {
49            match change {
50                ConfigChange::RateLimitChanged { .. } => {
51                    diff.rate_limit_changed = true;
52                }
53                ConfigChange::ModelChanged { .. } => {
54                    diff.model_changed = true;
55                }
56                ConfigChange::ChannelAdded(name) | ConfigChange::ChannelRemoved(name) => {
57                    if !diff.channels_changed.contains(name) {
58                        diff.channels_changed.push(name.clone());
59                    }
60                }
61                ConfigChange::McpServerAdded(name) | ConfigChange::McpServerRemoved(name) => {
62                    if !diff.mcp_servers_changed.contains(name) {
63                        diff.mcp_servers_changed.push(name.clone());
64                    }
65                }
66                ConfigChange::MemoryConfigChanged => {
67                    diff.memory_changed = true;
68                }
69                // Non-reloadable fields.
70                ConfigChange::ListenAddressChanged { .. } => {
71                    diff.requires_restart.push("api_listen".to_string());
72                }
73                ConfigChange::ApiKeyChanged => {
74                    diff.requires_restart.push("api_key".to_string());
75                }
76            }
77        }
78
79        diff
80    }
81
82    /// Returns true if any reloadable field changed.
83    pub fn has_reloadable_changes(&self) -> bool {
84        self.rate_limit_changed
85            || self.model_changed
86            || !self.channels_changed.is_empty()
87            || !self.mcp_servers_changed.is_empty()
88            || self.memory_changed
89    }
90}
91
92// ---------------------------------------------------------------------------
93// KernelConfigWatcher
94// ---------------------------------------------------------------------------
95
96/// Type alias for the callback collection to keep clippy happy.
97type ConfigCallbacks = Arc<RwLock<Vec<Box<dyn Fn(&PunchConfig, &KernelConfigDiff) + Send + Sync>>>>;
98
99/// A poll-based config file watcher that detects changes and applies them
100/// without requiring a restart.
101///
102/// It polls the file's mtime every 5 seconds, re-reads and validates on change,
103/// and notifies registered callbacks with the new config and a diff summary.
104pub struct KernelConfigWatcher {
105    config: Arc<RwLock<PunchConfig>>,
106    config_path: PathBuf,
107    last_modified: AtomicU64,
108    callbacks: ConfigCallbacks,
109}
110
111impl KernelConfigWatcher {
112    /// Create a new watcher for the given config file path with an initial config.
113    pub fn new(config_path: PathBuf, initial_config: PunchConfig) -> Self {
114        let mtime = Self::file_mtime(&config_path).unwrap_or(0);
115
116        Self {
117            config: Arc::new(RwLock::new(initial_config)),
118            config_path,
119            last_modified: AtomicU64::new(mtime),
120            callbacks: Arc::new(RwLock::new(Vec::new())),
121        }
122    }
123
124    /// Register a callback that will be invoked when the config changes.
125    ///
126    /// Multiple callbacks can be registered. They are called in registration order
127    /// with a reference to the new config and the diff summary.
128    pub async fn on_change<F>(&self, callback: F)
129    where
130        F: Fn(&PunchConfig, &KernelConfigDiff) + Send + Sync + 'static,
131    {
132        let mut cbs = self.callbacks.write().await;
133        cbs.push(Box::new(callback));
134    }
135
136    /// Get a clone of the current config.
137    pub async fn current_config(&self) -> PunchConfig {
138        self.config.read().await.clone()
139    }
140
141    /// Get a shared reference to the underlying config Arc.
142    pub fn config_arc(&self) -> Arc<RwLock<PunchConfig>> {
143        Arc::clone(&self.config)
144    }
145
146    /// Start the poll loop. Returns a `JoinHandle` for the spawned task.
147    ///
148    /// The task checks the config file's mtime every 5 seconds. On change:
149    /// 1. Reads and parses the file as TOML
150    /// 2. Validates the new config (keeps old config on error)
151    /// 3. Computes the diff and logs changes
152    /// 4. Warns about non-reloadable changes
153    /// 5. Swaps the config under the `RwLock`
154    /// 6. Notifies all registered callbacks
155    pub fn watch(&self) -> JoinHandle<()> {
156        let config = Arc::clone(&self.config);
157        let config_path = self.config_path.clone();
158        let last_modified = self.last_modified.load(Ordering::Relaxed);
159        let last_modified_atomic = Arc::new(AtomicU64::new(last_modified));
160        let callbacks = Arc::clone(&self.callbacks);
161
162        tokio::spawn(async move {
163            let mut interval = tokio::time::interval(Duration::from_secs(5));
164            // Skip the first immediate tick.
165            interval.tick().await;
166
167            info!(path = %config_path.display(), "config poll watcher started (5s interval)");
168
169            loop {
170                interval.tick().await;
171
172                let current_mtime = match Self::file_mtime(&config_path) {
173                    Some(m) => m,
174                    None => {
175                        debug!("config file not found or inaccessible, skipping check");
176                        continue;
177                    }
178                };
179
180                let prev_mtime = last_modified_atomic.load(Ordering::Relaxed);
181                if current_mtime == prev_mtime {
182                    continue;
183                }
184
185                debug!(
186                    old_mtime = prev_mtime,
187                    new_mtime = current_mtime,
188                    "config file mtime changed, reloading"
189                );
190
191                last_modified_atomic.store(current_mtime, Ordering::Relaxed);
192
193                // Read file content.
194                let content = match tokio::fs::read_to_string(&config_path).await {
195                    Ok(c) => c,
196                    Err(e) => {
197                        warn!(error = %e, "failed to read config file during hot reload");
198                        continue;
199                    }
200                };
201
202                // Parse TOML.
203                let new_config: PunchConfig = match toml::from_str(&content) {
204                    Ok(c) => c,
205                    Err(e) => {
206                        warn!(error = %e, "config parse error during hot reload — keeping old config");
207                        continue;
208                    }
209                };
210
211                // Validate.
212                let errors: Vec<_> = validate_config(&new_config)
213                    .into_iter()
214                    .filter(|v| matches!(v.severity, ValidationSeverity::Error))
215                    .collect();
216
217                if !errors.is_empty() {
218                    for err in &errors {
219                        warn!(field = %err.field, message = %err.message, "config validation error — keeping old config");
220                    }
221                    continue;
222                }
223
224                // Compute diff.
225                let old_config = config.read().await.clone();
226                let changes = diff_configs(&old_config, &new_config);
227
228                if changes.is_empty() {
229                    debug!("config file changed (mtime) but no effective differences");
230                    continue;
231                }
232
233                let diff = KernelConfigDiff::from_changes(&changes);
234
235                // Log each change.
236                for change in &changes {
237                    info!(change = ?change, "config hot reload: change detected");
238                }
239
240                // Warn about non-reloadable fields.
241                for field in &diff.requires_restart {
242                    warn!(
243                        field = %field,
244                        "config field changed but requires restart to take effect"
245                    );
246                }
247
248                // Swap config.
249                {
250                    let mut guard = config.write().await;
251                    *guard = new_config.clone();
252                }
253
254                // Notify callbacks.
255                let cbs = callbacks.read().await;
256                for cb in cbs.iter() {
257                    cb(&new_config, &diff);
258                }
259
260                info!(num_changes = changes.len(), "config hot reload complete");
261            }
262        })
263    }
264
265    /// Read the file's mtime as epoch seconds. Returns `None` if the file
266    /// cannot be stat'd.
267    fn file_mtime(path: &PathBuf) -> Option<u64> {
268        std::fs::metadata(path)
269            .ok()
270            .and_then(|m| m.modified().ok())
271            .and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok())
272            .map(|d| d.as_secs())
273    }
274}
275
276// ---------------------------------------------------------------------------
277// Tests
278// ---------------------------------------------------------------------------
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283    use punch_types::config::{MemoryConfig, ModelConfig, Provider};
284    use std::collections::HashMap;
285    use std::sync::atomic::AtomicBool;
286
287    fn make_test_config() -> PunchConfig {
288        PunchConfig {
289            api_listen: "127.0.0.1:6660".to_string(),
290            api_key: "test-key".to_string(),
291            rate_limit_rpm: 60,
292            default_model: ModelConfig {
293                provider: Provider::Anthropic,
294                model: "claude-sonnet-4-20250514".to_string(),
295                api_key_env: Some("ANTHROPIC_API_KEY".to_string()),
296                base_url: None,
297                max_tokens: Some(4096),
298                temperature: Some(0.7),
299            },
300            memory: MemoryConfig {
301                db_path: "/tmp/punch-test.db".to_string(),
302                knowledge_graph_enabled: true,
303                max_entries: Some(10000),
304            },
305            tunnel: None,
306            channels: HashMap::new(),
307            mcp_servers: HashMap::new(),
308            model_routing: Default::default(),
309            budget: Default::default(),
310        }
311    }
312
313    #[test]
314    fn kernel_config_diff_from_changes() {
315        let changes = vec![
316            ConfigChange::RateLimitChanged { old: 60, new: 120 },
317            ConfigChange::ModelChanged {
318                old_model: "a".to_string(),
319                new_model: "b".to_string(),
320            },
321            ConfigChange::ChannelAdded("slack".to_string()),
322            ConfigChange::McpServerRemoved("fs".to_string()),
323            ConfigChange::ListenAddressChanged {
324                old: "a".to_string(),
325                new: "b".to_string(),
326            },
327            ConfigChange::ApiKeyChanged,
328        ];
329
330        let diff = KernelConfigDiff::from_changes(&changes);
331        assert!(diff.rate_limit_changed);
332        assert!(diff.model_changed);
333        assert_eq!(diff.channels_changed, vec!["slack".to_string()]);
334        assert_eq!(diff.mcp_servers_changed, vec!["fs".to_string()]);
335        assert_eq!(diff.requires_restart.len(), 2);
336        assert!(diff.requires_restart.contains(&"api_listen".to_string()));
337        assert!(diff.requires_restart.contains(&"api_key".to_string()));
338    }
339
340    #[test]
341    fn kernel_config_diff_has_reloadable_changes() {
342        let empty = KernelConfigDiff::default();
343        assert!(!empty.has_reloadable_changes());
344
345        let with_rate = KernelConfigDiff {
346            rate_limit_changed: true,
347            ..Default::default()
348        };
349        assert!(with_rate.has_reloadable_changes());
350
351        let restart_only = KernelConfigDiff {
352            requires_restart: vec!["api_listen".to_string()],
353            ..Default::default()
354        };
355        assert!(!restart_only.has_reloadable_changes());
356    }
357
358    #[tokio::test]
359    async fn watch_detects_file_change() {
360        let dir = std::env::temp_dir().join(format!("punch-cfg-test-{}", uuid::Uuid::new_v4()));
361        std::fs::create_dir_all(&dir).expect("create temp dir");
362        let config_path = dir.join("punch.toml");
363
364        let initial = make_test_config();
365        let toml_str = toml::to_string_pretty(&initial).expect("serialize initial config");
366        std::fs::write(&config_path, &toml_str).expect("write initial config");
367
368        let watcher = KernelConfigWatcher::new(config_path.clone(), initial.clone());
369
370        let callback_fired = Arc::new(AtomicBool::new(false));
371        let cb_flag = Arc::clone(&callback_fired);
372        watcher
373            .on_change(move |_cfg, _diff| {
374                cb_flag.store(true, Ordering::Relaxed);
375            })
376            .await;
377
378        let handle = watcher.watch();
379
380        // Wait a bit then modify the file.
381        tokio::time::sleep(Duration::from_millis(200)).await;
382
383        let mut modified = initial.clone();
384        modified.rate_limit_rpm = 120;
385        let new_toml = toml::to_string_pretty(&modified).expect("serialize modified config");
386
387        // Ensure mtime differs (some filesystems have 1s granularity).
388        tokio::time::sleep(Duration::from_secs(1)).await;
389        std::fs::write(&config_path, &new_toml).expect("write modified config");
390
391        // Wait for the poller to pick it up.
392        tokio::time::sleep(Duration::from_secs(7)).await;
393
394        assert!(
395            callback_fired.load(Ordering::Relaxed),
396            "callback should have been fired after config change"
397        );
398
399        // Verify the config was updated.
400        let current = watcher.current_config().await;
401        assert_eq!(current.rate_limit_rpm, 120);
402
403        handle.abort();
404        let _ = std::fs::remove_dir_all(&dir);
405    }
406
407    #[tokio::test]
408    async fn parse_error_keeps_old_config() {
409        let dir = std::env::temp_dir().join(format!("punch-cfg-parse-{}", uuid::Uuid::new_v4()));
410        std::fs::create_dir_all(&dir).expect("create temp dir");
411        let config_path = dir.join("punch.toml");
412
413        let initial = make_test_config();
414        let toml_str = toml::to_string_pretty(&initial).expect("serialize initial config");
415        std::fs::write(&config_path, &toml_str).expect("write initial config");
416
417        let watcher = KernelConfigWatcher::new(config_path.clone(), initial.clone());
418        let handle = watcher.watch();
419
420        tokio::time::sleep(Duration::from_secs(1)).await;
421
422        // Write invalid TOML.
423        std::fs::write(&config_path, "this is not valid toml {{{}}}").expect("write bad config");
424
425        tokio::time::sleep(Duration::from_secs(7)).await;
426
427        // Config should be unchanged.
428        let current = watcher.current_config().await;
429        assert_eq!(current.rate_limit_rpm, 60);
430
431        handle.abort();
432        let _ = std::fs::remove_dir_all(&dir);
433    }
434
435    #[test]
436    fn diff_correctly_identifies_changed_fields() {
437        let old = make_test_config();
438        let mut new = old.clone();
439        new.rate_limit_rpm = 200;
440        new.default_model.model = "gpt-4o".to_string();
441
442        let changes = diff_configs(&old, &new);
443        let diff = KernelConfigDiff::from_changes(&changes);
444
445        assert!(diff.rate_limit_changed);
446        assert!(diff.model_changed);
447        assert!(diff.channels_changed.is_empty());
448        assert!(diff.mcp_servers_changed.is_empty());
449        assert!(diff.requires_restart.is_empty());
450    }
451
452    #[tokio::test]
453    async fn callback_registration_and_invocation() {
454        let config_path = PathBuf::from("/tmp/nonexistent-punch-test.toml");
455        let config = make_test_config();
456        let watcher = KernelConfigWatcher::new(config_path, config);
457
458        let counter = Arc::new(AtomicU64::new(0));
459        let c1 = Arc::clone(&counter);
460        watcher
461            .on_change(move |_cfg, _diff| {
462                c1.fetch_add(1, Ordering::Relaxed);
463            })
464            .await;
465
466        // Verify callback list has one entry.
467        let cbs = watcher.callbacks.read().await;
468        assert_eq!(cbs.len(), 1);
469    }
470
471    #[tokio::test]
472    async fn multiple_callbacks_supported() {
473        let config_path = PathBuf::from("/tmp/nonexistent-punch-multi.toml");
474        let config = make_test_config();
475        let watcher = KernelConfigWatcher::new(config_path, config);
476
477        let c1 = Arc::new(AtomicU64::new(0));
478        let c2 = Arc::new(AtomicU64::new(0));
479
480        let c1_clone = Arc::clone(&c1);
481        let c2_clone = Arc::clone(&c2);
482
483        watcher
484            .on_change(move |_cfg, _diff| {
485                c1_clone.fetch_add(1, Ordering::Relaxed);
486            })
487            .await;
488
489        watcher
490            .on_change(move |_cfg, _diff| {
491                c2_clone.fetch_add(1, Ordering::Relaxed);
492            })
493            .await;
494
495        let cbs = watcher.callbacks.read().await;
496        assert_eq!(cbs.len(), 2);
497    }
498
499    #[test]
500    fn non_reloadable_fields_logged_as_requiring_restart() {
501        let changes = vec![
502            ConfigChange::ListenAddressChanged {
503                old: "127.0.0.1:6660".to_string(),
504                new: "0.0.0.0:8080".to_string(),
505            },
506            ConfigChange::ApiKeyChanged,
507        ];
508
509        let diff = KernelConfigDiff::from_changes(&changes);
510        assert!(!diff.has_reloadable_changes());
511        assert_eq!(diff.requires_restart.len(), 2);
512    }
513
514    #[tokio::test]
515    async fn concurrent_reads_during_reload() {
516        let config = make_test_config();
517        let watcher = KernelConfigWatcher::new(PathBuf::from("/tmp/test.toml"), config);
518        let config_arc = watcher.config_arc();
519
520        // Spawn multiple concurrent readers.
521        let mut handles = Vec::new();
522        for _ in 0..10 {
523            let arc = Arc::clone(&config_arc);
524            handles.push(tokio::spawn(async move {
525                let cfg = arc.read().await;
526                assert!(!cfg.api_listen.is_empty());
527            }));
528        }
529
530        // Spawn a writer.
531        let arc_w = Arc::clone(&config_arc);
532        handles.push(tokio::spawn(async move {
533            let mut cfg = arc_w.write().await;
534            cfg.rate_limit_rpm = 999;
535        }));
536
537        for h in handles {
538            h.await.expect("task should complete");
539        }
540
541        // Verify the write took effect.
542        let final_cfg = config_arc.read().await;
543        assert_eq!(final_cfg.rate_limit_rpm, 999);
544    }
545
546    #[test]
547    fn memory_change_detected() {
548        let changes = vec![ConfigChange::MemoryConfigChanged];
549        let diff = KernelConfigDiff::from_changes(&changes);
550        assert!(diff.memory_changed);
551        assert!(diff.has_reloadable_changes());
552    }
553}