1use std::path::PathBuf;
10use std::sync::atomic::{AtomicU64, Ordering};
11use std::sync::Arc;
12use std::time::Duration;
13
14use tokio::sync::RwLock;
15use tokio::task::JoinHandle;
16use tracing::{debug, info, warn};
17
18use punch_types::config::PunchConfig;
19use punch_types::hot_reload::{ConfigChange, diff_configs, validate_config, ValidationSeverity};
20
21#[derive(Debug, Clone, Default)]
28pub struct KernelConfigDiff {
29 pub rate_limit_changed: bool,
31 pub model_changed: bool,
33 pub channels_changed: Vec<String>,
35 pub mcp_servers_changed: Vec<String>,
37 pub memory_changed: bool,
39 pub requires_restart: Vec<String>,
41}
42
43impl KernelConfigDiff {
44 fn from_changes(changes: &[ConfigChange]) -> Self {
46 let mut diff = Self::default();
47
48 for change in changes {
49 match change {
50 ConfigChange::RateLimitChanged { .. } => {
51 diff.rate_limit_changed = true;
52 }
53 ConfigChange::ModelChanged { .. } => {
54 diff.model_changed = true;
55 }
56 ConfigChange::ChannelAdded(name) | ConfigChange::ChannelRemoved(name) => {
57 if !diff.channels_changed.contains(name) {
58 diff.channels_changed.push(name.clone());
59 }
60 }
61 ConfigChange::McpServerAdded(name) | ConfigChange::McpServerRemoved(name) => {
62 if !diff.mcp_servers_changed.contains(name) {
63 diff.mcp_servers_changed.push(name.clone());
64 }
65 }
66 ConfigChange::MemoryConfigChanged => {
67 diff.memory_changed = true;
68 }
69 ConfigChange::ListenAddressChanged { .. } => {
71 diff.requires_restart.push("api_listen".to_string());
72 }
73 ConfigChange::ApiKeyChanged => {
74 diff.requires_restart.push("api_key".to_string());
75 }
76 }
77 }
78
79 diff
80 }
81
82 pub fn has_reloadable_changes(&self) -> bool {
84 self.rate_limit_changed
85 || self.model_changed
86 || !self.channels_changed.is_empty()
87 || !self.mcp_servers_changed.is_empty()
88 || self.memory_changed
89 }
90}
91
92type ConfigCallbacks = Arc<RwLock<Vec<Box<dyn Fn(&PunchConfig, &KernelConfigDiff) + Send + Sync>>>>;
98
99pub struct KernelConfigWatcher {
105 config: Arc<RwLock<PunchConfig>>,
106 config_path: PathBuf,
107 last_modified: AtomicU64,
108 callbacks: ConfigCallbacks,
109}
110
111impl KernelConfigWatcher {
112 pub fn new(config_path: PathBuf, initial_config: PunchConfig) -> Self {
114 let mtime = Self::file_mtime(&config_path).unwrap_or(0);
115
116 Self {
117 config: Arc::new(RwLock::new(initial_config)),
118 config_path,
119 last_modified: AtomicU64::new(mtime),
120 callbacks: Arc::new(RwLock::new(Vec::new())),
121 }
122 }
123
124 pub async fn on_change<F>(&self, callback: F)
129 where
130 F: Fn(&PunchConfig, &KernelConfigDiff) + Send + Sync + 'static,
131 {
132 let mut cbs = self.callbacks.write().await;
133 cbs.push(Box::new(callback));
134 }
135
136 pub async fn current_config(&self) -> PunchConfig {
138 self.config.read().await.clone()
139 }
140
141 pub fn config_arc(&self) -> Arc<RwLock<PunchConfig>> {
143 Arc::clone(&self.config)
144 }
145
146 pub fn watch(&self) -> JoinHandle<()> {
156 let config = Arc::clone(&self.config);
157 let config_path = self.config_path.clone();
158 let last_modified = self.last_modified.load(Ordering::Relaxed);
159 let last_modified_atomic = Arc::new(AtomicU64::new(last_modified));
160 let callbacks = Arc::clone(&self.callbacks);
161
162 tokio::spawn(async move {
163 let mut interval = tokio::time::interval(Duration::from_secs(5));
164 interval.tick().await;
166
167 info!(path = %config_path.display(), "config poll watcher started (5s interval)");
168
169 loop {
170 interval.tick().await;
171
172 let current_mtime = match Self::file_mtime(&config_path) {
173 Some(m) => m,
174 None => {
175 debug!("config file not found or inaccessible, skipping check");
176 continue;
177 }
178 };
179
180 let prev_mtime = last_modified_atomic.load(Ordering::Relaxed);
181 if current_mtime == prev_mtime {
182 continue;
183 }
184
185 debug!(
186 old_mtime = prev_mtime,
187 new_mtime = current_mtime,
188 "config file mtime changed, reloading"
189 );
190
191 last_modified_atomic.store(current_mtime, Ordering::Relaxed);
192
193 let content = match tokio::fs::read_to_string(&config_path).await {
195 Ok(c) => c,
196 Err(e) => {
197 warn!(error = %e, "failed to read config file during hot reload");
198 continue;
199 }
200 };
201
202 let new_config: PunchConfig = match toml::from_str(&content) {
204 Ok(c) => c,
205 Err(e) => {
206 warn!(error = %e, "config parse error during hot reload — keeping old config");
207 continue;
208 }
209 };
210
211 let errors: Vec<_> = validate_config(&new_config)
213 .into_iter()
214 .filter(|v| matches!(v.severity, ValidationSeverity::Error))
215 .collect();
216
217 if !errors.is_empty() {
218 for err in &errors {
219 warn!(field = %err.field, message = %err.message, "config validation error — keeping old config");
220 }
221 continue;
222 }
223
224 let old_config = config.read().await.clone();
226 let changes = diff_configs(&old_config, &new_config);
227
228 if changes.is_empty() {
229 debug!("config file changed (mtime) but no effective differences");
230 continue;
231 }
232
233 let diff = KernelConfigDiff::from_changes(&changes);
234
235 for change in &changes {
237 info!(change = ?change, "config hot reload: change detected");
238 }
239
240 for field in &diff.requires_restart {
242 warn!(
243 field = %field,
244 "config field changed but requires restart to take effect"
245 );
246 }
247
248 {
250 let mut guard = config.write().await;
251 *guard = new_config.clone();
252 }
253
254 let cbs = callbacks.read().await;
256 for cb in cbs.iter() {
257 cb(&new_config, &diff);
258 }
259
260 info!(
261 num_changes = changes.len(),
262 "config hot reload complete"
263 );
264 }
265 })
266 }
267
268 fn file_mtime(path: &PathBuf) -> Option<u64> {
271 std::fs::metadata(path)
272 .ok()
273 .and_then(|m| m.modified().ok())
274 .and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok())
275 .map(|d| d.as_secs())
276 }
277}
278
279#[cfg(test)]
284mod tests {
285 use super::*;
286 use punch_types::config::{MemoryConfig, ModelConfig, Provider};
287 use std::collections::HashMap;
288 use std::sync::atomic::AtomicBool;
289
290 fn make_test_config() -> PunchConfig {
291 PunchConfig {
292 api_listen: "127.0.0.1:6660".to_string(),
293 api_key: "test-key".to_string(),
294 rate_limit_rpm: 60,
295 default_model: ModelConfig {
296 provider: Provider::Anthropic,
297 model: "claude-sonnet-4-20250514".to_string(),
298 api_key_env: Some("ANTHROPIC_API_KEY".to_string()),
299 base_url: None,
300 max_tokens: Some(4096),
301 temperature: Some(0.7),
302 },
303 memory: MemoryConfig {
304 db_path: "/tmp/punch-test.db".to_string(),
305 knowledge_graph_enabled: true,
306 max_entries: Some(10000),
307 },
308 channels: HashMap::new(),
309 mcp_servers: HashMap::new(),
310 }
311 }
312
313 #[test]
314 fn kernel_config_diff_from_changes() {
315 let changes = vec![
316 ConfigChange::RateLimitChanged { old: 60, new: 120 },
317 ConfigChange::ModelChanged {
318 old_model: "a".to_string(),
319 new_model: "b".to_string(),
320 },
321 ConfigChange::ChannelAdded("slack".to_string()),
322 ConfigChange::McpServerRemoved("fs".to_string()),
323 ConfigChange::ListenAddressChanged {
324 old: "a".to_string(),
325 new: "b".to_string(),
326 },
327 ConfigChange::ApiKeyChanged,
328 ];
329
330 let diff = KernelConfigDiff::from_changes(&changes);
331 assert!(diff.rate_limit_changed);
332 assert!(diff.model_changed);
333 assert_eq!(diff.channels_changed, vec!["slack".to_string()]);
334 assert_eq!(diff.mcp_servers_changed, vec!["fs".to_string()]);
335 assert_eq!(diff.requires_restart.len(), 2);
336 assert!(diff.requires_restart.contains(&"api_listen".to_string()));
337 assert!(diff.requires_restart.contains(&"api_key".to_string()));
338 }
339
340 #[test]
341 fn kernel_config_diff_has_reloadable_changes() {
342 let empty = KernelConfigDiff::default();
343 assert!(!empty.has_reloadable_changes());
344
345 let with_rate = KernelConfigDiff {
346 rate_limit_changed: true,
347 ..Default::default()
348 };
349 assert!(with_rate.has_reloadable_changes());
350
351 let restart_only = KernelConfigDiff {
352 requires_restart: vec!["api_listen".to_string()],
353 ..Default::default()
354 };
355 assert!(!restart_only.has_reloadable_changes());
356 }
357
358 #[tokio::test]
359 async fn watch_detects_file_change() {
360 let dir = std::env::temp_dir().join(format!("punch-cfg-test-{}", uuid::Uuid::new_v4()));
361 std::fs::create_dir_all(&dir).expect("create temp dir");
362 let config_path = dir.join("punch.toml");
363
364 let initial = make_test_config();
365 let toml_str =
366 toml::to_string_pretty(&initial).expect("serialize initial config");
367 std::fs::write(&config_path, &toml_str).expect("write initial config");
368
369 let watcher = KernelConfigWatcher::new(config_path.clone(), initial.clone());
370
371 let callback_fired = Arc::new(AtomicBool::new(false));
372 let cb_flag = Arc::clone(&callback_fired);
373 watcher
374 .on_change(move |_cfg, _diff| {
375 cb_flag.store(true, Ordering::Relaxed);
376 })
377 .await;
378
379 let handle = watcher.watch();
380
381 tokio::time::sleep(Duration::from_millis(200)).await;
383
384 let mut modified = initial.clone();
385 modified.rate_limit_rpm = 120;
386 let new_toml = toml::to_string_pretty(&modified).expect("serialize modified config");
387
388 tokio::time::sleep(Duration::from_secs(1)).await;
390 std::fs::write(&config_path, &new_toml).expect("write modified config");
391
392 tokio::time::sleep(Duration::from_secs(7)).await;
394
395 assert!(
396 callback_fired.load(Ordering::Relaxed),
397 "callback should have been fired after config change"
398 );
399
400 let current = watcher.current_config().await;
402 assert_eq!(current.rate_limit_rpm, 120);
403
404 handle.abort();
405 let _ = std::fs::remove_dir_all(&dir);
406 }
407
408 #[tokio::test]
409 async fn parse_error_keeps_old_config() {
410 let dir = std::env::temp_dir().join(format!("punch-cfg-parse-{}", uuid::Uuid::new_v4()));
411 std::fs::create_dir_all(&dir).expect("create temp dir");
412 let config_path = dir.join("punch.toml");
413
414 let initial = make_test_config();
415 let toml_str =
416 toml::to_string_pretty(&initial).expect("serialize initial config");
417 std::fs::write(&config_path, &toml_str).expect("write initial config");
418
419 let watcher = KernelConfigWatcher::new(config_path.clone(), initial.clone());
420 let handle = watcher.watch();
421
422 tokio::time::sleep(Duration::from_secs(1)).await;
423
424 std::fs::write(&config_path, "this is not valid toml {{{}}}").expect("write bad config");
426
427 tokio::time::sleep(Duration::from_secs(7)).await;
428
429 let current = watcher.current_config().await;
431 assert_eq!(current.rate_limit_rpm, 60);
432
433 handle.abort();
434 let _ = std::fs::remove_dir_all(&dir);
435 }
436
437 #[test]
438 fn diff_correctly_identifies_changed_fields() {
439 let old = make_test_config();
440 let mut new = old.clone();
441 new.rate_limit_rpm = 200;
442 new.default_model.model = "gpt-4o".to_string();
443
444 let changes = diff_configs(&old, &new);
445 let diff = KernelConfigDiff::from_changes(&changes);
446
447 assert!(diff.rate_limit_changed);
448 assert!(diff.model_changed);
449 assert!(diff.channels_changed.is_empty());
450 assert!(diff.mcp_servers_changed.is_empty());
451 assert!(diff.requires_restart.is_empty());
452 }
453
454 #[tokio::test]
455 async fn callback_registration_and_invocation() {
456 let config_path = PathBuf::from("/tmp/nonexistent-punch-test.toml");
457 let config = make_test_config();
458 let watcher = KernelConfigWatcher::new(config_path, config);
459
460 let counter = Arc::new(AtomicU64::new(0));
461 let c1 = Arc::clone(&counter);
462 watcher
463 .on_change(move |_cfg, _diff| {
464 c1.fetch_add(1, Ordering::Relaxed);
465 })
466 .await;
467
468 let cbs = watcher.callbacks.read().await;
470 assert_eq!(cbs.len(), 1);
471 }
472
473 #[tokio::test]
474 async fn multiple_callbacks_supported() {
475 let config_path = PathBuf::from("/tmp/nonexistent-punch-multi.toml");
476 let config = make_test_config();
477 let watcher = KernelConfigWatcher::new(config_path, config);
478
479 let c1 = Arc::new(AtomicU64::new(0));
480 let c2 = Arc::new(AtomicU64::new(0));
481
482 let c1_clone = Arc::clone(&c1);
483 let c2_clone = Arc::clone(&c2);
484
485 watcher
486 .on_change(move |_cfg, _diff| {
487 c1_clone.fetch_add(1, Ordering::Relaxed);
488 })
489 .await;
490
491 watcher
492 .on_change(move |_cfg, _diff| {
493 c2_clone.fetch_add(1, Ordering::Relaxed);
494 })
495 .await;
496
497 let cbs = watcher.callbacks.read().await;
498 assert_eq!(cbs.len(), 2);
499 }
500
501 #[test]
502 fn non_reloadable_fields_logged_as_requiring_restart() {
503 let changes = vec![
504 ConfigChange::ListenAddressChanged {
505 old: "127.0.0.1:6660".to_string(),
506 new: "0.0.0.0:8080".to_string(),
507 },
508 ConfigChange::ApiKeyChanged,
509 ];
510
511 let diff = KernelConfigDiff::from_changes(&changes);
512 assert!(!diff.has_reloadable_changes());
513 assert_eq!(diff.requires_restart.len(), 2);
514 }
515
516 #[tokio::test]
517 async fn concurrent_reads_during_reload() {
518 let config = make_test_config();
519 let watcher = KernelConfigWatcher::new(PathBuf::from("/tmp/test.toml"), config);
520 let config_arc = watcher.config_arc();
521
522 let mut handles = Vec::new();
524 for _ in 0..10 {
525 let arc = Arc::clone(&config_arc);
526 handles.push(tokio::spawn(async move {
527 let cfg = arc.read().await;
528 assert!(!cfg.api_listen.is_empty());
529 }));
530 }
531
532 let arc_w = Arc::clone(&config_arc);
534 handles.push(tokio::spawn(async move {
535 let mut cfg = arc_w.write().await;
536 cfg.rate_limit_rpm = 999;
537 }));
538
539 for h in handles {
540 h.await.expect("task should complete");
541 }
542
543 let final_cfg = config_arc.read().await;
545 assert_eq!(final_cfg.rate_limit_rpm, 999);
546 }
547
548 #[test]
549 fn memory_change_detected() {
550 let changes = vec![ConfigChange::MemoryConfigChanged];
551 let diff = KernelConfigDiff::from_changes(&changes);
552 assert!(diff.memory_changed);
553 assert!(diff.has_reloadable_changes());
554 }
555}