Skip to main content

punch_kernel/
config_watcher.rs

1//! Config Hot Reload — poll-based config watcher with callback support.
2//!
3//! The [`KernelConfigWatcher`] wraps the underlying [`ConfigWatcher`] from
4//! `punch-types` and adds a poll-based mtime check, callback registration,
5//! and diff logging for the kernel layer. It distinguishes between hot-reloadable
6//! fields (rate limits, model defaults, channels, MCP servers, memory settings)
7//! and fields that require a restart (API listen address, database path, API key).
8
9use std::path::PathBuf;
10use std::sync::atomic::{AtomicU64, Ordering};
11use std::sync::Arc;
12use std::time::Duration;
13
14use tokio::sync::RwLock;
15use tokio::task::JoinHandle;
16use tracing::{debug, info, warn};
17
18use punch_types::config::PunchConfig;
19use punch_types::hot_reload::{ConfigChange, diff_configs, validate_config, ValidationSeverity};
20
21// ---------------------------------------------------------------------------
22// ConfigDiff (kernel-level summary)
23// ---------------------------------------------------------------------------
24
25/// Summary of what changed between two configs — used by callbacks to react
26/// to specific categories of changes.
27#[derive(Debug, Clone, Default)]
28pub struct KernelConfigDiff {
29    /// Whether rate limit settings changed.
30    pub rate_limit_changed: bool,
31    /// Whether the default model changed.
32    pub model_changed: bool,
33    /// Channel names that were added, removed, or modified.
34    pub channels_changed: Vec<String>,
35    /// MCP server names that were added, removed, or modified.
36    pub mcp_servers_changed: Vec<String>,
37    /// Whether memory configuration changed.
38    pub memory_changed: bool,
39    /// Non-reloadable fields that changed (require restart).
40    pub requires_restart: Vec<String>,
41}
42
43impl KernelConfigDiff {
44    /// Build a `KernelConfigDiff` from the low-level `ConfigChange` list.
45    fn from_changes(changes: &[ConfigChange]) -> Self {
46        let mut diff = Self::default();
47
48        for change in changes {
49            match change {
50                ConfigChange::RateLimitChanged { .. } => {
51                    diff.rate_limit_changed = true;
52                }
53                ConfigChange::ModelChanged { .. } => {
54                    diff.model_changed = true;
55                }
56                ConfigChange::ChannelAdded(name) | ConfigChange::ChannelRemoved(name) => {
57                    if !diff.channels_changed.contains(name) {
58                        diff.channels_changed.push(name.clone());
59                    }
60                }
61                ConfigChange::McpServerAdded(name) | ConfigChange::McpServerRemoved(name) => {
62                    if !diff.mcp_servers_changed.contains(name) {
63                        diff.mcp_servers_changed.push(name.clone());
64                    }
65                }
66                ConfigChange::MemoryConfigChanged => {
67                    diff.memory_changed = true;
68                }
69                // Non-reloadable fields.
70                ConfigChange::ListenAddressChanged { .. } => {
71                    diff.requires_restart.push("api_listen".to_string());
72                }
73                ConfigChange::ApiKeyChanged => {
74                    diff.requires_restart.push("api_key".to_string());
75                }
76            }
77        }
78
79        diff
80    }
81
82    /// Returns true if any reloadable field changed.
83    pub fn has_reloadable_changes(&self) -> bool {
84        self.rate_limit_changed
85            || self.model_changed
86            || !self.channels_changed.is_empty()
87            || !self.mcp_servers_changed.is_empty()
88            || self.memory_changed
89    }
90}
91
92// ---------------------------------------------------------------------------
93// KernelConfigWatcher
94// ---------------------------------------------------------------------------
95
96/// Type alias for the callback collection to keep clippy happy.
97type ConfigCallbacks = Arc<RwLock<Vec<Box<dyn Fn(&PunchConfig, &KernelConfigDiff) + Send + Sync>>>>;
98
99/// A poll-based config file watcher that detects changes and applies them
100/// without requiring a restart.
101///
102/// It polls the file's mtime every 5 seconds, re-reads and validates on change,
103/// and notifies registered callbacks with the new config and a diff summary.
104pub struct KernelConfigWatcher {
105    config: Arc<RwLock<PunchConfig>>,
106    config_path: PathBuf,
107    last_modified: AtomicU64,
108    callbacks: ConfigCallbacks,
109}
110
111impl KernelConfigWatcher {
112    /// Create a new watcher for the given config file path with an initial config.
113    pub fn new(config_path: PathBuf, initial_config: PunchConfig) -> Self {
114        let mtime = Self::file_mtime(&config_path).unwrap_or(0);
115
116        Self {
117            config: Arc::new(RwLock::new(initial_config)),
118            config_path,
119            last_modified: AtomicU64::new(mtime),
120            callbacks: Arc::new(RwLock::new(Vec::new())),
121        }
122    }
123
124    /// Register a callback that will be invoked when the config changes.
125    ///
126    /// Multiple callbacks can be registered. They are called in registration order
127    /// with a reference to the new config and the diff summary.
128    pub async fn on_change<F>(&self, callback: F)
129    where
130        F: Fn(&PunchConfig, &KernelConfigDiff) + Send + Sync + 'static,
131    {
132        let mut cbs = self.callbacks.write().await;
133        cbs.push(Box::new(callback));
134    }
135
136    /// Get a clone of the current config.
137    pub async fn current_config(&self) -> PunchConfig {
138        self.config.read().await.clone()
139    }
140
141    /// Get a shared reference to the underlying config Arc.
142    pub fn config_arc(&self) -> Arc<RwLock<PunchConfig>> {
143        Arc::clone(&self.config)
144    }
145
146    /// Start the poll loop. Returns a `JoinHandle` for the spawned task.
147    ///
148    /// The task checks the config file's mtime every 5 seconds. On change:
149    /// 1. Reads and parses the file as TOML
150    /// 2. Validates the new config (keeps old config on error)
151    /// 3. Computes the diff and logs changes
152    /// 4. Warns about non-reloadable changes
153    /// 5. Swaps the config under the `RwLock`
154    /// 6. Notifies all registered callbacks
155    pub fn watch(&self) -> JoinHandle<()> {
156        let config = Arc::clone(&self.config);
157        let config_path = self.config_path.clone();
158        let last_modified = self.last_modified.load(Ordering::Relaxed);
159        let last_modified_atomic = Arc::new(AtomicU64::new(last_modified));
160        let callbacks = Arc::clone(&self.callbacks);
161
162        tokio::spawn(async move {
163            let mut interval = tokio::time::interval(Duration::from_secs(5));
164            // Skip the first immediate tick.
165            interval.tick().await;
166
167            info!(path = %config_path.display(), "config poll watcher started (5s interval)");
168
169            loop {
170                interval.tick().await;
171
172                let current_mtime = match Self::file_mtime(&config_path) {
173                    Some(m) => m,
174                    None => {
175                        debug!("config file not found or inaccessible, skipping check");
176                        continue;
177                    }
178                };
179
180                let prev_mtime = last_modified_atomic.load(Ordering::Relaxed);
181                if current_mtime == prev_mtime {
182                    continue;
183                }
184
185                debug!(
186                    old_mtime = prev_mtime,
187                    new_mtime = current_mtime,
188                    "config file mtime changed, reloading"
189                );
190
191                last_modified_atomic.store(current_mtime, Ordering::Relaxed);
192
193                // Read file content.
194                let content = match tokio::fs::read_to_string(&config_path).await {
195                    Ok(c) => c,
196                    Err(e) => {
197                        warn!(error = %e, "failed to read config file during hot reload");
198                        continue;
199                    }
200                };
201
202                // Parse TOML.
203                let new_config: PunchConfig = match toml::from_str(&content) {
204                    Ok(c) => c,
205                    Err(e) => {
206                        warn!(error = %e, "config parse error during hot reload — keeping old config");
207                        continue;
208                    }
209                };
210
211                // Validate.
212                let errors: Vec<_> = validate_config(&new_config)
213                    .into_iter()
214                    .filter(|v| matches!(v.severity, ValidationSeverity::Error))
215                    .collect();
216
217                if !errors.is_empty() {
218                    for err in &errors {
219                        warn!(field = %err.field, message = %err.message, "config validation error — keeping old config");
220                    }
221                    continue;
222                }
223
224                // Compute diff.
225                let old_config = config.read().await.clone();
226                let changes = diff_configs(&old_config, &new_config);
227
228                if changes.is_empty() {
229                    debug!("config file changed (mtime) but no effective differences");
230                    continue;
231                }
232
233                let diff = KernelConfigDiff::from_changes(&changes);
234
235                // Log each change.
236                for change in &changes {
237                    info!(change = ?change, "config hot reload: change detected");
238                }
239
240                // Warn about non-reloadable fields.
241                for field in &diff.requires_restart {
242                    warn!(
243                        field = %field,
244                        "config field changed but requires restart to take effect"
245                    );
246                }
247
248                // Swap config.
249                {
250                    let mut guard = config.write().await;
251                    *guard = new_config.clone();
252                }
253
254                // Notify callbacks.
255                let cbs = callbacks.read().await;
256                for cb in cbs.iter() {
257                    cb(&new_config, &diff);
258                }
259
260                info!(
261                    num_changes = changes.len(),
262                    "config hot reload complete"
263                );
264            }
265        })
266    }
267
268    /// Read the file's mtime as epoch seconds. Returns `None` if the file
269    /// cannot be stat'd.
270    fn file_mtime(path: &PathBuf) -> Option<u64> {
271        std::fs::metadata(path)
272            .ok()
273            .and_then(|m| m.modified().ok())
274            .and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok())
275            .map(|d| d.as_secs())
276    }
277}
278
279// ---------------------------------------------------------------------------
280// Tests
281// ---------------------------------------------------------------------------
282
283#[cfg(test)]
284mod tests {
285    use super::*;
286    use punch_types::config::{MemoryConfig, ModelConfig, Provider};
287    use std::collections::HashMap;
288    use std::sync::atomic::AtomicBool;
289
290    fn make_test_config() -> PunchConfig {
291        PunchConfig {
292            api_listen: "127.0.0.1:6660".to_string(),
293            api_key: "test-key".to_string(),
294            rate_limit_rpm: 60,
295            default_model: ModelConfig {
296                provider: Provider::Anthropic,
297                model: "claude-sonnet-4-20250514".to_string(),
298                api_key_env: Some("ANTHROPIC_API_KEY".to_string()),
299                base_url: None,
300                max_tokens: Some(4096),
301                temperature: Some(0.7),
302            },
303            memory: MemoryConfig {
304                db_path: "/tmp/punch-test.db".to_string(),
305                knowledge_graph_enabled: true,
306                max_entries: Some(10000),
307            },
308            channels: HashMap::new(),
309            mcp_servers: HashMap::new(),
310        }
311    }
312
313    #[test]
314    fn kernel_config_diff_from_changes() {
315        let changes = vec![
316            ConfigChange::RateLimitChanged { old: 60, new: 120 },
317            ConfigChange::ModelChanged {
318                old_model: "a".to_string(),
319                new_model: "b".to_string(),
320            },
321            ConfigChange::ChannelAdded("slack".to_string()),
322            ConfigChange::McpServerRemoved("fs".to_string()),
323            ConfigChange::ListenAddressChanged {
324                old: "a".to_string(),
325                new: "b".to_string(),
326            },
327            ConfigChange::ApiKeyChanged,
328        ];
329
330        let diff = KernelConfigDiff::from_changes(&changes);
331        assert!(diff.rate_limit_changed);
332        assert!(diff.model_changed);
333        assert_eq!(diff.channels_changed, vec!["slack".to_string()]);
334        assert_eq!(diff.mcp_servers_changed, vec!["fs".to_string()]);
335        assert_eq!(diff.requires_restart.len(), 2);
336        assert!(diff.requires_restart.contains(&"api_listen".to_string()));
337        assert!(diff.requires_restart.contains(&"api_key".to_string()));
338    }
339
340    #[test]
341    fn kernel_config_diff_has_reloadable_changes() {
342        let empty = KernelConfigDiff::default();
343        assert!(!empty.has_reloadable_changes());
344
345        let with_rate = KernelConfigDiff {
346            rate_limit_changed: true,
347            ..Default::default()
348        };
349        assert!(with_rate.has_reloadable_changes());
350
351        let restart_only = KernelConfigDiff {
352            requires_restart: vec!["api_listen".to_string()],
353            ..Default::default()
354        };
355        assert!(!restart_only.has_reloadable_changes());
356    }
357
358    #[tokio::test]
359    async fn watch_detects_file_change() {
360        let dir = std::env::temp_dir().join(format!("punch-cfg-test-{}", uuid::Uuid::new_v4()));
361        std::fs::create_dir_all(&dir).expect("create temp dir");
362        let config_path = dir.join("punch.toml");
363
364        let initial = make_test_config();
365        let toml_str =
366            toml::to_string_pretty(&initial).expect("serialize initial config");
367        std::fs::write(&config_path, &toml_str).expect("write initial config");
368
369        let watcher = KernelConfigWatcher::new(config_path.clone(), initial.clone());
370
371        let callback_fired = Arc::new(AtomicBool::new(false));
372        let cb_flag = Arc::clone(&callback_fired);
373        watcher
374            .on_change(move |_cfg, _diff| {
375                cb_flag.store(true, Ordering::Relaxed);
376            })
377            .await;
378
379        let handle = watcher.watch();
380
381        // Wait a bit then modify the file.
382        tokio::time::sleep(Duration::from_millis(200)).await;
383
384        let mut modified = initial.clone();
385        modified.rate_limit_rpm = 120;
386        let new_toml = toml::to_string_pretty(&modified).expect("serialize modified config");
387
388        // Ensure mtime differs (some filesystems have 1s granularity).
389        tokio::time::sleep(Duration::from_secs(1)).await;
390        std::fs::write(&config_path, &new_toml).expect("write modified config");
391
392        // Wait for the poller to pick it up.
393        tokio::time::sleep(Duration::from_secs(7)).await;
394
395        assert!(
396            callback_fired.load(Ordering::Relaxed),
397            "callback should have been fired after config change"
398        );
399
400        // Verify the config was updated.
401        let current = watcher.current_config().await;
402        assert_eq!(current.rate_limit_rpm, 120);
403
404        handle.abort();
405        let _ = std::fs::remove_dir_all(&dir);
406    }
407
408    #[tokio::test]
409    async fn parse_error_keeps_old_config() {
410        let dir = std::env::temp_dir().join(format!("punch-cfg-parse-{}", uuid::Uuid::new_v4()));
411        std::fs::create_dir_all(&dir).expect("create temp dir");
412        let config_path = dir.join("punch.toml");
413
414        let initial = make_test_config();
415        let toml_str =
416            toml::to_string_pretty(&initial).expect("serialize initial config");
417        std::fs::write(&config_path, &toml_str).expect("write initial config");
418
419        let watcher = KernelConfigWatcher::new(config_path.clone(), initial.clone());
420        let handle = watcher.watch();
421
422        tokio::time::sleep(Duration::from_secs(1)).await;
423
424        // Write invalid TOML.
425        std::fs::write(&config_path, "this is not valid toml {{{}}}").expect("write bad config");
426
427        tokio::time::sleep(Duration::from_secs(7)).await;
428
429        // Config should be unchanged.
430        let current = watcher.current_config().await;
431        assert_eq!(current.rate_limit_rpm, 60);
432
433        handle.abort();
434        let _ = std::fs::remove_dir_all(&dir);
435    }
436
437    #[test]
438    fn diff_correctly_identifies_changed_fields() {
439        let old = make_test_config();
440        let mut new = old.clone();
441        new.rate_limit_rpm = 200;
442        new.default_model.model = "gpt-4o".to_string();
443
444        let changes = diff_configs(&old, &new);
445        let diff = KernelConfigDiff::from_changes(&changes);
446
447        assert!(diff.rate_limit_changed);
448        assert!(diff.model_changed);
449        assert!(diff.channels_changed.is_empty());
450        assert!(diff.mcp_servers_changed.is_empty());
451        assert!(diff.requires_restart.is_empty());
452    }
453
454    #[tokio::test]
455    async fn callback_registration_and_invocation() {
456        let config_path = PathBuf::from("/tmp/nonexistent-punch-test.toml");
457        let config = make_test_config();
458        let watcher = KernelConfigWatcher::new(config_path, config);
459
460        let counter = Arc::new(AtomicU64::new(0));
461        let c1 = Arc::clone(&counter);
462        watcher
463            .on_change(move |_cfg, _diff| {
464                c1.fetch_add(1, Ordering::Relaxed);
465            })
466            .await;
467
468        // Verify callback list has one entry.
469        let cbs = watcher.callbacks.read().await;
470        assert_eq!(cbs.len(), 1);
471    }
472
473    #[tokio::test]
474    async fn multiple_callbacks_supported() {
475        let config_path = PathBuf::from("/tmp/nonexistent-punch-multi.toml");
476        let config = make_test_config();
477        let watcher = KernelConfigWatcher::new(config_path, config);
478
479        let c1 = Arc::new(AtomicU64::new(0));
480        let c2 = Arc::new(AtomicU64::new(0));
481
482        let c1_clone = Arc::clone(&c1);
483        let c2_clone = Arc::clone(&c2);
484
485        watcher
486            .on_change(move |_cfg, _diff| {
487                c1_clone.fetch_add(1, Ordering::Relaxed);
488            })
489            .await;
490
491        watcher
492            .on_change(move |_cfg, _diff| {
493                c2_clone.fetch_add(1, Ordering::Relaxed);
494            })
495            .await;
496
497        let cbs = watcher.callbacks.read().await;
498        assert_eq!(cbs.len(), 2);
499    }
500
501    #[test]
502    fn non_reloadable_fields_logged_as_requiring_restart() {
503        let changes = vec![
504            ConfigChange::ListenAddressChanged {
505                old: "127.0.0.1:6660".to_string(),
506                new: "0.0.0.0:8080".to_string(),
507            },
508            ConfigChange::ApiKeyChanged,
509        ];
510
511        let diff = KernelConfigDiff::from_changes(&changes);
512        assert!(!diff.has_reloadable_changes());
513        assert_eq!(diff.requires_restart.len(), 2);
514    }
515
516    #[tokio::test]
517    async fn concurrent_reads_during_reload() {
518        let config = make_test_config();
519        let watcher = KernelConfigWatcher::new(PathBuf::from("/tmp/test.toml"), config);
520        let config_arc = watcher.config_arc();
521
522        // Spawn multiple concurrent readers.
523        let mut handles = Vec::new();
524        for _ in 0..10 {
525            let arc = Arc::clone(&config_arc);
526            handles.push(tokio::spawn(async move {
527                let cfg = arc.read().await;
528                assert!(!cfg.api_listen.is_empty());
529            }));
530        }
531
532        // Spawn a writer.
533        let arc_w = Arc::clone(&config_arc);
534        handles.push(tokio::spawn(async move {
535            let mut cfg = arc_w.write().await;
536            cfg.rate_limit_rpm = 999;
537        }));
538
539        for h in handles {
540            h.await.expect("task should complete");
541        }
542
543        // Verify the write took effect.
544        let final_cfg = config_arc.read().await;
545        assert_eq!(final_cfg.rate_limit_rpm, 999);
546    }
547
548    #[test]
549    fn memory_change_detected() {
550        let changes = vec![ConfigChange::MemoryConfigChanged];
551        let diff = KernelConfigDiff::from_changes(&changes);
552        assert!(diff.memory_changed);
553        assert!(diff.has_reloadable_changes());
554    }
555}