1use std::collections::HashSet;
9use std::path::PathBuf;
10use std::sync::Arc;
11
12use chrono::{DateTime, Utc};
13use notify::{Event, EventKind, RecommendedWatcher, RecursiveMode, Watcher};
14use serde::{Deserialize, Serialize};
15use tokio::sync::{RwLock, watch};
16use tracing::{debug, error, info, warn};
17
18use crate::config::PunchConfig;
19use crate::error::{PunchError, PunchResult};
20
21#[derive(Debug)]
24pub struct ConfigWatcher {
25 config_path: PathBuf,
27 current: Arc<RwLock<PunchConfig>>,
29 tx: watch::Sender<PunchConfig>,
31 rx: watch::Receiver<PunchConfig>,
33}
34
35impl ConfigWatcher {
36 pub fn new(config_path: PathBuf, initial_config: PunchConfig) -> Self {
38 let (tx, rx) = watch::channel(initial_config.clone());
39 Self {
40 config_path,
41 current: Arc::new(RwLock::new(initial_config)),
42 tx,
43 rx,
44 }
45 }
46
47 pub fn subscribe(&self) -> watch::Receiver<PunchConfig> {
49 self.rx.clone()
50 }
51
52 pub fn current(&self) -> PunchConfig {
54 self.rx.borrow().clone()
55 }
56
57 pub async fn start_watching(&self) -> PunchResult<()> {
62 let config_path = self.config_path.clone();
63 let current = Arc::clone(&self.current);
64 let tx = self.tx.clone();
65
66 let watch_path = config_path
68 .parent()
69 .map(|p| p.to_path_buf())
70 .unwrap_or_else(|| PathBuf::from("."));
71
72 let target_file = config_path
73 .file_name()
74 .map(|f| f.to_os_string())
75 .ok_or_else(|| PunchError::Config("config path has no file name".to_string()))?;
76
77 let (notify_tx, mut notify_rx) = tokio::sync::mpsc::channel::<Event>(16);
78
79 let _watcher: RecommendedWatcher = {
81 let notify_tx = notify_tx.clone();
82 let mut watcher =
83 notify::recommended_watcher(move |res: Result<Event, notify::Error>| match res {
84 Ok(event) => {
85 if let Err(e) = notify_tx.blocking_send(event) {
86 error!(error = %e, "failed to forward file event");
87 }
88 }
89 Err(e) => {
90 error!(error = %e, "filesystem watcher error");
91 }
92 })
93 .map_err(|e| PunchError::Config(format!("failed to create file watcher: {}", e)))?;
94
95 watcher
96 .watch(&watch_path, RecursiveMode::NonRecursive)
97 .map_err(|e| {
98 PunchError::Config(format!("failed to watch config directory: {}", e))
99 })?;
100
101 watcher
102 };
103
104 let config_path_for_task = config_path.clone();
105 let target_file_for_task = target_file.clone();
106
107 tokio::spawn(async move {
109 let _watcher = _watcher;
111
112 info!(path = %config_path_for_task.display(), "corner team watching config file");
113
114 while let Some(event) = notify_rx.recv().await {
115 let dominated = matches!(event.kind, EventKind::Modify(_) | EventKind::Create(_));
117 if !dominated {
118 continue;
119 }
120
121 let affects_target = event.paths.iter().any(|p| {
122 p.file_name()
123 .map(|f| f == target_file_for_task)
124 .unwrap_or(false)
125 });
126 if !affects_target {
127 continue;
128 }
129
130 debug!("config file change detected, reloading");
131
132 let content = match tokio::fs::read_to_string(&config_path_for_task).await {
134 Ok(c) => c,
135 Err(e) => {
136 error!(error = %e, "failed to read config file during reload");
137 continue;
138 }
139 };
140
141 let new_config: PunchConfig = match toml::from_str(&content) {
142 Ok(c) => c,
143 Err(e) => {
144 error!(error = %e, "failed to parse config file during reload");
145 continue;
146 }
147 };
148
149 let errors: Vec<_> = validate_config(&new_config)
151 .into_iter()
152 .filter(|v| matches!(v.severity, ValidationSeverity::Error))
153 .collect();
154
155 if !errors.is_empty() {
156 for err in &errors {
157 error!(field = %err.field, message = %err.message, "config validation failed");
158 }
159 continue;
160 }
161
162 let old_config = {
164 let mut guard = current.write().await;
165 let old = guard.clone();
166 *guard = new_config.clone();
167 old
168 };
169
170 let changes = diff_configs(&old_config, &new_config);
171 if changes.is_empty() {
172 debug!("config file changed but no effective differences detected");
173 continue;
174 }
175
176 for change in &changes {
177 info!(change = ?change, "corner team adjustment applied");
178 }
179
180 if tx.send(new_config).is_err() {
182 warn!("no config subscribers remaining — corner team shouting into the void");
183 break;
184 }
185 }
186
187 info!("config watcher task ended");
188 });
189
190 Ok(())
191 }
192
193 pub fn apply_change(
198 &self,
199 new_config: PunchConfig,
200 ) -> Result<ConfigChangeSet, ConfigValidationError> {
201 let validation_errors: Vec<_> = validate_config(&new_config)
202 .into_iter()
203 .filter(|v| matches!(v.severity, ValidationSeverity::Error))
204 .collect();
205
206 if let Some(err) = validation_errors.into_iter().next() {
207 return Err(err);
208 }
209
210 let old_config = self.rx.borrow().clone();
211 let changes = diff_configs(&old_config, &new_config);
212
213 {
215 let current = Arc::clone(&self.current);
216 let new_config_clone = new_config.clone();
217 let rt = tokio::runtime::Handle::try_current();
220 match rt {
221 Ok(handle) => {
222 let current = current.clone();
223 let cfg = new_config_clone.clone();
224 handle.spawn(async move {
225 let mut guard = current.write().await;
226 *guard = cfg;
227 });
228 }
229 Err(_) => {
230 }
233 }
234 }
235
236 let _ = self.tx.send(new_config);
238
239 Ok(ConfigChangeSet {
240 changes,
241 applied_at: Utc::now(),
242 })
243 }
244}
245
246#[derive(Debug, Clone, Serialize, Deserialize)]
248pub struct ConfigChangeSet {
249 pub changes: Vec<ConfigChange>,
251 pub applied_at: DateTime<Utc>,
253}
254
255#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
257pub enum ConfigChange {
258 ModelChanged {
260 old_model: String,
261 new_model: String,
262 },
263 ApiKeyChanged,
265 RateLimitChanged { old: u32, new: u32 },
267 ListenAddressChanged { old: String, new: String },
269 ChannelAdded(String),
271 ChannelRemoved(String),
273 McpServerAdded(String),
275 McpServerRemoved(String),
277 MemoryConfigChanged,
279}
280
281#[derive(Debug, Clone, Serialize, Deserialize)]
283pub struct ConfigValidationError {
284 pub field: String,
286 pub message: String,
288 pub severity: ValidationSeverity,
290}
291
292impl std::fmt::Display for ConfigValidationError {
293 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
294 write!(f, "[{:?}] {}: {}", self.severity, self.field, self.message)
295 }
296}
297
298impl std::error::Error for ConfigValidationError {}
299
300#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
302pub enum ValidationSeverity {
303 Warning,
305 Error,
307}
308
309pub fn diff_configs(old: &PunchConfig, new: &PunchConfig) -> Vec<ConfigChange> {
311 let mut changes = Vec::new();
312
313 if old.default_model.model != new.default_model.model {
315 changes.push(ConfigChange::ModelChanged {
316 old_model: old.default_model.model.clone(),
317 new_model: new.default_model.model.clone(),
318 });
319 }
320
321 if old.api_key != new.api_key {
323 changes.push(ConfigChange::ApiKeyChanged);
324 }
325
326 if old.rate_limit_rpm != new.rate_limit_rpm {
328 changes.push(ConfigChange::RateLimitChanged {
329 old: old.rate_limit_rpm,
330 new: new.rate_limit_rpm,
331 });
332 }
333
334 if old.api_listen != new.api_listen {
336 changes.push(ConfigChange::ListenAddressChanged {
337 old: old.api_listen.clone(),
338 new: new.api_listen.clone(),
339 });
340 }
341
342 let old_channels: HashSet<&String> = old.channels.keys().collect();
344 let new_channels: HashSet<&String> = new.channels.keys().collect();
345
346 for added in new_channels.difference(&old_channels) {
347 changes.push(ConfigChange::ChannelAdded((*added).clone()));
348 }
349 for removed in old_channels.difference(&new_channels) {
350 changes.push(ConfigChange::ChannelRemoved((*removed).clone()));
351 }
352
353 let old_servers: HashSet<&String> = old.mcp_servers.keys().collect();
355 let new_servers: HashSet<&String> = new.mcp_servers.keys().collect();
356
357 for added in new_servers.difference(&old_servers) {
358 changes.push(ConfigChange::McpServerAdded((*added).clone()));
359 }
360 for removed in old_servers.difference(&new_servers) {
361 changes.push(ConfigChange::McpServerRemoved((*removed).clone()));
362 }
363
364 let old_mem = serde_json::to_string(&old.memory).unwrap_or_default();
366 let new_mem = serde_json::to_string(&new.memory).unwrap_or_default();
367 if old_mem != new_mem {
368 changes.push(ConfigChange::MemoryConfigChanged);
369 }
370
371 changes
372}
373
374pub fn validate_config(config: &PunchConfig) -> Vec<ConfigValidationError> {
379 let mut errors = Vec::new();
380
381 if config.api_listen.parse::<std::net::SocketAddr>().is_err() {
383 errors.push(ConfigValidationError {
384 field: "api_listen".to_string(),
385 message: format!(
386 "'{}' is not a valid socket address (expected host:port)",
387 config.api_listen
388 ),
389 severity: ValidationSeverity::Error,
390 });
391 }
392
393 if config.default_model.model.trim().is_empty() {
395 errors.push(ConfigValidationError {
396 field: "default_model.model".to_string(),
397 message: "model name cannot be empty — the fighter needs a stance".to_string(),
398 severity: ValidationSeverity::Error,
399 });
400 }
401
402 if config.memory.db_path.trim().is_empty() {
404 errors.push(ConfigValidationError {
405 field: "memory.db_path".to_string(),
406 message: "database path cannot be empty — the fighter needs memory".to_string(),
407 severity: ValidationSeverity::Error,
408 });
409 }
410
411 if config.rate_limit_rpm == 0 {
413 errors.push(ConfigValidationError {
414 field: "rate_limit_rpm".to_string(),
415 message: "rate limit must be greater than zero — even a slugger needs some pace"
416 .to_string(),
417 severity: ValidationSeverity::Error,
418 });
419 }
420
421 if config.api_key.is_empty() {
423 errors.push(ConfigValidationError {
424 field: "api_key".to_string(),
425 message: "API key is empty — running in dev mode with no authentication".to_string(),
426 severity: ValidationSeverity::Warning,
427 });
428 }
429
430 errors
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436 use crate::config::{MemoryConfig, ModelConfig, Provider};
437 use std::collections::HashMap;
438
439 fn make_test_config() -> PunchConfig {
441 PunchConfig {
442 api_listen: "127.0.0.1:6660".to_string(),
443 api_key: "test-key-123".to_string(),
444 rate_limit_rpm: 60,
445 default_model: ModelConfig {
446 provider: Provider::Anthropic,
447 model: "claude-sonnet-4-20250514".to_string(),
448 api_key_env: Some("ANTHROPIC_API_KEY".to_string()),
449 base_url: None,
450 max_tokens: Some(4096),
451 temperature: Some(0.7),
452 },
453 memory: MemoryConfig {
454 db_path: "/tmp/punch-test.db".to_string(),
455 knowledge_graph_enabled: true,
456 max_entries: Some(10000),
457 },
458 tunnel: None,
459 channels: HashMap::new(),
460 mcp_servers: HashMap::new(),
461 }
462 }
463
464 #[test]
465 fn diff_detects_model_change() {
466 let old = make_test_config();
467 let mut new = old.clone();
468 new.default_model.model = "claude-opus-4-20250514".to_string();
469
470 let changes = diff_configs(&old, &new);
471 assert!(changes.contains(&ConfigChange::ModelChanged {
472 old_model: "claude-sonnet-4-20250514".to_string(),
473 new_model: "claude-opus-4-20250514".to_string(),
474 }));
475 }
476
477 #[test]
478 fn diff_detects_rate_limit_change() {
479 let old = make_test_config();
480 let mut new = old.clone();
481 new.rate_limit_rpm = 120;
482
483 let changes = diff_configs(&old, &new);
484 assert!(changes.contains(&ConfigChange::RateLimitChanged { old: 60, new: 120 }));
485 }
486
487 #[test]
488 fn diff_detects_channel_added() {
489 let old = make_test_config();
490 let mut new = old.clone();
491 new.channels.insert(
492 "slack".to_string(),
493 crate::config::ChannelConfig {
494 channel_type: "slack".to_string(),
495 token_env: Some("SLACK_TOKEN".to_string()),
496 webhook_secret_env: None,
497 allowed_user_ids: vec![],
498 rate_limit_per_user: 20,
499 settings: HashMap::new(),
500 },
501 );
502
503 let changes = diff_configs(&old, &new);
504 assert!(changes.contains(&ConfigChange::ChannelAdded("slack".to_string())));
505 }
506
507 #[test]
508 fn diff_detects_channel_removed() {
509 let mut old = make_test_config();
510 old.channels.insert(
511 "discord".to_string(),
512 crate::config::ChannelConfig {
513 channel_type: "discord".to_string(),
514 token_env: Some("DISCORD_TOKEN".to_string()),
515 webhook_secret_env: None,
516 allowed_user_ids: vec![],
517 rate_limit_per_user: 20,
518 settings: HashMap::new(),
519 },
520 );
521 let new = make_test_config();
522
523 let changes = diff_configs(&old, &new);
524 assert!(changes.contains(&ConfigChange::ChannelRemoved("discord".to_string())));
525 }
526
527 #[test]
528 fn diff_returns_empty_for_identical_configs() {
529 let config = make_test_config();
530 let changes = diff_configs(&config, &config);
531 assert!(
532 changes.is_empty(),
533 "identical configs should produce no changes"
534 );
535 }
536
537 #[test]
538 fn validate_passes_valid_config() {
539 let config = make_test_config();
540 let errors: Vec<_> = validate_config(&config)
541 .into_iter()
542 .filter(|e| matches!(e.severity, ValidationSeverity::Error))
543 .collect();
544 assert!(errors.is_empty(), "valid config should produce no errors");
545 }
546
547 #[test]
548 fn validate_catches_empty_model_name() {
549 let mut config = make_test_config();
550 config.default_model.model = "".to_string();
551
552 let errors = validate_config(&config);
553 assert!(
554 errors.iter().any(|e| e.field == "default_model.model"
555 && matches!(e.severity, ValidationSeverity::Error))
556 );
557 }
558
559 #[test]
560 fn validate_catches_empty_db_path() {
561 let mut config = make_test_config();
562 config.memory.db_path = "".to_string();
563
564 let errors = validate_config(&config);
565 assert!(errors.iter().any(
566 |e| e.field == "memory.db_path" && matches!(e.severity, ValidationSeverity::Error)
567 ));
568 }
569
570 #[test]
571 fn validate_warns_on_empty_api_key() {
572 let mut config = make_test_config();
573 config.api_key = "".to_string();
574
575 let errors = validate_config(&config);
576 assert!(
577 errors
578 .iter()
579 .any(|e| e.field == "api_key" && matches!(e.severity, ValidationSeverity::Warning))
580 );
581 }
582
583 #[test]
584 fn config_watcher_can_be_created() {
585 let config = make_test_config();
586 let watcher = ConfigWatcher::new(PathBuf::from("/tmp/punch.toml"), config.clone());
587 assert_eq!(watcher.current().api_listen, config.api_listen);
588 }
589
590 #[tokio::test]
591 async fn apply_change_returns_change_set() {
592 let config = make_test_config();
593 let watcher = ConfigWatcher::new(PathBuf::from("/tmp/punch.toml"), config);
594
595 let mut new_config = make_test_config();
596 new_config.rate_limit_rpm = 120;
597
598 let result = watcher.apply_change(new_config);
599 assert!(result.is_ok());
600 let change_set = result.expect("should succeed");
601 assert!(
602 change_set
603 .changes
604 .contains(&ConfigChange::RateLimitChanged { old: 60, new: 120 })
605 );
606 }
607
608 #[tokio::test]
609 async fn apply_change_rejects_invalid_config() {
610 let config = make_test_config();
611 let watcher = ConfigWatcher::new(PathBuf::from("/tmp/punch.toml"), config);
612
613 let mut bad_config = make_test_config();
614 bad_config.default_model.model = "".to_string();
615
616 let result = watcher.apply_change(bad_config);
617 assert!(result.is_err());
618 }
619
620 #[test]
621 fn current_config_accessible_after_creation() {
622 let config = make_test_config();
623 let watcher = ConfigWatcher::new(PathBuf::from("/tmp/punch.toml"), config.clone());
624
625 let current = watcher.current();
626 assert_eq!(current.api_listen, "127.0.0.1:6660");
627 assert_eq!(current.rate_limit_rpm, 60);
628 assert_eq!(current.default_model.model, "claude-sonnet-4-20250514");
629 assert_eq!(current.memory.db_path, "/tmp/punch-test.db");
630 }
631
632 #[test]
633 fn diff_detects_mcp_server_added() {
634 let old = make_test_config();
635 let mut new = old.clone();
636 new.mcp_servers.insert(
637 "filesystem".to_string(),
638 crate::config::McpServerConfig {
639 command: "npx".to_string(),
640 args: vec!["-y".to_string(), "@mcp/filesystem".to_string()],
641 env: HashMap::new(),
642 },
643 );
644
645 let changes = diff_configs(&old, &new);
646 assert!(changes.contains(&ConfigChange::McpServerAdded("filesystem".to_string())));
647 }
648
649 #[test]
650 fn diff_detects_mcp_server_removed() {
651 let mut old = make_test_config();
652 old.mcp_servers.insert(
653 "memory".to_string(),
654 crate::config::McpServerConfig {
655 command: "mcp-memory".to_string(),
656 args: vec![],
657 env: HashMap::new(),
658 },
659 );
660 let new = make_test_config();
661
662 let changes = diff_configs(&old, &new);
663 assert!(changes.contains(&ConfigChange::McpServerRemoved("memory".to_string())));
664 }
665
666 #[test]
667 fn diff_detects_memory_config_changed() {
668 let old = make_test_config();
669 let mut new = old.clone();
670 new.memory.knowledge_graph_enabled = false;
671
672 let changes = diff_configs(&old, &new);
673 assert!(changes.contains(&ConfigChange::MemoryConfigChanged));
674 }
675
676 #[test]
677 fn diff_detects_api_key_changed() {
678 let old = make_test_config();
679 let mut new = old.clone();
680 new.api_key = "new-secret-key".to_string();
681
682 let changes = diff_configs(&old, &new);
683 assert!(changes.contains(&ConfigChange::ApiKeyChanged));
684 }
685
686 #[test]
687 fn diff_detects_listen_address_changed() {
688 let old = make_test_config();
689 let mut new = old.clone();
690 new.api_listen = "0.0.0.0:8080".to_string();
691
692 let changes = diff_configs(&old, &new);
693 assert!(changes.contains(&ConfigChange::ListenAddressChanged {
694 old: "127.0.0.1:6660".to_string(),
695 new: "0.0.0.0:8080".to_string(),
696 }));
697 }
698
699 #[test]
700 fn validate_catches_invalid_socket_addr() {
701 let mut config = make_test_config();
702 config.api_listen = "not-a-valid-address".to_string();
703
704 let errors = validate_config(&config);
705 assert!(errors.iter().any(|e| e.field == "api_listen"
706 && matches!(e.severity, ValidationSeverity::Error)));
707 }
708
709 #[test]
710 fn validate_catches_zero_rate_limit() {
711 let mut config = make_test_config();
712 config.rate_limit_rpm = 0;
713
714 let errors = validate_config(&config);
715 assert!(errors.iter().any(
716 |e| e.field == "rate_limit_rpm" && matches!(e.severity, ValidationSeverity::Error)
717 ));
718 }
719
720 #[tokio::test]
721 async fn subscriber_receives_initial_config() {
722 let config = make_test_config();
723 let watcher = ConfigWatcher::new(PathBuf::from("/tmp/punch.toml"), config.clone());
724
725 let rx = watcher.subscribe();
726 let received = rx.borrow().clone();
727 assert_eq!(received.api_listen, config.api_listen);
728 }
729}