1use std::path::PathBuf;
10use std::sync::Arc;
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::time::Duration;
13
14use tokio::sync::RwLock;
15use tokio::task::JoinHandle;
16use tracing::{debug, info, warn};
17
18use punch_types::config::PunchConfig;
19use punch_types::hot_reload::{ConfigChange, ValidationSeverity, diff_configs, validate_config};
20
21#[derive(Debug, Clone, Default)]
28pub struct KernelConfigDiff {
29 pub rate_limit_changed: bool,
31 pub model_changed: bool,
33 pub channels_changed: Vec<String>,
35 pub mcp_servers_changed: Vec<String>,
37 pub memory_changed: bool,
39 pub requires_restart: Vec<String>,
41}
42
43impl KernelConfigDiff {
44 fn from_changes(changes: &[ConfigChange]) -> Self {
46 let mut diff = Self::default();
47
48 for change in changes {
49 match change {
50 ConfigChange::RateLimitChanged { .. } => {
51 diff.rate_limit_changed = true;
52 }
53 ConfigChange::ModelChanged { .. } => {
54 diff.model_changed = true;
55 }
56 ConfigChange::ChannelAdded(name) | ConfigChange::ChannelRemoved(name) => {
57 if !diff.channels_changed.contains(name) {
58 diff.channels_changed.push(name.clone());
59 }
60 }
61 ConfigChange::McpServerAdded(name) | ConfigChange::McpServerRemoved(name) => {
62 if !diff.mcp_servers_changed.contains(name) {
63 diff.mcp_servers_changed.push(name.clone());
64 }
65 }
66 ConfigChange::MemoryConfigChanged => {
67 diff.memory_changed = true;
68 }
69 ConfigChange::ListenAddressChanged { .. } => {
71 diff.requires_restart.push("api_listen".to_string());
72 }
73 ConfigChange::ApiKeyChanged => {
74 diff.requires_restart.push("api_key".to_string());
75 }
76 }
77 }
78
79 diff
80 }
81
82 pub fn has_reloadable_changes(&self) -> bool {
84 self.rate_limit_changed
85 || self.model_changed
86 || !self.channels_changed.is_empty()
87 || !self.mcp_servers_changed.is_empty()
88 || self.memory_changed
89 }
90}
91
92type ConfigCallbacks = Arc<RwLock<Vec<Box<dyn Fn(&PunchConfig, &KernelConfigDiff) + Send + Sync>>>>;
98
99pub struct KernelConfigWatcher {
105 config: Arc<RwLock<PunchConfig>>,
106 config_path: PathBuf,
107 last_modified: AtomicU64,
108 callbacks: ConfigCallbacks,
109}
110
111impl KernelConfigWatcher {
112 pub fn new(config_path: PathBuf, initial_config: PunchConfig) -> Self {
114 let mtime = Self::file_mtime(&config_path).unwrap_or(0);
115
116 Self {
117 config: Arc::new(RwLock::new(initial_config)),
118 config_path,
119 last_modified: AtomicU64::new(mtime),
120 callbacks: Arc::new(RwLock::new(Vec::new())),
121 }
122 }
123
124 pub async fn on_change<F>(&self, callback: F)
129 where
130 F: Fn(&PunchConfig, &KernelConfigDiff) + Send + Sync + 'static,
131 {
132 let mut cbs = self.callbacks.write().await;
133 cbs.push(Box::new(callback));
134 }
135
136 pub async fn current_config(&self) -> PunchConfig {
138 self.config.read().await.clone()
139 }
140
141 pub fn config_arc(&self) -> Arc<RwLock<PunchConfig>> {
143 Arc::clone(&self.config)
144 }
145
146 pub fn watch(&self) -> JoinHandle<()> {
156 let config = Arc::clone(&self.config);
157 let config_path = self.config_path.clone();
158 let last_modified = self.last_modified.load(Ordering::Relaxed);
159 let last_modified_atomic = Arc::new(AtomicU64::new(last_modified));
160 let callbacks = Arc::clone(&self.callbacks);
161
162 tokio::spawn(async move {
163 let mut interval = tokio::time::interval(Duration::from_secs(5));
164 interval.tick().await;
166
167 info!(path = %config_path.display(), "config poll watcher started (5s interval)");
168
169 loop {
170 interval.tick().await;
171
172 let current_mtime = match Self::file_mtime(&config_path) {
173 Some(m) => m,
174 None => {
175 debug!("config file not found or inaccessible, skipping check");
176 continue;
177 }
178 };
179
180 let prev_mtime = last_modified_atomic.load(Ordering::Relaxed);
181 if current_mtime == prev_mtime {
182 continue;
183 }
184
185 debug!(
186 old_mtime = prev_mtime,
187 new_mtime = current_mtime,
188 "config file mtime changed, reloading"
189 );
190
191 last_modified_atomic.store(current_mtime, Ordering::Relaxed);
192
193 let content = match tokio::fs::read_to_string(&config_path).await {
195 Ok(c) => c,
196 Err(e) => {
197 warn!(error = %e, "failed to read config file during hot reload");
198 continue;
199 }
200 };
201
202 let new_config: PunchConfig = match toml::from_str(&content) {
204 Ok(c) => c,
205 Err(e) => {
206 warn!(error = %e, "config parse error during hot reload — keeping old config");
207 continue;
208 }
209 };
210
211 let errors: Vec<_> = validate_config(&new_config)
213 .into_iter()
214 .filter(|v| matches!(v.severity, ValidationSeverity::Error))
215 .collect();
216
217 if !errors.is_empty() {
218 for err in &errors {
219 warn!(field = %err.field, message = %err.message, "config validation error — keeping old config");
220 }
221 continue;
222 }
223
224 let old_config = config.read().await.clone();
226 let changes = diff_configs(&old_config, &new_config);
227
228 if changes.is_empty() {
229 debug!("config file changed (mtime) but no effective differences");
230 continue;
231 }
232
233 let diff = KernelConfigDiff::from_changes(&changes);
234
235 for change in &changes {
237 info!(change = ?change, "config hot reload: change detected");
238 }
239
240 for field in &diff.requires_restart {
242 warn!(
243 field = %field,
244 "config field changed but requires restart to take effect"
245 );
246 }
247
248 {
250 let mut guard = config.write().await;
251 *guard = new_config.clone();
252 }
253
254 let cbs = callbacks.read().await;
256 for cb in cbs.iter() {
257 cb(&new_config, &diff);
258 }
259
260 info!(num_changes = changes.len(), "config hot reload complete");
261 }
262 })
263 }
264
265 fn file_mtime(path: &PathBuf) -> Option<u64> {
268 std::fs::metadata(path)
269 .ok()
270 .and_then(|m| m.modified().ok())
271 .and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok())
272 .map(|d| d.as_secs())
273 }
274}
275
276#[cfg(test)]
281mod tests {
282 use super::*;
283 use punch_types::config::{MemoryConfig, ModelConfig, Provider};
284 use std::collections::HashMap;
285 use std::sync::atomic::AtomicBool;
286
287 fn make_test_config() -> PunchConfig {
288 PunchConfig {
289 api_listen: "127.0.0.1:6660".to_string(),
290 api_key: "test-key".to_string(),
291 rate_limit_rpm: 60,
292 default_model: ModelConfig {
293 provider: Provider::Anthropic,
294 model: "claude-sonnet-4-20250514".to_string(),
295 api_key_env: Some("ANTHROPIC_API_KEY".to_string()),
296 base_url: None,
297 max_tokens: Some(4096),
298 temperature: Some(0.7),
299 },
300 memory: MemoryConfig {
301 db_path: "/tmp/punch-test.db".to_string(),
302 knowledge_graph_enabled: true,
303 max_entries: Some(10000),
304 },
305 tunnel: None,
306 channels: HashMap::new(),
307 mcp_servers: HashMap::new(),
308 model_routing: Default::default(),
309 budget: Default::default(),
310 }
311 }
312
313 #[test]
314 fn kernel_config_diff_from_changes() {
315 let changes = vec![
316 ConfigChange::RateLimitChanged { old: 60, new: 120 },
317 ConfigChange::ModelChanged {
318 old_model: "a".to_string(),
319 new_model: "b".to_string(),
320 },
321 ConfigChange::ChannelAdded("slack".to_string()),
322 ConfigChange::McpServerRemoved("fs".to_string()),
323 ConfigChange::ListenAddressChanged {
324 old: "a".to_string(),
325 new: "b".to_string(),
326 },
327 ConfigChange::ApiKeyChanged,
328 ];
329
330 let diff = KernelConfigDiff::from_changes(&changes);
331 assert!(diff.rate_limit_changed);
332 assert!(diff.model_changed);
333 assert_eq!(diff.channels_changed, vec!["slack".to_string()]);
334 assert_eq!(diff.mcp_servers_changed, vec!["fs".to_string()]);
335 assert_eq!(diff.requires_restart.len(), 2);
336 assert!(diff.requires_restart.contains(&"api_listen".to_string()));
337 assert!(diff.requires_restart.contains(&"api_key".to_string()));
338 }
339
340 #[test]
341 fn kernel_config_diff_has_reloadable_changes() {
342 let empty = KernelConfigDiff::default();
343 assert!(!empty.has_reloadable_changes());
344
345 let with_rate = KernelConfigDiff {
346 rate_limit_changed: true,
347 ..Default::default()
348 };
349 assert!(with_rate.has_reloadable_changes());
350
351 let restart_only = KernelConfigDiff {
352 requires_restart: vec!["api_listen".to_string()],
353 ..Default::default()
354 };
355 assert!(!restart_only.has_reloadable_changes());
356 }
357
358 #[tokio::test]
359 async fn watch_detects_file_change() {
360 let dir = std::env::temp_dir().join(format!("punch-cfg-test-{}", uuid::Uuid::new_v4()));
361 std::fs::create_dir_all(&dir).expect("create temp dir");
362 let config_path = dir.join("punch.toml");
363
364 let initial = make_test_config();
365 let toml_str = toml::to_string_pretty(&initial).expect("serialize initial config");
366 std::fs::write(&config_path, &toml_str).expect("write initial config");
367
368 let watcher = KernelConfigWatcher::new(config_path.clone(), initial.clone());
369
370 let callback_fired = Arc::new(AtomicBool::new(false));
371 let cb_flag = Arc::clone(&callback_fired);
372 watcher
373 .on_change(move |_cfg, _diff| {
374 cb_flag.store(true, Ordering::Relaxed);
375 })
376 .await;
377
378 let handle = watcher.watch();
379
380 tokio::time::sleep(Duration::from_millis(200)).await;
382
383 let mut modified = initial.clone();
384 modified.rate_limit_rpm = 120;
385 let new_toml = toml::to_string_pretty(&modified).expect("serialize modified config");
386
387 tokio::time::sleep(Duration::from_secs(1)).await;
389 std::fs::write(&config_path, &new_toml).expect("write modified config");
390
391 tokio::time::sleep(Duration::from_secs(7)).await;
393
394 assert!(
395 callback_fired.load(Ordering::Relaxed),
396 "callback should have been fired after config change"
397 );
398
399 let current = watcher.current_config().await;
401 assert_eq!(current.rate_limit_rpm, 120);
402
403 handle.abort();
404 let _ = std::fs::remove_dir_all(&dir);
405 }
406
407 #[tokio::test]
408 async fn parse_error_keeps_old_config() {
409 let dir = std::env::temp_dir().join(format!("punch-cfg-parse-{}", uuid::Uuid::new_v4()));
410 std::fs::create_dir_all(&dir).expect("create temp dir");
411 let config_path = dir.join("punch.toml");
412
413 let initial = make_test_config();
414 let toml_str = toml::to_string_pretty(&initial).expect("serialize initial config");
415 std::fs::write(&config_path, &toml_str).expect("write initial config");
416
417 let watcher = KernelConfigWatcher::new(config_path.clone(), initial.clone());
418 let handle = watcher.watch();
419
420 tokio::time::sleep(Duration::from_secs(1)).await;
421
422 std::fs::write(&config_path, "this is not valid toml {{{}}}").expect("write bad config");
424
425 tokio::time::sleep(Duration::from_secs(7)).await;
426
427 let current = watcher.current_config().await;
429 assert_eq!(current.rate_limit_rpm, 60);
430
431 handle.abort();
432 let _ = std::fs::remove_dir_all(&dir);
433 }
434
435 #[test]
436 fn diff_correctly_identifies_changed_fields() {
437 let old = make_test_config();
438 let mut new = old.clone();
439 new.rate_limit_rpm = 200;
440 new.default_model.model = "gpt-4o".to_string();
441
442 let changes = diff_configs(&old, &new);
443 let diff = KernelConfigDiff::from_changes(&changes);
444
445 assert!(diff.rate_limit_changed);
446 assert!(diff.model_changed);
447 assert!(diff.channels_changed.is_empty());
448 assert!(diff.mcp_servers_changed.is_empty());
449 assert!(diff.requires_restart.is_empty());
450 }
451
452 #[tokio::test]
453 async fn callback_registration_and_invocation() {
454 let config_path = PathBuf::from("/tmp/nonexistent-punch-test.toml");
455 let config = make_test_config();
456 let watcher = KernelConfigWatcher::new(config_path, config);
457
458 let counter = Arc::new(AtomicU64::new(0));
459 let c1 = Arc::clone(&counter);
460 watcher
461 .on_change(move |_cfg, _diff| {
462 c1.fetch_add(1, Ordering::Relaxed);
463 })
464 .await;
465
466 let cbs = watcher.callbacks.read().await;
468 assert_eq!(cbs.len(), 1);
469 }
470
471 #[tokio::test]
472 async fn multiple_callbacks_supported() {
473 let config_path = PathBuf::from("/tmp/nonexistent-punch-multi.toml");
474 let config = make_test_config();
475 let watcher = KernelConfigWatcher::new(config_path, config);
476
477 let c1 = Arc::new(AtomicU64::new(0));
478 let c2 = Arc::new(AtomicU64::new(0));
479
480 let c1_clone = Arc::clone(&c1);
481 let c2_clone = Arc::clone(&c2);
482
483 watcher
484 .on_change(move |_cfg, _diff| {
485 c1_clone.fetch_add(1, Ordering::Relaxed);
486 })
487 .await;
488
489 watcher
490 .on_change(move |_cfg, _diff| {
491 c2_clone.fetch_add(1, Ordering::Relaxed);
492 })
493 .await;
494
495 let cbs = watcher.callbacks.read().await;
496 assert_eq!(cbs.len(), 2);
497 }
498
499 #[test]
500 fn non_reloadable_fields_logged_as_requiring_restart() {
501 let changes = vec![
502 ConfigChange::ListenAddressChanged {
503 old: "127.0.0.1:6660".to_string(),
504 new: "0.0.0.0:8080".to_string(),
505 },
506 ConfigChange::ApiKeyChanged,
507 ];
508
509 let diff = KernelConfigDiff::from_changes(&changes);
510 assert!(!diff.has_reloadable_changes());
511 assert_eq!(diff.requires_restart.len(), 2);
512 }
513
514 #[tokio::test]
515 async fn concurrent_reads_during_reload() {
516 let config = make_test_config();
517 let watcher = KernelConfigWatcher::new(PathBuf::from("/tmp/test.toml"), config);
518 let config_arc = watcher.config_arc();
519
520 let mut handles = Vec::new();
522 for _ in 0..10 {
523 let arc = Arc::clone(&config_arc);
524 handles.push(tokio::spawn(async move {
525 let cfg = arc.read().await;
526 assert!(!cfg.api_listen.is_empty());
527 }));
528 }
529
530 let arc_w = Arc::clone(&config_arc);
532 handles.push(tokio::spawn(async move {
533 let mut cfg = arc_w.write().await;
534 cfg.rate_limit_rpm = 999;
535 }));
536
537 for h in handles {
538 h.await.expect("task should complete");
539 }
540
541 let final_cfg = config_arc.read().await;
543 assert_eq!(final_cfg.rate_limit_rpm, 999);
544 }
545
546 #[test]
547 fn memory_change_detected() {
548 let changes = vec![ConfigChange::MemoryConfigChanged];
549 let diff = KernelConfigDiff::from_changes(&changes);
550 assert!(diff.memory_changed);
551 assert!(diff.has_reloadable_changes());
552 }
553}