Skip to main content

punch_types/
hot_reload.rs

1//! Hot Config Reload — corner team adjustments between rounds.
2//!
3//! This module enables mid-fight strategy changes by watching the config file
4//! for modifications and broadcasting validated updates to all subscribers.
5//! Like a corner team making tactical adjustments between rounds, the system
6//! applies config changes without pulling the fighter from the ring.
7
8use std::collections::HashSet;
9use std::path::PathBuf;
10use std::sync::Arc;
11
12use chrono::{DateTime, Utc};
13use notify::{Event, EventKind, RecommendedWatcher, RecursiveMode, Watcher};
14use serde::{Deserialize, Serialize};
15use tokio::sync::{RwLock, watch};
16use tracing::{debug, error, info, warn};
17
18use crate::config::PunchConfig;
19use crate::error::{PunchError, PunchResult};
20
21/// Watches the config file and broadcasts validated changes — the corner team
22/// that keeps the fighter's strategy sharp without stopping the bout.
23#[derive(Debug)]
24pub struct ConfigWatcher {
25    /// Path to the configuration file being watched.
26    config_path: PathBuf,
27    /// Thread-safe handle to the current configuration.
28    current: Arc<RwLock<PunchConfig>>,
29    /// Sender half of the watch channel for broadcasting config updates.
30    tx: watch::Sender<PunchConfig>,
31    /// Receiver half of the watch channel — cloned for each subscriber.
32    rx: watch::Receiver<PunchConfig>,
33}
34
35impl ConfigWatcher {
36    /// Create a new ConfigWatcher ready to observe the corner team's playbook.
37    pub fn new(config_path: PathBuf, initial_config: PunchConfig) -> Self {
38        let (tx, rx) = watch::channel(initial_config.clone());
39        Self {
40            config_path,
41            current: Arc::new(RwLock::new(initial_config)),
42            tx,
43            rx,
44        }
45    }
46
47    /// Subscribe to config changes — get a ringside seat for every strategy adjustment.
48    pub fn subscribe(&self) -> watch::Receiver<PunchConfig> {
49        self.rx.clone()
50    }
51
52    /// Get a snapshot of the current config — check what game plan the fighter is using right now.
53    pub fn current(&self) -> PunchConfig {
54        self.rx.borrow().clone()
55    }
56
57    /// Start watching the config file for changes — the corner team takes their position.
58    ///
59    /// Spawns a background task that monitors the config file using filesystem events
60    /// and applies validated changes automatically.
61    pub async fn start_watching(&self) -> PunchResult<()> {
62        let config_path = self.config_path.clone();
63        let current = Arc::clone(&self.current);
64        let tx = self.tx.clone();
65
66        // Resolve the parent directory and file name for the watcher.
67        let watch_path = config_path
68            .parent()
69            .map(|p| p.to_path_buf())
70            .unwrap_or_else(|| PathBuf::from("."));
71
72        let target_file = config_path
73            .file_name()
74            .map(|f| f.to_os_string())
75            .ok_or_else(|| PunchError::Config("config path has no file name".to_string()))?;
76
77        let (notify_tx, mut notify_rx) = tokio::sync::mpsc::channel::<Event>(16);
78
79        // Create the filesystem watcher on a blocking thread since notify uses sync callbacks.
80        let _watcher: RecommendedWatcher = {
81            let notify_tx = notify_tx.clone();
82            let mut watcher =
83                notify::recommended_watcher(move |res: Result<Event, notify::Error>| match res {
84                    Ok(event) => {
85                        if let Err(e) = notify_tx.blocking_send(event) {
86                            error!(error = %e, "failed to forward file event");
87                        }
88                    }
89                    Err(e) => {
90                        error!(error = %e, "filesystem watcher error");
91                    }
92                })
93                .map_err(|e| PunchError::Config(format!("failed to create file watcher: {}", e)))?;
94
95            watcher
96                .watch(&watch_path, RecursiveMode::NonRecursive)
97                .map_err(|e| {
98                    PunchError::Config(format!("failed to watch config directory: {}", e))
99                })?;
100
101            watcher
102        };
103
104        let config_path_for_task = config_path.clone();
105        let target_file_for_task = target_file.clone();
106
107        // Spawn the background task — the corner team is now watching the fight.
108        tokio::spawn(async move {
109            // Keep the watcher alive for the lifetime of this task.
110            let _watcher = _watcher;
111
112            info!(path = %config_path_for_task.display(), "corner team watching config file");
113
114            while let Some(event) = notify_rx.recv().await {
115                // Only react to modify/create events on our target file.
116                let dominated = matches!(event.kind, EventKind::Modify(_) | EventKind::Create(_));
117                if !dominated {
118                    continue;
119                }
120
121                let affects_target = event.paths.iter().any(|p| {
122                    p.file_name()
123                        .map(|f| f == target_file_for_task)
124                        .unwrap_or(false)
125                });
126                if !affects_target {
127                    continue;
128                }
129
130                debug!("config file change detected, reloading");
131
132                // Read and parse the new config.
133                let content = match tokio::fs::read_to_string(&config_path_for_task).await {
134                    Ok(c) => c,
135                    Err(e) => {
136                        error!(error = %e, "failed to read config file during reload");
137                        continue;
138                    }
139                };
140
141                let new_config: PunchConfig = match toml::from_str(&content) {
142                    Ok(c) => c,
143                    Err(e) => {
144                        error!(error = %e, "failed to parse config file during reload");
145                        continue;
146                    }
147                };
148
149                // Validate the new config.
150                let errors: Vec<_> = validate_config(&new_config)
151                    .into_iter()
152                    .filter(|v| matches!(v.severity, ValidationSeverity::Error))
153                    .collect();
154
155                if !errors.is_empty() {
156                    for err in &errors {
157                        error!(field = %err.field, message = %err.message, "config validation failed");
158                    }
159                    continue;
160                }
161
162                // Apply the change.
163                let old_config = {
164                    let mut guard = current.write().await;
165                    let old = guard.clone();
166                    *guard = new_config.clone();
167                    old
168                };
169
170                let changes = diff_configs(&old_config, &new_config);
171                if changes.is_empty() {
172                    debug!("config file changed but no effective differences detected");
173                    continue;
174                }
175
176                for change in &changes {
177                    info!(change = ?change, "corner team adjustment applied");
178                }
179
180                // Broadcast the new config to all subscribers.
181                if tx.send(new_config).is_err() {
182                    warn!("no config subscribers remaining — corner team shouting into the void");
183                    break;
184                }
185            }
186
187            info!("config watcher task ended");
188        });
189
190        Ok(())
191    }
192
193    /// Validate and apply a new config programmatically — a direct corner team call.
194    ///
195    /// Returns the set of changes if the config is valid, or a validation error
196    /// if the new config fails checks.
197    pub fn apply_change(
198        &self,
199        new_config: PunchConfig,
200    ) -> Result<ConfigChangeSet, ConfigValidationError> {
201        let validation_errors: Vec<_> = validate_config(&new_config)
202            .into_iter()
203            .filter(|v| matches!(v.severity, ValidationSeverity::Error))
204            .collect();
205
206        if let Some(err) = validation_errors.into_iter().next() {
207            return Err(err);
208        }
209
210        let old_config = self.rx.borrow().clone();
211        let changes = diff_configs(&old_config, &new_config);
212
213        // Update current config behind the lock (blocking context is fine for apply_change).
214        {
215            let current = Arc::clone(&self.current);
216            let new_config_clone = new_config.clone();
217            // Use try_write to avoid blocking in sync context. If contended, fall back to
218            // a blocking write via std::thread::spawn, but for apply_change this is acceptable.
219            let rt = tokio::runtime::Handle::try_current();
220            match rt {
221                Ok(handle) => {
222                    let current = current.clone();
223                    let cfg = new_config_clone.clone();
224                    handle.spawn(async move {
225                        let mut guard = current.write().await;
226                        *guard = cfg;
227                    });
228                }
229                Err(_) => {
230                    // If no runtime, we're in a sync context — just best-effort.
231                    // The watch channel is the source of truth anyway.
232                }
233            }
234        }
235
236        // Broadcast through the watch channel — this is the authoritative update.
237        let _ = self.tx.send(new_config);
238
239        Ok(ConfigChangeSet {
240            changes,
241            applied_at: Utc::now(),
242        })
243    }
244}
245
246/// A set of changes applied in a single config reload — the corner team's adjustment notes.
247#[derive(Debug, Clone, Serialize, Deserialize)]
248pub struct ConfigChangeSet {
249    /// Individual changes detected between old and new configs.
250    pub changes: Vec<ConfigChange>,
251    /// Timestamp when these adjustments were applied.
252    pub applied_at: DateTime<Utc>,
253}
254
255/// A single configuration change detected during a reload.
256#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
257pub enum ConfigChange {
258    /// The default model was swapped — switching fighting stance.
259    ModelChanged {
260        old_model: String,
261        new_model: String,
262    },
263    /// API key was rotated — new credentials for the fight.
264    ApiKeyChanged,
265    /// Rate limit was adjusted — changing the pace of the bout.
266    RateLimitChanged { old: u32, new: u32 },
267    /// Listen address was changed — moving to a different ring.
268    ListenAddressChanged { old: String, new: String },
269    /// A new channel entered the arena.
270    ChannelAdded(String),
271    /// A channel was pulled from the fight card.
272    ChannelRemoved(String),
273    /// A new MCP server joined the corner team.
274    McpServerAdded(String),
275    /// An MCP server was cut from the roster.
276    McpServerRemoved(String),
277    /// Memory configuration was adjusted — changing the fighter's recall strategy.
278    MemoryConfigChanged,
279}
280
281/// A validation error found in a config — a foul called by the referee.
282#[derive(Debug, Clone, Serialize, Deserialize)]
283pub struct ConfigValidationError {
284    /// The config field that failed validation.
285    pub field: String,
286    /// Human-readable description of the issue.
287    pub message: String,
288    /// How severe this validation failure is.
289    pub severity: ValidationSeverity,
290}
291
292impl std::fmt::Display for ConfigValidationError {
293    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
294        write!(f, "[{:?}] {}: {}", self.severity, self.field, self.message)
295    }
296}
297
298impl std::error::Error for ConfigValidationError {}
299
300/// Severity of a configuration validation issue.
301#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
302pub enum ValidationSeverity {
303    /// Something worth noting but not a showstopper — the fighter can continue.
304    Warning,
305    /// A hard foul — the config cannot be accepted.
306    Error,
307}
308
309/// Compare two configs and enumerate what changed — scouting the opponent's adjustments.
310pub fn diff_configs(old: &PunchConfig, new: &PunchConfig) -> Vec<ConfigChange> {
311    let mut changes = Vec::new();
312
313    // Model change
314    if old.default_model.model != new.default_model.model {
315        changes.push(ConfigChange::ModelChanged {
316            old_model: old.default_model.model.clone(),
317            new_model: new.default_model.model.clone(),
318        });
319    }
320
321    // API key change
322    if old.api_key != new.api_key {
323        changes.push(ConfigChange::ApiKeyChanged);
324    }
325
326    // Rate limit change
327    if old.rate_limit_rpm != new.rate_limit_rpm {
328        changes.push(ConfigChange::RateLimitChanged {
329            old: old.rate_limit_rpm,
330            new: new.rate_limit_rpm,
331        });
332    }
333
334    // Listen address change
335    if old.api_listen != new.api_listen {
336        changes.push(ConfigChange::ListenAddressChanged {
337            old: old.api_listen.clone(),
338            new: new.api_listen.clone(),
339        });
340    }
341
342    // Channel diffs
343    let old_channels: HashSet<&String> = old.channels.keys().collect();
344    let new_channels: HashSet<&String> = new.channels.keys().collect();
345
346    for added in new_channels.difference(&old_channels) {
347        changes.push(ConfigChange::ChannelAdded((*added).clone()));
348    }
349    for removed in old_channels.difference(&new_channels) {
350        changes.push(ConfigChange::ChannelRemoved((*removed).clone()));
351    }
352
353    // MCP server diffs
354    let old_servers: HashSet<&String> = old.mcp_servers.keys().collect();
355    let new_servers: HashSet<&String> = new.mcp_servers.keys().collect();
356
357    for added in new_servers.difference(&old_servers) {
358        changes.push(ConfigChange::McpServerAdded((*added).clone()));
359    }
360    for removed in old_servers.difference(&new_servers) {
361        changes.push(ConfigChange::McpServerRemoved((*removed).clone()));
362    }
363
364    // Memory config change — compare serialized forms to catch any field differences.
365    let old_mem = serde_json::to_string(&old.memory).unwrap_or_default();
366    let new_mem = serde_json::to_string(&new.memory).unwrap_or_default();
367    if old_mem != new_mem {
368        changes.push(ConfigChange::MemoryConfigChanged);
369    }
370
371    changes
372}
373
374/// Validate a config for correctness — the referee's pre-fight inspection.
375///
376/// Returns a list of validation issues. Errors must be fixed before the config
377/// can be accepted; warnings are advisory.
378pub fn validate_config(config: &PunchConfig) -> Vec<ConfigValidationError> {
379    let mut errors = Vec::new();
380
381    // Check api_listen is a valid socket address format.
382    if config.api_listen.parse::<std::net::SocketAddr>().is_err() {
383        errors.push(ConfigValidationError {
384            field: "api_listen".to_string(),
385            message: format!(
386                "'{}' is not a valid socket address (expected host:port)",
387                config.api_listen
388            ),
389            severity: ValidationSeverity::Error,
390        });
391    }
392
393    // Check default_model has a non-empty model name.
394    if config.default_model.model.trim().is_empty() {
395        errors.push(ConfigValidationError {
396            field: "default_model.model".to_string(),
397            message: "model name cannot be empty — the fighter needs a stance".to_string(),
398            severity: ValidationSeverity::Error,
399        });
400    }
401
402    // Check memory db_path is non-empty.
403    if config.memory.db_path.trim().is_empty() {
404        errors.push(ConfigValidationError {
405            field: "memory.db_path".to_string(),
406            message: "database path cannot be empty — the fighter needs memory".to_string(),
407            severity: ValidationSeverity::Error,
408        });
409    }
410
411    // Check rate_limit_rpm is > 0.
412    if config.rate_limit_rpm == 0 {
413        errors.push(ConfigValidationError {
414            field: "rate_limit_rpm".to_string(),
415            message: "rate limit must be greater than zero — even a slugger needs some pace"
416                .to_string(),
417            severity: ValidationSeverity::Error,
418        });
419    }
420
421    // Warn if api_key is empty (dev mode).
422    if config.api_key.is_empty() {
423        errors.push(ConfigValidationError {
424            field: "api_key".to_string(),
425            message: "API key is empty — running in dev mode with no authentication".to_string(),
426            severity: ValidationSeverity::Warning,
427        });
428    }
429
430    errors
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436    use crate::config::{MemoryConfig, ModelConfig, Provider};
437    use std::collections::HashMap;
438
439    /// Build a valid test config — a well-prepared fighter entering the ring.
440    fn make_test_config() -> PunchConfig {
441        PunchConfig {
442            api_listen: "127.0.0.1:6660".to_string(),
443            api_key: "test-key-123".to_string(),
444            rate_limit_rpm: 60,
445            default_model: ModelConfig {
446                provider: Provider::Anthropic,
447                model: "claude-sonnet-4-20250514".to_string(),
448                api_key_env: Some("ANTHROPIC_API_KEY".to_string()),
449                base_url: None,
450                max_tokens: Some(4096),
451                temperature: Some(0.7),
452            },
453            memory: MemoryConfig {
454                db_path: "/tmp/punch-test.db".to_string(),
455                knowledge_graph_enabled: true,
456                max_entries: Some(10000),
457            },
458            tunnel: None,
459            channels: HashMap::new(),
460            mcp_servers: HashMap::new(),
461        }
462    }
463
464    #[test]
465    fn diff_detects_model_change() {
466        let old = make_test_config();
467        let mut new = old.clone();
468        new.default_model.model = "claude-opus-4-20250514".to_string();
469
470        let changes = diff_configs(&old, &new);
471        assert!(changes.contains(&ConfigChange::ModelChanged {
472            old_model: "claude-sonnet-4-20250514".to_string(),
473            new_model: "claude-opus-4-20250514".to_string(),
474        }));
475    }
476
477    #[test]
478    fn diff_detects_rate_limit_change() {
479        let old = make_test_config();
480        let mut new = old.clone();
481        new.rate_limit_rpm = 120;
482
483        let changes = diff_configs(&old, &new);
484        assert!(changes.contains(&ConfigChange::RateLimitChanged { old: 60, new: 120 }));
485    }
486
487    #[test]
488    fn diff_detects_channel_added() {
489        let old = make_test_config();
490        let mut new = old.clone();
491        new.channels.insert(
492            "slack".to_string(),
493            crate::config::ChannelConfig {
494                channel_type: "slack".to_string(),
495                token_env: Some("SLACK_TOKEN".to_string()),
496                webhook_secret_env: None,
497                allowed_user_ids: vec![],
498                rate_limit_per_user: 20,
499                settings: HashMap::new(),
500            },
501        );
502
503        let changes = diff_configs(&old, &new);
504        assert!(changes.contains(&ConfigChange::ChannelAdded("slack".to_string())));
505    }
506
507    #[test]
508    fn diff_detects_channel_removed() {
509        let mut old = make_test_config();
510        old.channels.insert(
511            "discord".to_string(),
512            crate::config::ChannelConfig {
513                channel_type: "discord".to_string(),
514                token_env: Some("DISCORD_TOKEN".to_string()),
515                webhook_secret_env: None,
516                allowed_user_ids: vec![],
517                rate_limit_per_user: 20,
518                settings: HashMap::new(),
519            },
520        );
521        let new = make_test_config();
522
523        let changes = diff_configs(&old, &new);
524        assert!(changes.contains(&ConfigChange::ChannelRemoved("discord".to_string())));
525    }
526
527    #[test]
528    fn diff_returns_empty_for_identical_configs() {
529        let config = make_test_config();
530        let changes = diff_configs(&config, &config);
531        assert!(
532            changes.is_empty(),
533            "identical configs should produce no changes"
534        );
535    }
536
537    #[test]
538    fn validate_passes_valid_config() {
539        let config = make_test_config();
540        let errors: Vec<_> = validate_config(&config)
541            .into_iter()
542            .filter(|e| matches!(e.severity, ValidationSeverity::Error))
543            .collect();
544        assert!(errors.is_empty(), "valid config should produce no errors");
545    }
546
547    #[test]
548    fn validate_catches_empty_model_name() {
549        let mut config = make_test_config();
550        config.default_model.model = "".to_string();
551
552        let errors = validate_config(&config);
553        assert!(
554            errors.iter().any(|e| e.field == "default_model.model"
555                && matches!(e.severity, ValidationSeverity::Error))
556        );
557    }
558
559    #[test]
560    fn validate_catches_empty_db_path() {
561        let mut config = make_test_config();
562        config.memory.db_path = "".to_string();
563
564        let errors = validate_config(&config);
565        assert!(errors.iter().any(
566            |e| e.field == "memory.db_path" && matches!(e.severity, ValidationSeverity::Error)
567        ));
568    }
569
570    #[test]
571    fn validate_warns_on_empty_api_key() {
572        let mut config = make_test_config();
573        config.api_key = "".to_string();
574
575        let errors = validate_config(&config);
576        assert!(
577            errors
578                .iter()
579                .any(|e| e.field == "api_key" && matches!(e.severity, ValidationSeverity::Warning))
580        );
581    }
582
583    #[test]
584    fn config_watcher_can_be_created() {
585        let config = make_test_config();
586        let watcher = ConfigWatcher::new(PathBuf::from("/tmp/punch.toml"), config.clone());
587        assert_eq!(watcher.current().api_listen, config.api_listen);
588    }
589
590    #[tokio::test]
591    async fn apply_change_returns_change_set() {
592        let config = make_test_config();
593        let watcher = ConfigWatcher::new(PathBuf::from("/tmp/punch.toml"), config);
594
595        let mut new_config = make_test_config();
596        new_config.rate_limit_rpm = 120;
597
598        let result = watcher.apply_change(new_config);
599        assert!(result.is_ok());
600        let change_set = result.expect("should succeed");
601        assert!(
602            change_set
603                .changes
604                .contains(&ConfigChange::RateLimitChanged { old: 60, new: 120 })
605        );
606    }
607
608    #[tokio::test]
609    async fn apply_change_rejects_invalid_config() {
610        let config = make_test_config();
611        let watcher = ConfigWatcher::new(PathBuf::from("/tmp/punch.toml"), config);
612
613        let mut bad_config = make_test_config();
614        bad_config.default_model.model = "".to_string();
615
616        let result = watcher.apply_change(bad_config);
617        assert!(result.is_err());
618    }
619
620    #[test]
621    fn current_config_accessible_after_creation() {
622        let config = make_test_config();
623        let watcher = ConfigWatcher::new(PathBuf::from("/tmp/punch.toml"), config.clone());
624
625        let current = watcher.current();
626        assert_eq!(current.api_listen, "127.0.0.1:6660");
627        assert_eq!(current.rate_limit_rpm, 60);
628        assert_eq!(current.default_model.model, "claude-sonnet-4-20250514");
629        assert_eq!(current.memory.db_path, "/tmp/punch-test.db");
630    }
631
632    #[test]
633    fn diff_detects_mcp_server_added() {
634        let old = make_test_config();
635        let mut new = old.clone();
636        new.mcp_servers.insert(
637            "filesystem".to_string(),
638            crate::config::McpServerConfig {
639                command: "npx".to_string(),
640                args: vec!["-y".to_string(), "@mcp/filesystem".to_string()],
641                env: HashMap::new(),
642            },
643        );
644
645        let changes = diff_configs(&old, &new);
646        assert!(changes.contains(&ConfigChange::McpServerAdded("filesystem".to_string())));
647    }
648
649    #[test]
650    fn diff_detects_mcp_server_removed() {
651        let mut old = make_test_config();
652        old.mcp_servers.insert(
653            "memory".to_string(),
654            crate::config::McpServerConfig {
655                command: "mcp-memory".to_string(),
656                args: vec![],
657                env: HashMap::new(),
658            },
659        );
660        let new = make_test_config();
661
662        let changes = diff_configs(&old, &new);
663        assert!(changes.contains(&ConfigChange::McpServerRemoved("memory".to_string())));
664    }
665
666    #[test]
667    fn diff_detects_memory_config_changed() {
668        let old = make_test_config();
669        let mut new = old.clone();
670        new.memory.knowledge_graph_enabled = false;
671
672        let changes = diff_configs(&old, &new);
673        assert!(changes.contains(&ConfigChange::MemoryConfigChanged));
674    }
675
676    #[test]
677    fn diff_detects_api_key_changed() {
678        let old = make_test_config();
679        let mut new = old.clone();
680        new.api_key = "new-secret-key".to_string();
681
682        let changes = diff_configs(&old, &new);
683        assert!(changes.contains(&ConfigChange::ApiKeyChanged));
684    }
685
686    #[test]
687    fn diff_detects_listen_address_changed() {
688        let old = make_test_config();
689        let mut new = old.clone();
690        new.api_listen = "0.0.0.0:8080".to_string();
691
692        let changes = diff_configs(&old, &new);
693        assert!(changes.contains(&ConfigChange::ListenAddressChanged {
694            old: "127.0.0.1:6660".to_string(),
695            new: "0.0.0.0:8080".to_string(),
696        }));
697    }
698
699    #[test]
700    fn validate_catches_invalid_socket_addr() {
701        let mut config = make_test_config();
702        config.api_listen = "not-a-valid-address".to_string();
703
704        let errors = validate_config(&config);
705        assert!(errors.iter().any(|e| e.field == "api_listen"
706            && matches!(e.severity, ValidationSeverity::Error)));
707    }
708
709    #[test]
710    fn validate_catches_zero_rate_limit() {
711        let mut config = make_test_config();
712        config.rate_limit_rpm = 0;
713
714        let errors = validate_config(&config);
715        assert!(errors.iter().any(
716            |e| e.field == "rate_limit_rpm" && matches!(e.severity, ValidationSeverity::Error)
717        ));
718    }
719
720    #[tokio::test]
721    async fn subscriber_receives_initial_config() {
722        let config = make_test_config();
723        let watcher = ConfigWatcher::new(PathBuf::from("/tmp/punch.toml"), config.clone());
724
725        let rx = watcher.subscribe();
726        let received = rx.borrow().clone();
727        assert_eq!(received.api_listen, config.api_listen);
728    }
729}