1use std::collections::HashSet;
9use std::path::PathBuf;
10use std::sync::Arc;
11
12use chrono::{DateTime, Utc};
13use notify::{Event, EventKind, RecommendedWatcher, RecursiveMode, Watcher};
14use serde::{Deserialize, Serialize};
15use tokio::sync::{RwLock, watch};
16use tracing::{debug, error, info, warn};
17
18use crate::config::PunchConfig;
19use crate::error::{PunchError, PunchResult};
20
21#[derive(Debug)]
24pub struct ConfigWatcher {
25 config_path: PathBuf,
27 current: Arc<RwLock<PunchConfig>>,
29 tx: watch::Sender<PunchConfig>,
31 rx: watch::Receiver<PunchConfig>,
33}
34
35impl ConfigWatcher {
36 pub fn new(config_path: PathBuf, initial_config: PunchConfig) -> Self {
38 let (tx, rx) = watch::channel(initial_config.clone());
39 Self {
40 config_path,
41 current: Arc::new(RwLock::new(initial_config)),
42 tx,
43 rx,
44 }
45 }
46
47 pub fn subscribe(&self) -> watch::Receiver<PunchConfig> {
49 self.rx.clone()
50 }
51
52 pub fn current(&self) -> PunchConfig {
54 self.rx.borrow().clone()
55 }
56
57 pub async fn start_watching(&self) -> PunchResult<()> {
62 let config_path = self.config_path.clone();
63 let current = Arc::clone(&self.current);
64 let tx = self.tx.clone();
65
66 let watch_path = config_path
68 .parent()
69 .map(|p| p.to_path_buf())
70 .unwrap_or_else(|| PathBuf::from("."));
71
72 let target_file = config_path
73 .file_name()
74 .map(|f| f.to_os_string())
75 .ok_or_else(|| PunchError::Config("config path has no file name".to_string()))?;
76
77 let (notify_tx, mut notify_rx) = tokio::sync::mpsc::channel::<Event>(16);
78
79 let _watcher: RecommendedWatcher = {
81 let notify_tx = notify_tx.clone();
82 let mut watcher =
83 notify::recommended_watcher(move |res: Result<Event, notify::Error>| match res {
84 Ok(event) => {
85 if let Err(e) = notify_tx.blocking_send(event) {
86 error!(error = %e, "failed to forward file event");
87 }
88 }
89 Err(e) => {
90 error!(error = %e, "filesystem watcher error");
91 }
92 })
93 .map_err(|e| PunchError::Config(format!("failed to create file watcher: {}", e)))?;
94
95 watcher
96 .watch(&watch_path, RecursiveMode::NonRecursive)
97 .map_err(|e| {
98 PunchError::Config(format!("failed to watch config directory: {}", e))
99 })?;
100
101 watcher
102 };
103
104 let config_path_for_task = config_path.clone();
105 let target_file_for_task = target_file.clone();
106
107 tokio::spawn(async move {
109 let _watcher = _watcher;
111
112 info!(path = %config_path_for_task.display(), "corner team watching config file");
113
114 while let Some(event) = notify_rx.recv().await {
115 let dominated = matches!(event.kind, EventKind::Modify(_) | EventKind::Create(_));
117 if !dominated {
118 continue;
119 }
120
121 let affects_target = event.paths.iter().any(|p| {
122 p.file_name()
123 .map(|f| f == target_file_for_task)
124 .unwrap_or(false)
125 });
126 if !affects_target {
127 continue;
128 }
129
130 debug!("config file change detected, reloading");
131
132 let content = match tokio::fs::read_to_string(&config_path_for_task).await {
134 Ok(c) => c,
135 Err(e) => {
136 error!(error = %e, "failed to read config file during reload");
137 continue;
138 }
139 };
140
141 let new_config: PunchConfig = match toml::from_str(&content) {
142 Ok(c) => c,
143 Err(e) => {
144 error!(error = %e, "failed to parse config file during reload");
145 continue;
146 }
147 };
148
149 let errors: Vec<_> = validate_config(&new_config)
151 .into_iter()
152 .filter(|v| matches!(v.severity, ValidationSeverity::Error))
153 .collect();
154
155 if !errors.is_empty() {
156 for err in &errors {
157 error!(field = %err.field, message = %err.message, "config validation failed");
158 }
159 continue;
160 }
161
162 let old_config = {
164 let mut guard = current.write().await;
165 let old = guard.clone();
166 *guard = new_config.clone();
167 old
168 };
169
170 let changes = diff_configs(&old_config, &new_config);
171 if changes.is_empty() {
172 debug!("config file changed but no effective differences detected");
173 continue;
174 }
175
176 for change in &changes {
177 info!(change = ?change, "corner team adjustment applied");
178 }
179
180 if tx.send(new_config).is_err() {
182 warn!("no config subscribers remaining — corner team shouting into the void");
183 break;
184 }
185 }
186
187 info!("config watcher task ended");
188 });
189
190 Ok(())
191 }
192
193 pub fn apply_change(
198 &self,
199 new_config: PunchConfig,
200 ) -> Result<ConfigChangeSet, ConfigValidationError> {
201 let validation_errors: Vec<_> = validate_config(&new_config)
202 .into_iter()
203 .filter(|v| matches!(v.severity, ValidationSeverity::Error))
204 .collect();
205
206 if let Some(err) = validation_errors.into_iter().next() {
207 return Err(err);
208 }
209
210 let old_config = self.rx.borrow().clone();
211 let changes = diff_configs(&old_config, &new_config);
212
213 {
215 let current = Arc::clone(&self.current);
216 let new_config_clone = new_config.clone();
217 let rt = tokio::runtime::Handle::try_current();
220 match rt {
221 Ok(handle) => {
222 let current = current.clone();
223 let cfg = new_config_clone.clone();
224 handle.spawn(async move {
225 let mut guard = current.write().await;
226 *guard = cfg;
227 });
228 }
229 Err(_) => {
230 }
233 }
234 }
235
236 let _ = self.tx.send(new_config);
238
239 Ok(ConfigChangeSet {
240 changes,
241 applied_at: Utc::now(),
242 })
243 }
244}
245
246#[derive(Debug, Clone, Serialize, Deserialize)]
248pub struct ConfigChangeSet {
249 pub changes: Vec<ConfigChange>,
251 pub applied_at: DateTime<Utc>,
253}
254
255#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
257pub enum ConfigChange {
258 ModelChanged {
260 old_model: String,
261 new_model: String,
262 },
263 ApiKeyChanged,
265 RateLimitChanged { old: u32, new: u32 },
267 ListenAddressChanged { old: String, new: String },
269 ChannelAdded(String),
271 ChannelRemoved(String),
273 McpServerAdded(String),
275 McpServerRemoved(String),
277 MemoryConfigChanged,
279}
280
281#[derive(Debug, Clone, Serialize, Deserialize)]
283pub struct ConfigValidationError {
284 pub field: String,
286 pub message: String,
288 pub severity: ValidationSeverity,
290}
291
292impl std::fmt::Display for ConfigValidationError {
293 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
294 write!(f, "[{:?}] {}: {}", self.severity, self.field, self.message)
295 }
296}
297
298impl std::error::Error for ConfigValidationError {}
299
300#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
302pub enum ValidationSeverity {
303 Warning,
305 Error,
307}
308
309pub fn diff_configs(old: &PunchConfig, new: &PunchConfig) -> Vec<ConfigChange> {
311 let mut changes = Vec::new();
312
313 if old.default_model.model != new.default_model.model {
315 changes.push(ConfigChange::ModelChanged {
316 old_model: old.default_model.model.clone(),
317 new_model: new.default_model.model.clone(),
318 });
319 }
320
321 if old.api_key != new.api_key {
323 changes.push(ConfigChange::ApiKeyChanged);
324 }
325
326 if old.rate_limit_rpm != new.rate_limit_rpm {
328 changes.push(ConfigChange::RateLimitChanged {
329 old: old.rate_limit_rpm,
330 new: new.rate_limit_rpm,
331 });
332 }
333
334 if old.api_listen != new.api_listen {
336 changes.push(ConfigChange::ListenAddressChanged {
337 old: old.api_listen.clone(),
338 new: new.api_listen.clone(),
339 });
340 }
341
342 let old_channels: HashSet<&String> = old.channels.keys().collect();
344 let new_channels: HashSet<&String> = new.channels.keys().collect();
345
346 for added in new_channels.difference(&old_channels) {
347 changes.push(ConfigChange::ChannelAdded((*added).clone()));
348 }
349 for removed in old_channels.difference(&new_channels) {
350 changes.push(ConfigChange::ChannelRemoved((*removed).clone()));
351 }
352
353 let old_servers: HashSet<&String> = old.mcp_servers.keys().collect();
355 let new_servers: HashSet<&String> = new.mcp_servers.keys().collect();
356
357 for added in new_servers.difference(&old_servers) {
358 changes.push(ConfigChange::McpServerAdded((*added).clone()));
359 }
360 for removed in old_servers.difference(&new_servers) {
361 changes.push(ConfigChange::McpServerRemoved((*removed).clone()));
362 }
363
364 let old_mem = serde_json::to_string(&old.memory).unwrap_or_default();
366 let new_mem = serde_json::to_string(&new.memory).unwrap_or_default();
367 if old_mem != new_mem {
368 changes.push(ConfigChange::MemoryConfigChanged);
369 }
370
371 changes
372}
373
374pub fn validate_config(config: &PunchConfig) -> Vec<ConfigValidationError> {
379 let mut errors = Vec::new();
380
381 if config.api_listen.parse::<std::net::SocketAddr>().is_err() {
383 errors.push(ConfigValidationError {
384 field: "api_listen".to_string(),
385 message: format!(
386 "'{}' is not a valid socket address (expected host:port)",
387 config.api_listen
388 ),
389 severity: ValidationSeverity::Error,
390 });
391 }
392
393 if config.default_model.model.trim().is_empty() {
395 errors.push(ConfigValidationError {
396 field: "default_model.model".to_string(),
397 message: "model name cannot be empty — the fighter needs a stance".to_string(),
398 severity: ValidationSeverity::Error,
399 });
400 }
401
402 if config.memory.db_path.trim().is_empty() {
404 errors.push(ConfigValidationError {
405 field: "memory.db_path".to_string(),
406 message: "database path cannot be empty — the fighter needs memory".to_string(),
407 severity: ValidationSeverity::Error,
408 });
409 }
410
411 if config.rate_limit_rpm == 0 {
413 errors.push(ConfigValidationError {
414 field: "rate_limit_rpm".to_string(),
415 message: "rate limit must be greater than zero — even a slugger needs some pace"
416 .to_string(),
417 severity: ValidationSeverity::Error,
418 });
419 }
420
421 if config.api_key.is_empty() {
423 errors.push(ConfigValidationError {
424 field: "api_key".to_string(),
425 message: "API key is empty — running in dev mode with no authentication".to_string(),
426 severity: ValidationSeverity::Warning,
427 });
428 }
429
430 errors
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436 use crate::config::{MemoryConfig, ModelConfig, Provider};
437 use std::collections::HashMap;
438
439 fn make_test_config() -> PunchConfig {
441 PunchConfig {
442 api_listen: "127.0.0.1:6660".to_string(),
443 api_key: "test-key-123".to_string(),
444 rate_limit_rpm: 60,
445 default_model: ModelConfig {
446 provider: Provider::Anthropic,
447 model: "claude-sonnet-4-20250514".to_string(),
448 api_key_env: Some("ANTHROPIC_API_KEY".to_string()),
449 base_url: None,
450 max_tokens: Some(4096),
451 temperature: Some(0.7),
452 },
453 memory: MemoryConfig {
454 db_path: "/tmp/punch-test.db".to_string(),
455 knowledge_graph_enabled: true,
456 max_entries: Some(10000),
457 },
458 channels: HashMap::new(),
459 mcp_servers: HashMap::new(),
460 }
461 }
462
463 #[test]
464 fn diff_detects_model_change() {
465 let old = make_test_config();
466 let mut new = old.clone();
467 new.default_model.model = "claude-opus-4-20250514".to_string();
468
469 let changes = diff_configs(&old, &new);
470 assert!(changes.contains(&ConfigChange::ModelChanged {
471 old_model: "claude-sonnet-4-20250514".to_string(),
472 new_model: "claude-opus-4-20250514".to_string(),
473 }));
474 }
475
476 #[test]
477 fn diff_detects_rate_limit_change() {
478 let old = make_test_config();
479 let mut new = old.clone();
480 new.rate_limit_rpm = 120;
481
482 let changes = diff_configs(&old, &new);
483 assert!(changes.contains(&ConfigChange::RateLimitChanged { old: 60, new: 120 }));
484 }
485
486 #[test]
487 fn diff_detects_channel_added() {
488 let old = make_test_config();
489 let mut new = old.clone();
490 new.channels.insert(
491 "slack".to_string(),
492 crate::config::ChannelConfig {
493 channel_type: "slack".to_string(),
494 token_env: Some("SLACK_TOKEN".to_string()),
495 settings: HashMap::new(),
496 },
497 );
498
499 let changes = diff_configs(&old, &new);
500 assert!(changes.contains(&ConfigChange::ChannelAdded("slack".to_string())));
501 }
502
503 #[test]
504 fn diff_detects_channel_removed() {
505 let mut old = make_test_config();
506 old.channels.insert(
507 "discord".to_string(),
508 crate::config::ChannelConfig {
509 channel_type: "discord".to_string(),
510 token_env: Some("DISCORD_TOKEN".to_string()),
511 settings: HashMap::new(),
512 },
513 );
514 let new = make_test_config();
515
516 let changes = diff_configs(&old, &new);
517 assert!(changes.contains(&ConfigChange::ChannelRemoved("discord".to_string())));
518 }
519
520 #[test]
521 fn diff_returns_empty_for_identical_configs() {
522 let config = make_test_config();
523 let changes = diff_configs(&config, &config);
524 assert!(
525 changes.is_empty(),
526 "identical configs should produce no changes"
527 );
528 }
529
530 #[test]
531 fn validate_passes_valid_config() {
532 let config = make_test_config();
533 let errors: Vec<_> = validate_config(&config)
534 .into_iter()
535 .filter(|e| matches!(e.severity, ValidationSeverity::Error))
536 .collect();
537 assert!(errors.is_empty(), "valid config should produce no errors");
538 }
539
540 #[test]
541 fn validate_catches_empty_model_name() {
542 let mut config = make_test_config();
543 config.default_model.model = "".to_string();
544
545 let errors = validate_config(&config);
546 assert!(
547 errors.iter().any(|e| e.field == "default_model.model"
548 && matches!(e.severity, ValidationSeverity::Error))
549 );
550 }
551
552 #[test]
553 fn validate_catches_empty_db_path() {
554 let mut config = make_test_config();
555 config.memory.db_path = "".to_string();
556
557 let errors = validate_config(&config);
558 assert!(errors.iter().any(
559 |e| e.field == "memory.db_path" && matches!(e.severity, ValidationSeverity::Error)
560 ));
561 }
562
563 #[test]
564 fn validate_warns_on_empty_api_key() {
565 let mut config = make_test_config();
566 config.api_key = "".to_string();
567
568 let errors = validate_config(&config);
569 assert!(
570 errors
571 .iter()
572 .any(|e| e.field == "api_key" && matches!(e.severity, ValidationSeverity::Warning))
573 );
574 }
575
576 #[test]
577 fn config_watcher_can_be_created() {
578 let config = make_test_config();
579 let watcher = ConfigWatcher::new(PathBuf::from("/tmp/punch.toml"), config.clone());
580 assert_eq!(watcher.current().api_listen, config.api_listen);
581 }
582
583 #[tokio::test]
584 async fn apply_change_returns_change_set() {
585 let config = make_test_config();
586 let watcher = ConfigWatcher::new(PathBuf::from("/tmp/punch.toml"), config);
587
588 let mut new_config = make_test_config();
589 new_config.rate_limit_rpm = 120;
590
591 let result = watcher.apply_change(new_config);
592 assert!(result.is_ok());
593 let change_set = result.expect("should succeed");
594 assert!(
595 change_set
596 .changes
597 .contains(&ConfigChange::RateLimitChanged { old: 60, new: 120 })
598 );
599 }
600
601 #[tokio::test]
602 async fn apply_change_rejects_invalid_config() {
603 let config = make_test_config();
604 let watcher = ConfigWatcher::new(PathBuf::from("/tmp/punch.toml"), config);
605
606 let mut bad_config = make_test_config();
607 bad_config.default_model.model = "".to_string();
608
609 let result = watcher.apply_change(bad_config);
610 assert!(result.is_err());
611 }
612
613 #[test]
614 fn current_config_accessible_after_creation() {
615 let config = make_test_config();
616 let watcher = ConfigWatcher::new(PathBuf::from("/tmp/punch.toml"), config.clone());
617
618 let current = watcher.current();
619 assert_eq!(current.api_listen, "127.0.0.1:6660");
620 assert_eq!(current.rate_limit_rpm, 60);
621 assert_eq!(current.default_model.model, "claude-sonnet-4-20250514");
622 assert_eq!(current.memory.db_path, "/tmp/punch-test.db");
623 }
624
625 #[test]
626 fn diff_detects_mcp_server_added() {
627 let old = make_test_config();
628 let mut new = old.clone();
629 new.mcp_servers.insert(
630 "filesystem".to_string(),
631 crate::config::McpServerConfig {
632 command: "npx".to_string(),
633 args: vec!["-y".to_string(), "@mcp/filesystem".to_string()],
634 env: HashMap::new(),
635 },
636 );
637
638 let changes = diff_configs(&old, &new);
639 assert!(changes.contains(&ConfigChange::McpServerAdded("filesystem".to_string())));
640 }
641
642 #[test]
643 fn diff_detects_mcp_server_removed() {
644 let mut old = make_test_config();
645 old.mcp_servers.insert(
646 "memory".to_string(),
647 crate::config::McpServerConfig {
648 command: "mcp-memory".to_string(),
649 args: vec![],
650 env: HashMap::new(),
651 },
652 );
653 let new = make_test_config();
654
655 let changes = diff_configs(&old, &new);
656 assert!(changes.contains(&ConfigChange::McpServerRemoved("memory".to_string())));
657 }
658
659 #[test]
660 fn diff_detects_memory_config_changed() {
661 let old = make_test_config();
662 let mut new = old.clone();
663 new.memory.knowledge_graph_enabled = false;
664
665 let changes = diff_configs(&old, &new);
666 assert!(changes.contains(&ConfigChange::MemoryConfigChanged));
667 }
668
669 #[test]
670 fn diff_detects_api_key_changed() {
671 let old = make_test_config();
672 let mut new = old.clone();
673 new.api_key = "new-secret-key".to_string();
674
675 let changes = diff_configs(&old, &new);
676 assert!(changes.contains(&ConfigChange::ApiKeyChanged));
677 }
678
679 #[test]
680 fn diff_detects_listen_address_changed() {
681 let old = make_test_config();
682 let mut new = old.clone();
683 new.api_listen = "0.0.0.0:8080".to_string();
684
685 let changes = diff_configs(&old, &new);
686 assert!(changes.contains(&ConfigChange::ListenAddressChanged {
687 old: "127.0.0.1:6660".to_string(),
688 new: "0.0.0.0:8080".to_string(),
689 }));
690 }
691
692 #[test]
693 fn validate_catches_invalid_socket_addr() {
694 let mut config = make_test_config();
695 config.api_listen = "not-a-valid-address".to_string();
696
697 let errors = validate_config(&config);
698 assert!(errors.iter().any(|e| e.field == "api_listen"
699 && matches!(e.severity, ValidationSeverity::Error)));
700 }
701
702 #[test]
703 fn validate_catches_zero_rate_limit() {
704 let mut config = make_test_config();
705 config.rate_limit_rpm = 0;
706
707 let errors = validate_config(&config);
708 assert!(errors.iter().any(
709 |e| e.field == "rate_limit_rpm" && matches!(e.severity, ValidationSeverity::Error)
710 ));
711 }
712
713 #[tokio::test]
714 async fn subscriber_receives_initial_config() {
715 let config = make_test_config();
716 let watcher = ConfigWatcher::new(PathBuf::from("/tmp/punch.toml"), config.clone());
717
718 let rx = watcher.subscribe();
719 let received = rx.borrow().clone();
720 assert_eq!(received.api_listen, config.api_listen);
721 }
722}