1use std::path::PathBuf;
10use std::sync::Arc;
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::time::Duration;
13
14use tokio::sync::RwLock;
15use tokio::task::JoinHandle;
16use tracing::{debug, info, warn};
17
18use punch_types::config::PunchConfig;
19use punch_types::hot_reload::{ConfigChange, ValidationSeverity, diff_configs, validate_config};
20
21#[derive(Debug, Clone, Default)]
28pub struct KernelConfigDiff {
29 pub rate_limit_changed: bool,
31 pub model_changed: bool,
33 pub channels_changed: Vec<String>,
35 pub mcp_servers_changed: Vec<String>,
37 pub memory_changed: bool,
39 pub requires_restart: Vec<String>,
41}
42
43impl KernelConfigDiff {
44 fn from_changes(changes: &[ConfigChange]) -> Self {
46 let mut diff = Self::default();
47
48 for change in changes {
49 match change {
50 ConfigChange::RateLimitChanged { .. } => {
51 diff.rate_limit_changed = true;
52 }
53 ConfigChange::ModelChanged { .. } => {
54 diff.model_changed = true;
55 }
56 ConfigChange::ChannelAdded(name) | ConfigChange::ChannelRemoved(name) => {
57 if !diff.channels_changed.contains(name) {
58 diff.channels_changed.push(name.clone());
59 }
60 }
61 ConfigChange::McpServerAdded(name) | ConfigChange::McpServerRemoved(name) => {
62 if !diff.mcp_servers_changed.contains(name) {
63 diff.mcp_servers_changed.push(name.clone());
64 }
65 }
66 ConfigChange::MemoryConfigChanged => {
67 diff.memory_changed = true;
68 }
69 ConfigChange::ListenAddressChanged { .. } => {
71 diff.requires_restart.push("api_listen".to_string());
72 }
73 ConfigChange::ApiKeyChanged => {
74 diff.requires_restart.push("api_key".to_string());
75 }
76 }
77 }
78
79 diff
80 }
81
82 pub fn has_reloadable_changes(&self) -> bool {
84 self.rate_limit_changed
85 || self.model_changed
86 || !self.channels_changed.is_empty()
87 || !self.mcp_servers_changed.is_empty()
88 || self.memory_changed
89 }
90}
91
92type ConfigCallbacks = Arc<RwLock<Vec<Box<dyn Fn(&PunchConfig, &KernelConfigDiff) + Send + Sync>>>>;
98
99pub struct KernelConfigWatcher {
105 config: Arc<RwLock<PunchConfig>>,
106 config_path: PathBuf,
107 last_modified: AtomicU64,
108 callbacks: ConfigCallbacks,
109}
110
111impl KernelConfigWatcher {
112 pub fn new(config_path: PathBuf, initial_config: PunchConfig) -> Self {
114 let mtime = Self::file_mtime(&config_path).unwrap_or(0);
115
116 Self {
117 config: Arc::new(RwLock::new(initial_config)),
118 config_path,
119 last_modified: AtomicU64::new(mtime),
120 callbacks: Arc::new(RwLock::new(Vec::new())),
121 }
122 }
123
124 pub async fn on_change<F>(&self, callback: F)
129 where
130 F: Fn(&PunchConfig, &KernelConfigDiff) + Send + Sync + 'static,
131 {
132 let mut cbs = self.callbacks.write().await;
133 cbs.push(Box::new(callback));
134 }
135
136 pub async fn current_config(&self) -> PunchConfig {
138 self.config.read().await.clone()
139 }
140
141 pub fn config_arc(&self) -> Arc<RwLock<PunchConfig>> {
143 Arc::clone(&self.config)
144 }
145
146 pub fn watch(&self) -> JoinHandle<()> {
156 let config = Arc::clone(&self.config);
157 let config_path = self.config_path.clone();
158 let last_modified = self.last_modified.load(Ordering::Relaxed);
159 let last_modified_atomic = Arc::new(AtomicU64::new(last_modified));
160 let callbacks = Arc::clone(&self.callbacks);
161
162 tokio::spawn(async move {
163 let mut interval = tokio::time::interval(Duration::from_secs(5));
164 interval.tick().await;
166
167 info!(path = %config_path.display(), "config poll watcher started (5s interval)");
168
169 loop {
170 interval.tick().await;
171
172 let current_mtime = match Self::file_mtime(&config_path) {
173 Some(m) => m,
174 None => {
175 debug!("config file not found or inaccessible, skipping check");
176 continue;
177 }
178 };
179
180 let prev_mtime = last_modified_atomic.load(Ordering::Relaxed);
181 if current_mtime == prev_mtime {
182 continue;
183 }
184
185 debug!(
186 old_mtime = prev_mtime,
187 new_mtime = current_mtime,
188 "config file mtime changed, reloading"
189 );
190
191 last_modified_atomic.store(current_mtime, Ordering::Relaxed);
192
193 let content = match tokio::fs::read_to_string(&config_path).await {
195 Ok(c) => c,
196 Err(e) => {
197 warn!(error = %e, "failed to read config file during hot reload");
198 continue;
199 }
200 };
201
202 let new_config: PunchConfig = match toml::from_str(&content) {
204 Ok(c) => c,
205 Err(e) => {
206 warn!(error = %e, "config parse error during hot reload — keeping old config");
207 continue;
208 }
209 };
210
211 let errors: Vec<_> = validate_config(&new_config)
213 .into_iter()
214 .filter(|v| matches!(v.severity, ValidationSeverity::Error))
215 .collect();
216
217 if !errors.is_empty() {
218 for err in &errors {
219 warn!(field = %err.field, message = %err.message, "config validation error — keeping old config");
220 }
221 continue;
222 }
223
224 let old_config = config.read().await.clone();
226 let changes = diff_configs(&old_config, &new_config);
227
228 if changes.is_empty() {
229 debug!("config file changed (mtime) but no effective differences");
230 continue;
231 }
232
233 let diff = KernelConfigDiff::from_changes(&changes);
234
235 for change in &changes {
237 info!(change = ?change, "config hot reload: change detected");
238 }
239
240 for field in &diff.requires_restart {
242 warn!(
243 field = %field,
244 "config field changed but requires restart to take effect"
245 );
246 }
247
248 {
250 let mut guard = config.write().await;
251 *guard = new_config.clone();
252 }
253
254 let cbs = callbacks.read().await;
256 for cb in cbs.iter() {
257 cb(&new_config, &diff);
258 }
259
260 info!(num_changes = changes.len(), "config hot reload complete");
261 }
262 })
263 }
264
265 fn file_mtime(path: &PathBuf) -> Option<u64> {
268 std::fs::metadata(path)
269 .ok()
270 .and_then(|m| m.modified().ok())
271 .and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok())
272 .map(|d| d.as_secs())
273 }
274}
275
276#[cfg(test)]
281mod tests {
282 use super::*;
283 use punch_types::config::{MemoryConfig, ModelConfig, Provider};
284 use std::collections::HashMap;
285 use std::sync::atomic::AtomicBool;
286
287 fn make_test_config() -> PunchConfig {
288 PunchConfig {
289 api_listen: "127.0.0.1:6660".to_string(),
290 api_key: "test-key".to_string(),
291 rate_limit_rpm: 60,
292 default_model: ModelConfig {
293 provider: Provider::Anthropic,
294 model: "claude-sonnet-4-20250514".to_string(),
295 api_key_env: Some("ANTHROPIC_API_KEY".to_string()),
296 base_url: None,
297 max_tokens: Some(4096),
298 temperature: Some(0.7),
299 },
300 memory: MemoryConfig {
301 db_path: "/tmp/punch-test.db".to_string(),
302 knowledge_graph_enabled: true,
303 max_entries: Some(10000),
304 },
305 channels: HashMap::new(),
306 mcp_servers: HashMap::new(),
307 }
308 }
309
310 #[test]
311 fn kernel_config_diff_from_changes() {
312 let changes = vec![
313 ConfigChange::RateLimitChanged { old: 60, new: 120 },
314 ConfigChange::ModelChanged {
315 old_model: "a".to_string(),
316 new_model: "b".to_string(),
317 },
318 ConfigChange::ChannelAdded("slack".to_string()),
319 ConfigChange::McpServerRemoved("fs".to_string()),
320 ConfigChange::ListenAddressChanged {
321 old: "a".to_string(),
322 new: "b".to_string(),
323 },
324 ConfigChange::ApiKeyChanged,
325 ];
326
327 let diff = KernelConfigDiff::from_changes(&changes);
328 assert!(diff.rate_limit_changed);
329 assert!(diff.model_changed);
330 assert_eq!(diff.channels_changed, vec!["slack".to_string()]);
331 assert_eq!(diff.mcp_servers_changed, vec!["fs".to_string()]);
332 assert_eq!(diff.requires_restart.len(), 2);
333 assert!(diff.requires_restart.contains(&"api_listen".to_string()));
334 assert!(diff.requires_restart.contains(&"api_key".to_string()));
335 }
336
337 #[test]
338 fn kernel_config_diff_has_reloadable_changes() {
339 let empty = KernelConfigDiff::default();
340 assert!(!empty.has_reloadable_changes());
341
342 let with_rate = KernelConfigDiff {
343 rate_limit_changed: true,
344 ..Default::default()
345 };
346 assert!(with_rate.has_reloadable_changes());
347
348 let restart_only = KernelConfigDiff {
349 requires_restart: vec!["api_listen".to_string()],
350 ..Default::default()
351 };
352 assert!(!restart_only.has_reloadable_changes());
353 }
354
355 #[tokio::test]
356 async fn watch_detects_file_change() {
357 let dir = std::env::temp_dir().join(format!("punch-cfg-test-{}", uuid::Uuid::new_v4()));
358 std::fs::create_dir_all(&dir).expect("create temp dir");
359 let config_path = dir.join("punch.toml");
360
361 let initial = make_test_config();
362 let toml_str = toml::to_string_pretty(&initial).expect("serialize initial config");
363 std::fs::write(&config_path, &toml_str).expect("write initial config");
364
365 let watcher = KernelConfigWatcher::new(config_path.clone(), initial.clone());
366
367 let callback_fired = Arc::new(AtomicBool::new(false));
368 let cb_flag = Arc::clone(&callback_fired);
369 watcher
370 .on_change(move |_cfg, _diff| {
371 cb_flag.store(true, Ordering::Relaxed);
372 })
373 .await;
374
375 let handle = watcher.watch();
376
377 tokio::time::sleep(Duration::from_millis(200)).await;
379
380 let mut modified = initial.clone();
381 modified.rate_limit_rpm = 120;
382 let new_toml = toml::to_string_pretty(&modified).expect("serialize modified config");
383
384 tokio::time::sleep(Duration::from_secs(1)).await;
386 std::fs::write(&config_path, &new_toml).expect("write modified config");
387
388 tokio::time::sleep(Duration::from_secs(7)).await;
390
391 assert!(
392 callback_fired.load(Ordering::Relaxed),
393 "callback should have been fired after config change"
394 );
395
396 let current = watcher.current_config().await;
398 assert_eq!(current.rate_limit_rpm, 120);
399
400 handle.abort();
401 let _ = std::fs::remove_dir_all(&dir);
402 }
403
404 #[tokio::test]
405 async fn parse_error_keeps_old_config() {
406 let dir = std::env::temp_dir().join(format!("punch-cfg-parse-{}", uuid::Uuid::new_v4()));
407 std::fs::create_dir_all(&dir).expect("create temp dir");
408 let config_path = dir.join("punch.toml");
409
410 let initial = make_test_config();
411 let toml_str = toml::to_string_pretty(&initial).expect("serialize initial config");
412 std::fs::write(&config_path, &toml_str).expect("write initial config");
413
414 let watcher = KernelConfigWatcher::new(config_path.clone(), initial.clone());
415 let handle = watcher.watch();
416
417 tokio::time::sleep(Duration::from_secs(1)).await;
418
419 std::fs::write(&config_path, "this is not valid toml {{{}}}").expect("write bad config");
421
422 tokio::time::sleep(Duration::from_secs(7)).await;
423
424 let current = watcher.current_config().await;
426 assert_eq!(current.rate_limit_rpm, 60);
427
428 handle.abort();
429 let _ = std::fs::remove_dir_all(&dir);
430 }
431
432 #[test]
433 fn diff_correctly_identifies_changed_fields() {
434 let old = make_test_config();
435 let mut new = old.clone();
436 new.rate_limit_rpm = 200;
437 new.default_model.model = "gpt-4o".to_string();
438
439 let changes = diff_configs(&old, &new);
440 let diff = KernelConfigDiff::from_changes(&changes);
441
442 assert!(diff.rate_limit_changed);
443 assert!(diff.model_changed);
444 assert!(diff.channels_changed.is_empty());
445 assert!(diff.mcp_servers_changed.is_empty());
446 assert!(diff.requires_restart.is_empty());
447 }
448
449 #[tokio::test]
450 async fn callback_registration_and_invocation() {
451 let config_path = PathBuf::from("/tmp/nonexistent-punch-test.toml");
452 let config = make_test_config();
453 let watcher = KernelConfigWatcher::new(config_path, config);
454
455 let counter = Arc::new(AtomicU64::new(0));
456 let c1 = Arc::clone(&counter);
457 watcher
458 .on_change(move |_cfg, _diff| {
459 c1.fetch_add(1, Ordering::Relaxed);
460 })
461 .await;
462
463 let cbs = watcher.callbacks.read().await;
465 assert_eq!(cbs.len(), 1);
466 }
467
468 #[tokio::test]
469 async fn multiple_callbacks_supported() {
470 let config_path = PathBuf::from("/tmp/nonexistent-punch-multi.toml");
471 let config = make_test_config();
472 let watcher = KernelConfigWatcher::new(config_path, config);
473
474 let c1 = Arc::new(AtomicU64::new(0));
475 let c2 = Arc::new(AtomicU64::new(0));
476
477 let c1_clone = Arc::clone(&c1);
478 let c2_clone = Arc::clone(&c2);
479
480 watcher
481 .on_change(move |_cfg, _diff| {
482 c1_clone.fetch_add(1, Ordering::Relaxed);
483 })
484 .await;
485
486 watcher
487 .on_change(move |_cfg, _diff| {
488 c2_clone.fetch_add(1, Ordering::Relaxed);
489 })
490 .await;
491
492 let cbs = watcher.callbacks.read().await;
493 assert_eq!(cbs.len(), 2);
494 }
495
496 #[test]
497 fn non_reloadable_fields_logged_as_requiring_restart() {
498 let changes = vec![
499 ConfigChange::ListenAddressChanged {
500 old: "127.0.0.1:6660".to_string(),
501 new: "0.0.0.0:8080".to_string(),
502 },
503 ConfigChange::ApiKeyChanged,
504 ];
505
506 let diff = KernelConfigDiff::from_changes(&changes);
507 assert!(!diff.has_reloadable_changes());
508 assert_eq!(diff.requires_restart.len(), 2);
509 }
510
511 #[tokio::test]
512 async fn concurrent_reads_during_reload() {
513 let config = make_test_config();
514 let watcher = KernelConfigWatcher::new(PathBuf::from("/tmp/test.toml"), config);
515 let config_arc = watcher.config_arc();
516
517 let mut handles = Vec::new();
519 for _ in 0..10 {
520 let arc = Arc::clone(&config_arc);
521 handles.push(tokio::spawn(async move {
522 let cfg = arc.read().await;
523 assert!(!cfg.api_listen.is_empty());
524 }));
525 }
526
527 let arc_w = Arc::clone(&config_arc);
529 handles.push(tokio::spawn(async move {
530 let mut cfg = arc_w.write().await;
531 cfg.rate_limit_rpm = 999;
532 }));
533
534 for h in handles {
535 h.await.expect("task should complete");
536 }
537
538 let final_cfg = config_arc.read().await;
540 assert_eq!(final_cfg.rate_limit_rpm, 999);
541 }
542
543 #[test]
544 fn memory_change_detected() {
545 let changes = vec![ConfigChange::MemoryConfigChanged];
546 let diff = KernelConfigDiff::from_changes(&changes);
547 assert!(diff.memory_changed);
548 assert!(diff.has_reloadable_changes());
549 }
550}