1use std::collections::HashSet;
9use std::path::PathBuf;
10use std::sync::Arc;
11
12use chrono::{DateTime, Utc};
13use notify::{Event, EventKind, RecommendedWatcher, RecursiveMode, Watcher};
14use serde::{Deserialize, Serialize};
15use tokio::sync::{RwLock, watch};
16use tracing::{debug, error, info, warn};
17
18use crate::config::PunchConfig;
19use crate::error::{PunchError, PunchResult};
20
21#[derive(Debug)]
24pub struct ConfigWatcher {
25 config_path: PathBuf,
27 current: Arc<RwLock<PunchConfig>>,
29 tx: watch::Sender<PunchConfig>,
31 rx: watch::Receiver<PunchConfig>,
33}
34
35impl ConfigWatcher {
36 pub fn new(config_path: PathBuf, initial_config: PunchConfig) -> Self {
38 let (tx, rx) = watch::channel(initial_config.clone());
39 Self {
40 config_path,
41 current: Arc::new(RwLock::new(initial_config)),
42 tx,
43 rx,
44 }
45 }
46
47 pub fn subscribe(&self) -> watch::Receiver<PunchConfig> {
49 self.rx.clone()
50 }
51
52 pub fn current(&self) -> PunchConfig {
54 self.rx.borrow().clone()
55 }
56
57 pub async fn start_watching(&self) -> PunchResult<()> {
62 let config_path = self.config_path.clone();
63 let current = Arc::clone(&self.current);
64 let tx = self.tx.clone();
65
66 let watch_path = config_path
68 .parent()
69 .map(|p| p.to_path_buf())
70 .unwrap_or_else(|| PathBuf::from("."));
71
72 let target_file = config_path
73 .file_name()
74 .map(|f| f.to_os_string())
75 .ok_or_else(|| PunchError::Config("config path has no file name".to_string()))?;
76
77 let (notify_tx, mut notify_rx) = tokio::sync::mpsc::channel::<Event>(16);
78
79 let _watcher: RecommendedWatcher = {
81 let notify_tx = notify_tx.clone();
82 let mut watcher =
83 notify::recommended_watcher(move |res: Result<Event, notify::Error>| match res {
84 Ok(event) => {
85 if let Err(e) = notify_tx.blocking_send(event) {
86 error!(error = %e, "failed to forward file event");
87 }
88 }
89 Err(e) => {
90 error!(error = %e, "filesystem watcher error");
91 }
92 })
93 .map_err(|e| PunchError::Config(format!("failed to create file watcher: {}", e)))?;
94
95 watcher
96 .watch(&watch_path, RecursiveMode::NonRecursive)
97 .map_err(|e| {
98 PunchError::Config(format!("failed to watch config directory: {}", e))
99 })?;
100
101 watcher
102 };
103
104 let config_path_for_task = config_path.clone();
105 let target_file_for_task = target_file.clone();
106
107 tokio::spawn(async move {
109 let _watcher = _watcher;
111
112 info!(path = %config_path_for_task.display(), "corner team watching config file");
113
114 while let Some(event) = notify_rx.recv().await {
115 let dominated = matches!(event.kind, EventKind::Modify(_) | EventKind::Create(_));
117 if !dominated {
118 continue;
119 }
120
121 let affects_target = event.paths.iter().any(|p| {
122 p.file_name()
123 .map(|f| f == target_file_for_task)
124 .unwrap_or(false)
125 });
126 if !affects_target {
127 continue;
128 }
129
130 debug!("config file change detected, reloading");
131
132 let content = match tokio::fs::read_to_string(&config_path_for_task).await {
134 Ok(c) => c,
135 Err(e) => {
136 error!(error = %e, "failed to read config file during reload");
137 continue;
138 }
139 };
140
141 let new_config: PunchConfig = match toml::from_str(&content) {
142 Ok(c) => c,
143 Err(e) => {
144 error!(error = %e, "failed to parse config file during reload");
145 continue;
146 }
147 };
148
149 let errors: Vec<_> = validate_config(&new_config)
151 .into_iter()
152 .filter(|v| matches!(v.severity, ValidationSeverity::Error))
153 .collect();
154
155 if !errors.is_empty() {
156 for err in &errors {
157 error!(field = %err.field, message = %err.message, "config validation failed");
158 }
159 continue;
160 }
161
162 let old_config = {
164 let mut guard = current.write().await;
165 let old = guard.clone();
166 *guard = new_config.clone();
167 old
168 };
169
170 let changes = diff_configs(&old_config, &new_config);
171 if changes.is_empty() {
172 debug!("config file changed but no effective differences detected");
173 continue;
174 }
175
176 for change in &changes {
177 info!(change = ?change, "corner team adjustment applied");
178 }
179
180 if tx.send(new_config).is_err() {
182 warn!("no config subscribers remaining — corner team shouting into the void");
183 break;
184 }
185 }
186
187 info!("config watcher task ended");
188 });
189
190 Ok(())
191 }
192
193 pub fn apply_change(
198 &self,
199 new_config: PunchConfig,
200 ) -> Result<ConfigChangeSet, ConfigValidationError> {
201 let validation_errors: Vec<_> = validate_config(&new_config)
202 .into_iter()
203 .filter(|v| matches!(v.severity, ValidationSeverity::Error))
204 .collect();
205
206 if let Some(err) = validation_errors.into_iter().next() {
207 return Err(err);
208 }
209
210 let old_config = self.rx.borrow().clone();
211 let changes = diff_configs(&old_config, &new_config);
212
213 {
215 let current = Arc::clone(&self.current);
216 let new_config_clone = new_config.clone();
217 let rt = tokio::runtime::Handle::try_current();
220 match rt {
221 Ok(handle) => {
222 let current = current.clone();
223 let cfg = new_config_clone.clone();
224 handle.spawn(async move {
225 let mut guard = current.write().await;
226 *guard = cfg;
227 });
228 }
229 Err(_) => {
230 }
233 }
234 }
235
236 let _ = self.tx.send(new_config);
238
239 Ok(ConfigChangeSet {
240 changes,
241 applied_at: Utc::now(),
242 })
243 }
244}
245
246#[derive(Debug, Clone, Serialize, Deserialize)]
248pub struct ConfigChangeSet {
249 pub changes: Vec<ConfigChange>,
251 pub applied_at: DateTime<Utc>,
253}
254
255#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
257pub enum ConfigChange {
258 ModelChanged {
260 old_model: String,
261 new_model: String,
262 },
263 ApiKeyChanged,
265 RateLimitChanged { old: u32, new: u32 },
267 ListenAddressChanged { old: String, new: String },
269 ChannelAdded(String),
271 ChannelRemoved(String),
273 McpServerAdded(String),
275 McpServerRemoved(String),
277 MemoryConfigChanged,
279}
280
281#[derive(Debug, Clone, Serialize, Deserialize)]
283pub struct ConfigValidationError {
284 pub field: String,
286 pub message: String,
288 pub severity: ValidationSeverity,
290}
291
292impl std::fmt::Display for ConfigValidationError {
293 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
294 write!(f, "[{:?}] {}: {}", self.severity, self.field, self.message)
295 }
296}
297
298impl std::error::Error for ConfigValidationError {}
299
300#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
302pub enum ValidationSeverity {
303 Warning,
305 Error,
307}
308
309pub fn diff_configs(old: &PunchConfig, new: &PunchConfig) -> Vec<ConfigChange> {
311 let mut changes = Vec::new();
312
313 if old.default_model.model != new.default_model.model {
315 changes.push(ConfigChange::ModelChanged {
316 old_model: old.default_model.model.clone(),
317 new_model: new.default_model.model.clone(),
318 });
319 }
320
321 if old.api_key != new.api_key {
323 changes.push(ConfigChange::ApiKeyChanged);
324 }
325
326 if old.rate_limit_rpm != new.rate_limit_rpm {
328 changes.push(ConfigChange::RateLimitChanged {
329 old: old.rate_limit_rpm,
330 new: new.rate_limit_rpm,
331 });
332 }
333
334 if old.api_listen != new.api_listen {
336 changes.push(ConfigChange::ListenAddressChanged {
337 old: old.api_listen.clone(),
338 new: new.api_listen.clone(),
339 });
340 }
341
342 let old_channels: HashSet<&String> = old.channels.keys().collect();
344 let new_channels: HashSet<&String> = new.channels.keys().collect();
345
346 for added in new_channels.difference(&old_channels) {
347 changes.push(ConfigChange::ChannelAdded((*added).clone()));
348 }
349 for removed in old_channels.difference(&new_channels) {
350 changes.push(ConfigChange::ChannelRemoved((*removed).clone()));
351 }
352
353 let old_servers: HashSet<&String> = old.mcp_servers.keys().collect();
355 let new_servers: HashSet<&String> = new.mcp_servers.keys().collect();
356
357 for added in new_servers.difference(&old_servers) {
358 changes.push(ConfigChange::McpServerAdded((*added).clone()));
359 }
360 for removed in old_servers.difference(&new_servers) {
361 changes.push(ConfigChange::McpServerRemoved((*removed).clone()));
362 }
363
364 let old_mem = serde_json::to_string(&old.memory).unwrap_or_default();
366 let new_mem = serde_json::to_string(&new.memory).unwrap_or_default();
367 if old_mem != new_mem {
368 changes.push(ConfigChange::MemoryConfigChanged);
369 }
370
371 changes
372}
373
374pub fn validate_config(config: &PunchConfig) -> Vec<ConfigValidationError> {
379 let mut errors = Vec::new();
380
381 if config.api_listen.parse::<std::net::SocketAddr>().is_err() {
383 errors.push(ConfigValidationError {
384 field: "api_listen".to_string(),
385 message: format!(
386 "'{}' is not a valid socket address (expected host:port)",
387 config.api_listen
388 ),
389 severity: ValidationSeverity::Error,
390 });
391 }
392
393 if config.default_model.model.trim().is_empty() {
395 errors.push(ConfigValidationError {
396 field: "default_model.model".to_string(),
397 message: "model name cannot be empty — the fighter needs a stance".to_string(),
398 severity: ValidationSeverity::Error,
399 });
400 }
401
402 if config.memory.db_path.trim().is_empty() {
404 errors.push(ConfigValidationError {
405 field: "memory.db_path".to_string(),
406 message: "database path cannot be empty — the fighter needs memory".to_string(),
407 severity: ValidationSeverity::Error,
408 });
409 }
410
411 if config.rate_limit_rpm == 0 {
413 errors.push(ConfigValidationError {
414 field: "rate_limit_rpm".to_string(),
415 message: "rate limit must be greater than zero — even a slugger needs some pace"
416 .to_string(),
417 severity: ValidationSeverity::Error,
418 });
419 }
420
421 if config.api_key.is_empty() {
423 errors.push(ConfigValidationError {
424 field: "api_key".to_string(),
425 message: "API key is empty — running in dev mode with no authentication".to_string(),
426 severity: ValidationSeverity::Warning,
427 });
428 }
429
430 errors
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436 use crate::config::{MemoryConfig, ModelConfig, Provider};
437 use std::collections::HashMap;
438
439 fn make_test_config() -> PunchConfig {
441 PunchConfig {
442 api_listen: "127.0.0.1:6660".to_string(),
443 api_key: "test-key-123".to_string(),
444 rate_limit_rpm: 60,
445 default_model: ModelConfig {
446 provider: Provider::Anthropic,
447 model: "claude-sonnet-4-20250514".to_string(),
448 api_key_env: Some("ANTHROPIC_API_KEY".to_string()),
449 base_url: None,
450 max_tokens: Some(4096),
451 temperature: Some(0.7),
452 },
453 memory: MemoryConfig {
454 db_path: "/tmp/punch-test.db".to_string(),
455 knowledge_graph_enabled: true,
456 max_entries: Some(10000),
457 },
458 tunnel: None,
459 channels: HashMap::new(),
460 mcp_servers: HashMap::new(),
461 model_routing: Default::default(),
462 budget: Default::default(),
463 }
464 }
465
466 #[test]
467 fn diff_detects_model_change() {
468 let old = make_test_config();
469 let mut new = old.clone();
470 new.default_model.model = "claude-opus-4-20250514".to_string();
471
472 let changes = diff_configs(&old, &new);
473 assert!(changes.contains(&ConfigChange::ModelChanged {
474 old_model: "claude-sonnet-4-20250514".to_string(),
475 new_model: "claude-opus-4-20250514".to_string(),
476 }));
477 }
478
479 #[test]
480 fn diff_detects_rate_limit_change() {
481 let old = make_test_config();
482 let mut new = old.clone();
483 new.rate_limit_rpm = 120;
484
485 let changes = diff_configs(&old, &new);
486 assert!(changes.contains(&ConfigChange::RateLimitChanged { old: 60, new: 120 }));
487 }
488
489 #[test]
490 fn diff_detects_channel_added() {
491 let old = make_test_config();
492 let mut new = old.clone();
493 new.channels.insert(
494 "slack".to_string(),
495 crate::config::ChannelConfig {
496 channel_type: "slack".to_string(),
497 token_env: Some("SLACK_TOKEN".to_string()),
498 webhook_secret_env: None,
499 allowed_user_ids: vec![],
500 rate_limit_per_user: 20,
501 settings: HashMap::new(),
502 },
503 );
504
505 let changes = diff_configs(&old, &new);
506 assert!(changes.contains(&ConfigChange::ChannelAdded("slack".to_string())));
507 }
508
509 #[test]
510 fn diff_detects_channel_removed() {
511 let mut old = make_test_config();
512 old.channels.insert(
513 "discord".to_string(),
514 crate::config::ChannelConfig {
515 channel_type: "discord".to_string(),
516 token_env: Some("DISCORD_TOKEN".to_string()),
517 webhook_secret_env: None,
518 allowed_user_ids: vec![],
519 rate_limit_per_user: 20,
520 settings: HashMap::new(),
521 },
522 );
523 let new = make_test_config();
524
525 let changes = diff_configs(&old, &new);
526 assert!(changes.contains(&ConfigChange::ChannelRemoved("discord".to_string())));
527 }
528
529 #[test]
530 fn diff_returns_empty_for_identical_configs() {
531 let config = make_test_config();
532 let changes = diff_configs(&config, &config);
533 assert!(
534 changes.is_empty(),
535 "identical configs should produce no changes"
536 );
537 }
538
539 #[test]
540 fn validate_passes_valid_config() {
541 let config = make_test_config();
542 let errors: Vec<_> = validate_config(&config)
543 .into_iter()
544 .filter(|e| matches!(e.severity, ValidationSeverity::Error))
545 .collect();
546 assert!(errors.is_empty(), "valid config should produce no errors");
547 }
548
549 #[test]
550 fn validate_catches_empty_model_name() {
551 let mut config = make_test_config();
552 config.default_model.model = "".to_string();
553
554 let errors = validate_config(&config);
555 assert!(
556 errors.iter().any(|e| e.field == "default_model.model"
557 && matches!(e.severity, ValidationSeverity::Error))
558 );
559 }
560
561 #[test]
562 fn validate_catches_empty_db_path() {
563 let mut config = make_test_config();
564 config.memory.db_path = "".to_string();
565
566 let errors = validate_config(&config);
567 assert!(errors.iter().any(
568 |e| e.field == "memory.db_path" && matches!(e.severity, ValidationSeverity::Error)
569 ));
570 }
571
572 #[test]
573 fn validate_warns_on_empty_api_key() {
574 let mut config = make_test_config();
575 config.api_key = "".to_string();
576
577 let errors = validate_config(&config);
578 assert!(
579 errors
580 .iter()
581 .any(|e| e.field == "api_key" && matches!(e.severity, ValidationSeverity::Warning))
582 );
583 }
584
585 #[test]
586 fn config_watcher_can_be_created() {
587 let config = make_test_config();
588 let watcher = ConfigWatcher::new(PathBuf::from("/tmp/punch.toml"), config.clone());
589 assert_eq!(watcher.current().api_listen, config.api_listen);
590 }
591
592 #[tokio::test]
593 async fn apply_change_returns_change_set() {
594 let config = make_test_config();
595 let watcher = ConfigWatcher::new(PathBuf::from("/tmp/punch.toml"), config);
596
597 let mut new_config = make_test_config();
598 new_config.rate_limit_rpm = 120;
599
600 let result = watcher.apply_change(new_config);
601 assert!(result.is_ok());
602 let change_set = result.expect("should succeed");
603 assert!(
604 change_set
605 .changes
606 .contains(&ConfigChange::RateLimitChanged { old: 60, new: 120 })
607 );
608 }
609
610 #[tokio::test]
611 async fn apply_change_rejects_invalid_config() {
612 let config = make_test_config();
613 let watcher = ConfigWatcher::new(PathBuf::from("/tmp/punch.toml"), config);
614
615 let mut bad_config = make_test_config();
616 bad_config.default_model.model = "".to_string();
617
618 let result = watcher.apply_change(bad_config);
619 assert!(result.is_err());
620 }
621
622 #[test]
623 fn current_config_accessible_after_creation() {
624 let config = make_test_config();
625 let watcher = ConfigWatcher::new(PathBuf::from("/tmp/punch.toml"), config.clone());
626
627 let current = watcher.current();
628 assert_eq!(current.api_listen, "127.0.0.1:6660");
629 assert_eq!(current.rate_limit_rpm, 60);
630 assert_eq!(current.default_model.model, "claude-sonnet-4-20250514");
631 assert_eq!(current.memory.db_path, "/tmp/punch-test.db");
632 }
633
634 #[test]
635 fn diff_detects_mcp_server_added() {
636 let old = make_test_config();
637 let mut new = old.clone();
638 new.mcp_servers.insert(
639 "filesystem".to_string(),
640 crate::config::McpServerConfig {
641 command: "npx".to_string(),
642 args: vec!["-y".to_string(), "@mcp/filesystem".to_string()],
643 env: HashMap::new(),
644 },
645 );
646
647 let changes = diff_configs(&old, &new);
648 assert!(changes.contains(&ConfigChange::McpServerAdded("filesystem".to_string())));
649 }
650
651 #[test]
652 fn diff_detects_mcp_server_removed() {
653 let mut old = make_test_config();
654 old.mcp_servers.insert(
655 "memory".to_string(),
656 crate::config::McpServerConfig {
657 command: "mcp-memory".to_string(),
658 args: vec![],
659 env: HashMap::new(),
660 },
661 );
662 let new = make_test_config();
663
664 let changes = diff_configs(&old, &new);
665 assert!(changes.contains(&ConfigChange::McpServerRemoved("memory".to_string())));
666 }
667
668 #[test]
669 fn diff_detects_memory_config_changed() {
670 let old = make_test_config();
671 let mut new = old.clone();
672 new.memory.knowledge_graph_enabled = false;
673
674 let changes = diff_configs(&old, &new);
675 assert!(changes.contains(&ConfigChange::MemoryConfigChanged));
676 }
677
678 #[test]
679 fn diff_detects_api_key_changed() {
680 let old = make_test_config();
681 let mut new = old.clone();
682 new.api_key = "new-secret-key".to_string();
683
684 let changes = diff_configs(&old, &new);
685 assert!(changes.contains(&ConfigChange::ApiKeyChanged));
686 }
687
688 #[test]
689 fn diff_detects_listen_address_changed() {
690 let old = make_test_config();
691 let mut new = old.clone();
692 new.api_listen = "0.0.0.0:8080".to_string();
693
694 let changes = diff_configs(&old, &new);
695 assert!(changes.contains(&ConfigChange::ListenAddressChanged {
696 old: "127.0.0.1:6660".to_string(),
697 new: "0.0.0.0:8080".to_string(),
698 }));
699 }
700
701 #[test]
702 fn validate_catches_invalid_socket_addr() {
703 let mut config = make_test_config();
704 config.api_listen = "not-a-valid-address".to_string();
705
706 let errors = validate_config(&config);
707 assert!(errors.iter().any(|e| e.field == "api_listen"
708 && matches!(e.severity, ValidationSeverity::Error)));
709 }
710
711 #[test]
712 fn validate_catches_zero_rate_limit() {
713 let mut config = make_test_config();
714 config.rate_limit_rpm = 0;
715
716 let errors = validate_config(&config);
717 assert!(errors.iter().any(
718 |e| e.field == "rate_limit_rpm" && matches!(e.severity, ValidationSeverity::Error)
719 ));
720 }
721
722 #[tokio::test]
723 async fn subscriber_receives_initial_config() {
724 let config = make_test_config();
725 let watcher = ConfigWatcher::new(PathBuf::from("/tmp/punch.toml"), config.clone());
726
727 let rx = watcher.subscribe();
728 let received = rx.borrow().clone();
729 assert_eq!(received.api_listen, config.api_listen);
730 }
731}