Skip to main content

punch_kernel/
config_watcher.rs

1//! Config Hot Reload — poll-based config watcher with callback support.
2//!
3//! The [`KernelConfigWatcher`] wraps the underlying [`ConfigWatcher`] from
4//! `punch-types` and adds a poll-based mtime check, callback registration,
5//! and diff logging for the kernel layer. It distinguishes between hot-reloadable
6//! fields (rate limits, model defaults, channels, MCP servers, memory settings)
7//! and fields that require a restart (API listen address, database path, API key).
8
9use std::path::PathBuf;
10use std::sync::Arc;
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::time::Duration;
13
14use tokio::sync::RwLock;
15use tokio::task::JoinHandle;
16use tracing::{debug, info, warn};
17
18use punch_types::config::PunchConfig;
19use punch_types::hot_reload::{ConfigChange, ValidationSeverity, diff_configs, validate_config};
20
21// ---------------------------------------------------------------------------
22// ConfigDiff (kernel-level summary)
23// ---------------------------------------------------------------------------
24
25/// Summary of what changed between two configs — used by callbacks to react
26/// to specific categories of changes.
27#[derive(Debug, Clone, Default)]
28pub struct KernelConfigDiff {
29    /// Whether rate limit settings changed.
30    pub rate_limit_changed: bool,
31    /// Whether the default model changed.
32    pub model_changed: bool,
33    /// Channel names that were added, removed, or modified.
34    pub channels_changed: Vec<String>,
35    /// MCP server names that were added, removed, or modified.
36    pub mcp_servers_changed: Vec<String>,
37    /// Whether memory configuration changed.
38    pub memory_changed: bool,
39    /// Non-reloadable fields that changed (require restart).
40    pub requires_restart: Vec<String>,
41}
42
43impl KernelConfigDiff {
44    /// Build a `KernelConfigDiff` from the low-level `ConfigChange` list.
45    fn from_changes(changes: &[ConfigChange]) -> Self {
46        let mut diff = Self::default();
47
48        for change in changes {
49            match change {
50                ConfigChange::RateLimitChanged { .. } => {
51                    diff.rate_limit_changed = true;
52                }
53                ConfigChange::ModelChanged { .. } => {
54                    diff.model_changed = true;
55                }
56                ConfigChange::ChannelAdded(name) | ConfigChange::ChannelRemoved(name) => {
57                    if !diff.channels_changed.contains(name) {
58                        diff.channels_changed.push(name.clone());
59                    }
60                }
61                ConfigChange::McpServerAdded(name) | ConfigChange::McpServerRemoved(name) => {
62                    if !diff.mcp_servers_changed.contains(name) {
63                        diff.mcp_servers_changed.push(name.clone());
64                    }
65                }
66                ConfigChange::MemoryConfigChanged => {
67                    diff.memory_changed = true;
68                }
69                // Non-reloadable fields.
70                ConfigChange::ListenAddressChanged { .. } => {
71                    diff.requires_restart.push("api_listen".to_string());
72                }
73                ConfigChange::ApiKeyChanged => {
74                    diff.requires_restart.push("api_key".to_string());
75                }
76            }
77        }
78
79        diff
80    }
81
82    /// Returns true if any reloadable field changed.
83    pub fn has_reloadable_changes(&self) -> bool {
84        self.rate_limit_changed
85            || self.model_changed
86            || !self.channels_changed.is_empty()
87            || !self.mcp_servers_changed.is_empty()
88            || self.memory_changed
89    }
90}
91
92// ---------------------------------------------------------------------------
93// KernelConfigWatcher
94// ---------------------------------------------------------------------------
95
96/// Type alias for the callback collection to keep clippy happy.
97type ConfigCallbacks = Arc<RwLock<Vec<Box<dyn Fn(&PunchConfig, &KernelConfigDiff) + Send + Sync>>>>;
98
99/// A poll-based config file watcher that detects changes and applies them
100/// without requiring a restart.
101///
102/// It polls the file's mtime every 5 seconds, re-reads and validates on change,
103/// and notifies registered callbacks with the new config and a diff summary.
104pub struct KernelConfigWatcher {
105    config: Arc<RwLock<PunchConfig>>,
106    config_path: PathBuf,
107    last_modified: AtomicU64,
108    callbacks: ConfigCallbacks,
109}
110
111impl KernelConfigWatcher {
112    /// Create a new watcher for the given config file path with an initial config.
113    pub fn new(config_path: PathBuf, initial_config: PunchConfig) -> Self {
114        let mtime = Self::file_mtime(&config_path).unwrap_or(0);
115
116        Self {
117            config: Arc::new(RwLock::new(initial_config)),
118            config_path,
119            last_modified: AtomicU64::new(mtime),
120            callbacks: Arc::new(RwLock::new(Vec::new())),
121        }
122    }
123
124    /// Register a callback that will be invoked when the config changes.
125    ///
126    /// Multiple callbacks can be registered. They are called in registration order
127    /// with a reference to the new config and the diff summary.
128    pub async fn on_change<F>(&self, callback: F)
129    where
130        F: Fn(&PunchConfig, &KernelConfigDiff) + Send + Sync + 'static,
131    {
132        let mut cbs = self.callbacks.write().await;
133        cbs.push(Box::new(callback));
134    }
135
136    /// Get a clone of the current config.
137    pub async fn current_config(&self) -> PunchConfig {
138        self.config.read().await.clone()
139    }
140
141    /// Get a shared reference to the underlying config Arc.
142    pub fn config_arc(&self) -> Arc<RwLock<PunchConfig>> {
143        Arc::clone(&self.config)
144    }
145
146    /// Start the poll loop. Returns a `JoinHandle` for the spawned task.
147    ///
148    /// The task checks the config file's mtime every 5 seconds. On change:
149    /// 1. Reads and parses the file as TOML
150    /// 2. Validates the new config (keeps old config on error)
151    /// 3. Computes the diff and logs changes
152    /// 4. Warns about non-reloadable changes
153    /// 5. Swaps the config under the `RwLock`
154    /// 6. Notifies all registered callbacks
155    pub fn watch(&self) -> JoinHandle<()> {
156        let config = Arc::clone(&self.config);
157        let config_path = self.config_path.clone();
158        let last_modified = self.last_modified.load(Ordering::Relaxed);
159        let last_modified_atomic = Arc::new(AtomicU64::new(last_modified));
160        let callbacks = Arc::clone(&self.callbacks);
161
162        tokio::spawn(async move {
163            let mut interval = tokio::time::interval(Duration::from_secs(5));
164            // Skip the first immediate tick.
165            interval.tick().await;
166
167            info!(path = %config_path.display(), "config poll watcher started (5s interval)");
168
169            loop {
170                interval.tick().await;
171
172                let current_mtime = match Self::file_mtime(&config_path) {
173                    Some(m) => m,
174                    None => {
175                        debug!("config file not found or inaccessible, skipping check");
176                        continue;
177                    }
178                };
179
180                let prev_mtime = last_modified_atomic.load(Ordering::Relaxed);
181                if current_mtime == prev_mtime {
182                    continue;
183                }
184
185                debug!(
186                    old_mtime = prev_mtime,
187                    new_mtime = current_mtime,
188                    "config file mtime changed, reloading"
189                );
190
191                last_modified_atomic.store(current_mtime, Ordering::Relaxed);
192
193                // Read file content.
194                let content = match tokio::fs::read_to_string(&config_path).await {
195                    Ok(c) => c,
196                    Err(e) => {
197                        warn!(error = %e, "failed to read config file during hot reload");
198                        continue;
199                    }
200                };
201
202                // Parse TOML.
203                let new_config: PunchConfig = match toml::from_str(&content) {
204                    Ok(c) => c,
205                    Err(e) => {
206                        warn!(error = %e, "config parse error during hot reload — keeping old config");
207                        continue;
208                    }
209                };
210
211                // Validate.
212                let errors: Vec<_> = validate_config(&new_config)
213                    .into_iter()
214                    .filter(|v| matches!(v.severity, ValidationSeverity::Error))
215                    .collect();
216
217                if !errors.is_empty() {
218                    for err in &errors {
219                        warn!(field = %err.field, message = %err.message, "config validation error — keeping old config");
220                    }
221                    continue;
222                }
223
224                // Compute diff.
225                let old_config = config.read().await.clone();
226                let changes = diff_configs(&old_config, &new_config);
227
228                if changes.is_empty() {
229                    debug!("config file changed (mtime) but no effective differences");
230                    continue;
231                }
232
233                let diff = KernelConfigDiff::from_changes(&changes);
234
235                // Log each change.
236                for change in &changes {
237                    info!(change = ?change, "config hot reload: change detected");
238                }
239
240                // Warn about non-reloadable fields.
241                for field in &diff.requires_restart {
242                    warn!(
243                        field = %field,
244                        "config field changed but requires restart to take effect"
245                    );
246                }
247
248                // Swap config.
249                {
250                    let mut guard = config.write().await;
251                    *guard = new_config.clone();
252                }
253
254                // Notify callbacks.
255                let cbs = callbacks.read().await;
256                for cb in cbs.iter() {
257                    cb(&new_config, &diff);
258                }
259
260                info!(num_changes = changes.len(), "config hot reload complete");
261            }
262        })
263    }
264
265    /// Read the file's mtime as epoch seconds. Returns `None` if the file
266    /// cannot be stat'd.
267    fn file_mtime(path: &PathBuf) -> Option<u64> {
268        std::fs::metadata(path)
269            .ok()
270            .and_then(|m| m.modified().ok())
271            .and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok())
272            .map(|d| d.as_secs())
273    }
274}
275
276// ---------------------------------------------------------------------------
277// Tests
278// ---------------------------------------------------------------------------
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283    use punch_types::config::{MemoryConfig, ModelConfig, Provider};
284    use std::collections::HashMap;
285    use std::sync::atomic::AtomicBool;
286
287    fn make_test_config() -> PunchConfig {
288        PunchConfig {
289            api_listen: "127.0.0.1:6660".to_string(),
290            api_key: "test-key".to_string(),
291            rate_limit_rpm: 60,
292            default_model: ModelConfig {
293                provider: Provider::Anthropic,
294                model: "claude-sonnet-4-20250514".to_string(),
295                api_key_env: Some("ANTHROPIC_API_KEY".to_string()),
296                base_url: None,
297                max_tokens: Some(4096),
298                temperature: Some(0.7),
299            },
300            memory: MemoryConfig {
301                db_path: "/tmp/punch-test.db".to_string(),
302                knowledge_graph_enabled: true,
303                max_entries: Some(10000),
304            },
305            tunnel: None,
306            channels: HashMap::new(),
307            mcp_servers: HashMap::new(),
308        }
309    }
310
311    #[test]
312    fn kernel_config_diff_from_changes() {
313        let changes = vec![
314            ConfigChange::RateLimitChanged { old: 60, new: 120 },
315            ConfigChange::ModelChanged {
316                old_model: "a".to_string(),
317                new_model: "b".to_string(),
318            },
319            ConfigChange::ChannelAdded("slack".to_string()),
320            ConfigChange::McpServerRemoved("fs".to_string()),
321            ConfigChange::ListenAddressChanged {
322                old: "a".to_string(),
323                new: "b".to_string(),
324            },
325            ConfigChange::ApiKeyChanged,
326        ];
327
328        let diff = KernelConfigDiff::from_changes(&changes);
329        assert!(diff.rate_limit_changed);
330        assert!(diff.model_changed);
331        assert_eq!(diff.channels_changed, vec!["slack".to_string()]);
332        assert_eq!(diff.mcp_servers_changed, vec!["fs".to_string()]);
333        assert_eq!(diff.requires_restart.len(), 2);
334        assert!(diff.requires_restart.contains(&"api_listen".to_string()));
335        assert!(diff.requires_restart.contains(&"api_key".to_string()));
336    }
337
338    #[test]
339    fn kernel_config_diff_has_reloadable_changes() {
340        let empty = KernelConfigDiff::default();
341        assert!(!empty.has_reloadable_changes());
342
343        let with_rate = KernelConfigDiff {
344            rate_limit_changed: true,
345            ..Default::default()
346        };
347        assert!(with_rate.has_reloadable_changes());
348
349        let restart_only = KernelConfigDiff {
350            requires_restart: vec!["api_listen".to_string()],
351            ..Default::default()
352        };
353        assert!(!restart_only.has_reloadable_changes());
354    }
355
356    #[tokio::test]
357    async fn watch_detects_file_change() {
358        let dir = std::env::temp_dir().join(format!("punch-cfg-test-{}", uuid::Uuid::new_v4()));
359        std::fs::create_dir_all(&dir).expect("create temp dir");
360        let config_path = dir.join("punch.toml");
361
362        let initial = make_test_config();
363        let toml_str = toml::to_string_pretty(&initial).expect("serialize initial config");
364        std::fs::write(&config_path, &toml_str).expect("write initial config");
365
366        let watcher = KernelConfigWatcher::new(config_path.clone(), initial.clone());
367
368        let callback_fired = Arc::new(AtomicBool::new(false));
369        let cb_flag = Arc::clone(&callback_fired);
370        watcher
371            .on_change(move |_cfg, _diff| {
372                cb_flag.store(true, Ordering::Relaxed);
373            })
374            .await;
375
376        let handle = watcher.watch();
377
378        // Wait a bit then modify the file.
379        tokio::time::sleep(Duration::from_millis(200)).await;
380
381        let mut modified = initial.clone();
382        modified.rate_limit_rpm = 120;
383        let new_toml = toml::to_string_pretty(&modified).expect("serialize modified config");
384
385        // Ensure mtime differs (some filesystems have 1s granularity).
386        tokio::time::sleep(Duration::from_secs(1)).await;
387        std::fs::write(&config_path, &new_toml).expect("write modified config");
388
389        // Wait for the poller to pick it up.
390        tokio::time::sleep(Duration::from_secs(7)).await;
391
392        assert!(
393            callback_fired.load(Ordering::Relaxed),
394            "callback should have been fired after config change"
395        );
396
397        // Verify the config was updated.
398        let current = watcher.current_config().await;
399        assert_eq!(current.rate_limit_rpm, 120);
400
401        handle.abort();
402        let _ = std::fs::remove_dir_all(&dir);
403    }
404
405    #[tokio::test]
406    async fn parse_error_keeps_old_config() {
407        let dir = std::env::temp_dir().join(format!("punch-cfg-parse-{}", uuid::Uuid::new_v4()));
408        std::fs::create_dir_all(&dir).expect("create temp dir");
409        let config_path = dir.join("punch.toml");
410
411        let initial = make_test_config();
412        let toml_str = toml::to_string_pretty(&initial).expect("serialize initial config");
413        std::fs::write(&config_path, &toml_str).expect("write initial config");
414
415        let watcher = KernelConfigWatcher::new(config_path.clone(), initial.clone());
416        let handle = watcher.watch();
417
418        tokio::time::sleep(Duration::from_secs(1)).await;
419
420        // Write invalid TOML.
421        std::fs::write(&config_path, "this is not valid toml {{{}}}").expect("write bad config");
422
423        tokio::time::sleep(Duration::from_secs(7)).await;
424
425        // Config should be unchanged.
426        let current = watcher.current_config().await;
427        assert_eq!(current.rate_limit_rpm, 60);
428
429        handle.abort();
430        let _ = std::fs::remove_dir_all(&dir);
431    }
432
433    #[test]
434    fn diff_correctly_identifies_changed_fields() {
435        let old = make_test_config();
436        let mut new = old.clone();
437        new.rate_limit_rpm = 200;
438        new.default_model.model = "gpt-4o".to_string();
439
440        let changes = diff_configs(&old, &new);
441        let diff = KernelConfigDiff::from_changes(&changes);
442
443        assert!(diff.rate_limit_changed);
444        assert!(diff.model_changed);
445        assert!(diff.channels_changed.is_empty());
446        assert!(diff.mcp_servers_changed.is_empty());
447        assert!(diff.requires_restart.is_empty());
448    }
449
450    #[tokio::test]
451    async fn callback_registration_and_invocation() {
452        let config_path = PathBuf::from("/tmp/nonexistent-punch-test.toml");
453        let config = make_test_config();
454        let watcher = KernelConfigWatcher::new(config_path, config);
455
456        let counter = Arc::new(AtomicU64::new(0));
457        let c1 = Arc::clone(&counter);
458        watcher
459            .on_change(move |_cfg, _diff| {
460                c1.fetch_add(1, Ordering::Relaxed);
461            })
462            .await;
463
464        // Verify callback list has one entry.
465        let cbs = watcher.callbacks.read().await;
466        assert_eq!(cbs.len(), 1);
467    }
468
469    #[tokio::test]
470    async fn multiple_callbacks_supported() {
471        let config_path = PathBuf::from("/tmp/nonexistent-punch-multi.toml");
472        let config = make_test_config();
473        let watcher = KernelConfigWatcher::new(config_path, config);
474
475        let c1 = Arc::new(AtomicU64::new(0));
476        let c2 = Arc::new(AtomicU64::new(0));
477
478        let c1_clone = Arc::clone(&c1);
479        let c2_clone = Arc::clone(&c2);
480
481        watcher
482            .on_change(move |_cfg, _diff| {
483                c1_clone.fetch_add(1, Ordering::Relaxed);
484            })
485            .await;
486
487        watcher
488            .on_change(move |_cfg, _diff| {
489                c2_clone.fetch_add(1, Ordering::Relaxed);
490            })
491            .await;
492
493        let cbs = watcher.callbacks.read().await;
494        assert_eq!(cbs.len(), 2);
495    }
496
497    #[test]
498    fn non_reloadable_fields_logged_as_requiring_restart() {
499        let changes = vec![
500            ConfigChange::ListenAddressChanged {
501                old: "127.0.0.1:6660".to_string(),
502                new: "0.0.0.0:8080".to_string(),
503            },
504            ConfigChange::ApiKeyChanged,
505        ];
506
507        let diff = KernelConfigDiff::from_changes(&changes);
508        assert!(!diff.has_reloadable_changes());
509        assert_eq!(diff.requires_restart.len(), 2);
510    }
511
512    #[tokio::test]
513    async fn concurrent_reads_during_reload() {
514        let config = make_test_config();
515        let watcher = KernelConfigWatcher::new(PathBuf::from("/tmp/test.toml"), config);
516        let config_arc = watcher.config_arc();
517
518        // Spawn multiple concurrent readers.
519        let mut handles = Vec::new();
520        for _ in 0..10 {
521            let arc = Arc::clone(&config_arc);
522            handles.push(tokio::spawn(async move {
523                let cfg = arc.read().await;
524                assert!(!cfg.api_listen.is_empty());
525            }));
526        }
527
528        // Spawn a writer.
529        let arc_w = Arc::clone(&config_arc);
530        handles.push(tokio::spawn(async move {
531            let mut cfg = arc_w.write().await;
532            cfg.rate_limit_rpm = 999;
533        }));
534
535        for h in handles {
536            h.await.expect("task should complete");
537        }
538
539        // Verify the write took effect.
540        let final_cfg = config_arc.read().await;
541        assert_eq!(final_cfg.rate_limit_rpm, 999);
542    }
543
544    #[test]
545    fn memory_change_detected() {
546        let changes = vec![ConfigChange::MemoryConfigChanged];
547        let diff = KernelConfigDiff::from_changes(&changes);
548        assert!(diff.memory_changed);
549        assert!(diff.has_reloadable_changes());
550    }
551}