1use std::path::PathBuf;
10use std::sync::Arc;
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::time::Duration;
13
14use tokio::sync::RwLock;
15use tokio::task::JoinHandle;
16use tracing::{debug, info, warn};
17
18use punch_types::config::PunchConfig;
19use punch_types::hot_reload::{ConfigChange, ValidationSeverity, diff_configs, validate_config};
20
21#[derive(Debug, Clone, Default)]
28pub struct KernelConfigDiff {
29 pub rate_limit_changed: bool,
31 pub model_changed: bool,
33 pub channels_changed: Vec<String>,
35 pub mcp_servers_changed: Vec<String>,
37 pub memory_changed: bool,
39 pub requires_restart: Vec<String>,
41}
42
43impl KernelConfigDiff {
44 fn from_changes(changes: &[ConfigChange]) -> Self {
46 let mut diff = Self::default();
47
48 for change in changes {
49 match change {
50 ConfigChange::RateLimitChanged { .. } => {
51 diff.rate_limit_changed = true;
52 }
53 ConfigChange::ModelChanged { .. } => {
54 diff.model_changed = true;
55 }
56 ConfigChange::ChannelAdded(name) | ConfigChange::ChannelRemoved(name) => {
57 if !diff.channels_changed.contains(name) {
58 diff.channels_changed.push(name.clone());
59 }
60 }
61 ConfigChange::McpServerAdded(name) | ConfigChange::McpServerRemoved(name) => {
62 if !diff.mcp_servers_changed.contains(name) {
63 diff.mcp_servers_changed.push(name.clone());
64 }
65 }
66 ConfigChange::MemoryConfigChanged => {
67 diff.memory_changed = true;
68 }
69 ConfigChange::ListenAddressChanged { .. } => {
71 diff.requires_restart.push("api_listen".to_string());
72 }
73 ConfigChange::ApiKeyChanged => {
74 diff.requires_restart.push("api_key".to_string());
75 }
76 }
77 }
78
79 diff
80 }
81
82 pub fn has_reloadable_changes(&self) -> bool {
84 self.rate_limit_changed
85 || self.model_changed
86 || !self.channels_changed.is_empty()
87 || !self.mcp_servers_changed.is_empty()
88 || self.memory_changed
89 }
90}
91
92type ConfigCallbacks = Arc<RwLock<Vec<Box<dyn Fn(&PunchConfig, &KernelConfigDiff) + Send + Sync>>>>;
98
99pub struct KernelConfigWatcher {
105 config: Arc<RwLock<PunchConfig>>,
106 config_path: PathBuf,
107 last_modified: AtomicU64,
108 callbacks: ConfigCallbacks,
109}
110
111impl KernelConfigWatcher {
112 pub fn new(config_path: PathBuf, initial_config: PunchConfig) -> Self {
114 let mtime = Self::file_mtime(&config_path).unwrap_or(0);
115
116 Self {
117 config: Arc::new(RwLock::new(initial_config)),
118 config_path,
119 last_modified: AtomicU64::new(mtime),
120 callbacks: Arc::new(RwLock::new(Vec::new())),
121 }
122 }
123
124 pub async fn on_change<F>(&self, callback: F)
129 where
130 F: Fn(&PunchConfig, &KernelConfigDiff) + Send + Sync + 'static,
131 {
132 let mut cbs = self.callbacks.write().await;
133 cbs.push(Box::new(callback));
134 }
135
136 pub async fn current_config(&self) -> PunchConfig {
138 self.config.read().await.clone()
139 }
140
141 pub fn config_arc(&self) -> Arc<RwLock<PunchConfig>> {
143 Arc::clone(&self.config)
144 }
145
146 pub fn watch(&self) -> JoinHandle<()> {
156 let config = Arc::clone(&self.config);
157 let config_path = self.config_path.clone();
158 let last_modified = self.last_modified.load(Ordering::Relaxed);
159 let last_modified_atomic = Arc::new(AtomicU64::new(last_modified));
160 let callbacks = Arc::clone(&self.callbacks);
161
162 tokio::spawn(async move {
163 let mut interval = tokio::time::interval(Duration::from_secs(5));
164 interval.tick().await;
166
167 info!(path = %config_path.display(), "config poll watcher started (5s interval)");
168
169 loop {
170 interval.tick().await;
171
172 let current_mtime = match Self::file_mtime(&config_path) {
173 Some(m) => m,
174 None => {
175 debug!("config file not found or inaccessible, skipping check");
176 continue;
177 }
178 };
179
180 let prev_mtime = last_modified_atomic.load(Ordering::Relaxed);
181 if current_mtime == prev_mtime {
182 continue;
183 }
184
185 debug!(
186 old_mtime = prev_mtime,
187 new_mtime = current_mtime,
188 "config file mtime changed, reloading"
189 );
190
191 last_modified_atomic.store(current_mtime, Ordering::Relaxed);
192
193 let content = match tokio::fs::read_to_string(&config_path).await {
195 Ok(c) => c,
196 Err(e) => {
197 warn!(error = %e, "failed to read config file during hot reload");
198 continue;
199 }
200 };
201
202 let new_config: PunchConfig = match toml::from_str(&content) {
204 Ok(c) => c,
205 Err(e) => {
206 warn!(error = %e, "config parse error during hot reload — keeping old config");
207 continue;
208 }
209 };
210
211 let errors: Vec<_> = validate_config(&new_config)
213 .into_iter()
214 .filter(|v| matches!(v.severity, ValidationSeverity::Error))
215 .collect();
216
217 if !errors.is_empty() {
218 for err in &errors {
219 warn!(field = %err.field, message = %err.message, "config validation error — keeping old config");
220 }
221 continue;
222 }
223
224 let old_config = config.read().await.clone();
226 let changes = diff_configs(&old_config, &new_config);
227
228 if changes.is_empty() {
229 debug!("config file changed (mtime) but no effective differences");
230 continue;
231 }
232
233 let diff = KernelConfigDiff::from_changes(&changes);
234
235 for change in &changes {
237 info!(change = ?change, "config hot reload: change detected");
238 }
239
240 for field in &diff.requires_restart {
242 warn!(
243 field = %field,
244 "config field changed but requires restart to take effect"
245 );
246 }
247
248 {
250 let mut guard = config.write().await;
251 *guard = new_config.clone();
252 }
253
254 let cbs = callbacks.read().await;
256 for cb in cbs.iter() {
257 cb(&new_config, &diff);
258 }
259
260 info!(num_changes = changes.len(), "config hot reload complete");
261 }
262 })
263 }
264
265 fn file_mtime(path: &PathBuf) -> Option<u64> {
268 std::fs::metadata(path)
269 .ok()
270 .and_then(|m| m.modified().ok())
271 .and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok())
272 .map(|d| d.as_secs())
273 }
274}
275
276#[cfg(test)]
281mod tests {
282 use super::*;
283 use punch_types::config::{MemoryConfig, ModelConfig, Provider};
284 use std::collections::HashMap;
285 use std::sync::atomic::AtomicBool;
286
287 fn make_test_config() -> PunchConfig {
288 PunchConfig {
289 api_listen: "127.0.0.1:6660".to_string(),
290 api_key: "test-key".to_string(),
291 rate_limit_rpm: 60,
292 default_model: ModelConfig {
293 provider: Provider::Anthropic,
294 model: "claude-sonnet-4-20250514".to_string(),
295 api_key_env: Some("ANTHROPIC_API_KEY".to_string()),
296 base_url: None,
297 max_tokens: Some(4096),
298 temperature: Some(0.7),
299 },
300 memory: MemoryConfig {
301 db_path: "/tmp/punch-test.db".to_string(),
302 knowledge_graph_enabled: true,
303 max_entries: Some(10000),
304 },
305 tunnel: None,
306 channels: HashMap::new(),
307 mcp_servers: HashMap::new(),
308 }
309 }
310
311 #[test]
312 fn kernel_config_diff_from_changes() {
313 let changes = vec![
314 ConfigChange::RateLimitChanged { old: 60, new: 120 },
315 ConfigChange::ModelChanged {
316 old_model: "a".to_string(),
317 new_model: "b".to_string(),
318 },
319 ConfigChange::ChannelAdded("slack".to_string()),
320 ConfigChange::McpServerRemoved("fs".to_string()),
321 ConfigChange::ListenAddressChanged {
322 old: "a".to_string(),
323 new: "b".to_string(),
324 },
325 ConfigChange::ApiKeyChanged,
326 ];
327
328 let diff = KernelConfigDiff::from_changes(&changes);
329 assert!(diff.rate_limit_changed);
330 assert!(diff.model_changed);
331 assert_eq!(diff.channels_changed, vec!["slack".to_string()]);
332 assert_eq!(diff.mcp_servers_changed, vec!["fs".to_string()]);
333 assert_eq!(diff.requires_restart.len(), 2);
334 assert!(diff.requires_restart.contains(&"api_listen".to_string()));
335 assert!(diff.requires_restart.contains(&"api_key".to_string()));
336 }
337
338 #[test]
339 fn kernel_config_diff_has_reloadable_changes() {
340 let empty = KernelConfigDiff::default();
341 assert!(!empty.has_reloadable_changes());
342
343 let with_rate = KernelConfigDiff {
344 rate_limit_changed: true,
345 ..Default::default()
346 };
347 assert!(with_rate.has_reloadable_changes());
348
349 let restart_only = KernelConfigDiff {
350 requires_restart: vec!["api_listen".to_string()],
351 ..Default::default()
352 };
353 assert!(!restart_only.has_reloadable_changes());
354 }
355
356 #[tokio::test]
357 async fn watch_detects_file_change() {
358 let dir = std::env::temp_dir().join(format!("punch-cfg-test-{}", uuid::Uuid::new_v4()));
359 std::fs::create_dir_all(&dir).expect("create temp dir");
360 let config_path = dir.join("punch.toml");
361
362 let initial = make_test_config();
363 let toml_str = toml::to_string_pretty(&initial).expect("serialize initial config");
364 std::fs::write(&config_path, &toml_str).expect("write initial config");
365
366 let watcher = KernelConfigWatcher::new(config_path.clone(), initial.clone());
367
368 let callback_fired = Arc::new(AtomicBool::new(false));
369 let cb_flag = Arc::clone(&callback_fired);
370 watcher
371 .on_change(move |_cfg, _diff| {
372 cb_flag.store(true, Ordering::Relaxed);
373 })
374 .await;
375
376 let handle = watcher.watch();
377
378 tokio::time::sleep(Duration::from_millis(200)).await;
380
381 let mut modified = initial.clone();
382 modified.rate_limit_rpm = 120;
383 let new_toml = toml::to_string_pretty(&modified).expect("serialize modified config");
384
385 tokio::time::sleep(Duration::from_secs(1)).await;
387 std::fs::write(&config_path, &new_toml).expect("write modified config");
388
389 tokio::time::sleep(Duration::from_secs(7)).await;
391
392 assert!(
393 callback_fired.load(Ordering::Relaxed),
394 "callback should have been fired after config change"
395 );
396
397 let current = watcher.current_config().await;
399 assert_eq!(current.rate_limit_rpm, 120);
400
401 handle.abort();
402 let _ = std::fs::remove_dir_all(&dir);
403 }
404
405 #[tokio::test]
406 async fn parse_error_keeps_old_config() {
407 let dir = std::env::temp_dir().join(format!("punch-cfg-parse-{}", uuid::Uuid::new_v4()));
408 std::fs::create_dir_all(&dir).expect("create temp dir");
409 let config_path = dir.join("punch.toml");
410
411 let initial = make_test_config();
412 let toml_str = toml::to_string_pretty(&initial).expect("serialize initial config");
413 std::fs::write(&config_path, &toml_str).expect("write initial config");
414
415 let watcher = KernelConfigWatcher::new(config_path.clone(), initial.clone());
416 let handle = watcher.watch();
417
418 tokio::time::sleep(Duration::from_secs(1)).await;
419
420 std::fs::write(&config_path, "this is not valid toml {{{}}}").expect("write bad config");
422
423 tokio::time::sleep(Duration::from_secs(7)).await;
424
425 let current = watcher.current_config().await;
427 assert_eq!(current.rate_limit_rpm, 60);
428
429 handle.abort();
430 let _ = std::fs::remove_dir_all(&dir);
431 }
432
433 #[test]
434 fn diff_correctly_identifies_changed_fields() {
435 let old = make_test_config();
436 let mut new = old.clone();
437 new.rate_limit_rpm = 200;
438 new.default_model.model = "gpt-4o".to_string();
439
440 let changes = diff_configs(&old, &new);
441 let diff = KernelConfigDiff::from_changes(&changes);
442
443 assert!(diff.rate_limit_changed);
444 assert!(diff.model_changed);
445 assert!(diff.channels_changed.is_empty());
446 assert!(diff.mcp_servers_changed.is_empty());
447 assert!(diff.requires_restart.is_empty());
448 }
449
450 #[tokio::test]
451 async fn callback_registration_and_invocation() {
452 let config_path = PathBuf::from("/tmp/nonexistent-punch-test.toml");
453 let config = make_test_config();
454 let watcher = KernelConfigWatcher::new(config_path, config);
455
456 let counter = Arc::new(AtomicU64::new(0));
457 let c1 = Arc::clone(&counter);
458 watcher
459 .on_change(move |_cfg, _diff| {
460 c1.fetch_add(1, Ordering::Relaxed);
461 })
462 .await;
463
464 let cbs = watcher.callbacks.read().await;
466 assert_eq!(cbs.len(), 1);
467 }
468
469 #[tokio::test]
470 async fn multiple_callbacks_supported() {
471 let config_path = PathBuf::from("/tmp/nonexistent-punch-multi.toml");
472 let config = make_test_config();
473 let watcher = KernelConfigWatcher::new(config_path, config);
474
475 let c1 = Arc::new(AtomicU64::new(0));
476 let c2 = Arc::new(AtomicU64::new(0));
477
478 let c1_clone = Arc::clone(&c1);
479 let c2_clone = Arc::clone(&c2);
480
481 watcher
482 .on_change(move |_cfg, _diff| {
483 c1_clone.fetch_add(1, Ordering::Relaxed);
484 })
485 .await;
486
487 watcher
488 .on_change(move |_cfg, _diff| {
489 c2_clone.fetch_add(1, Ordering::Relaxed);
490 })
491 .await;
492
493 let cbs = watcher.callbacks.read().await;
494 assert_eq!(cbs.len(), 2);
495 }
496
497 #[test]
498 fn non_reloadable_fields_logged_as_requiring_restart() {
499 let changes = vec![
500 ConfigChange::ListenAddressChanged {
501 old: "127.0.0.1:6660".to_string(),
502 new: "0.0.0.0:8080".to_string(),
503 },
504 ConfigChange::ApiKeyChanged,
505 ];
506
507 let diff = KernelConfigDiff::from_changes(&changes);
508 assert!(!diff.has_reloadable_changes());
509 assert_eq!(diff.requires_restart.len(), 2);
510 }
511
512 #[tokio::test]
513 async fn concurrent_reads_during_reload() {
514 let config = make_test_config();
515 let watcher = KernelConfigWatcher::new(PathBuf::from("/tmp/test.toml"), config);
516 let config_arc = watcher.config_arc();
517
518 let mut handles = Vec::new();
520 for _ in 0..10 {
521 let arc = Arc::clone(&config_arc);
522 handles.push(tokio::spawn(async move {
523 let cfg = arc.read().await;
524 assert!(!cfg.api_listen.is_empty());
525 }));
526 }
527
528 let arc_w = Arc::clone(&config_arc);
530 handles.push(tokio::spawn(async move {
531 let mut cfg = arc_w.write().await;
532 cfg.rate_limit_rpm = 999;
533 }));
534
535 for h in handles {
536 h.await.expect("task should complete");
537 }
538
539 let final_cfg = config_arc.read().await;
541 assert_eq!(final_cfg.rate_limit_rpm, 999);
542 }
543
544 #[test]
545 fn memory_change_detected() {
546 let changes = vec![ConfigChange::MemoryConfigChanged];
547 let diff = KernelConfigDiff::from_changes(&changes);
548 assert!(diff.memory_changed);
549 assert!(diff.has_reloadable_changes());
550 }
551}