Skip to main content

punch_types/
hot_reload.rs

1//! Hot Config Reload — corner team adjustments between rounds.
2//!
3//! This module enables mid-fight strategy changes by watching the config file
4//! for modifications and broadcasting validated updates to all subscribers.
5//! Like a corner team making tactical adjustments between rounds, the system
6//! applies config changes without pulling the fighter from the ring.
7
8use std::collections::HashSet;
9use std::path::PathBuf;
10use std::sync::Arc;
11
12use chrono::{DateTime, Utc};
13use notify::{Event, EventKind, RecommendedWatcher, RecursiveMode, Watcher};
14use serde::{Deserialize, Serialize};
15use tokio::sync::{RwLock, watch};
16use tracing::{debug, error, info, warn};
17
18use crate::config::PunchConfig;
19use crate::error::{PunchError, PunchResult};
20
21/// Watches the config file and broadcasts validated changes — the corner team
22/// that keeps the fighter's strategy sharp without stopping the bout.
23#[derive(Debug)]
24pub struct ConfigWatcher {
25    /// Path to the configuration file being watched.
26    config_path: PathBuf,
27    /// Thread-safe handle to the current configuration.
28    current: Arc<RwLock<PunchConfig>>,
29    /// Sender half of the watch channel for broadcasting config updates.
30    tx: watch::Sender<PunchConfig>,
31    /// Receiver half of the watch channel — cloned for each subscriber.
32    rx: watch::Receiver<PunchConfig>,
33}
34
35impl ConfigWatcher {
36    /// Create a new ConfigWatcher ready to observe the corner team's playbook.
37    pub fn new(config_path: PathBuf, initial_config: PunchConfig) -> Self {
38        let (tx, rx) = watch::channel(initial_config.clone());
39        Self {
40            config_path,
41            current: Arc::new(RwLock::new(initial_config)),
42            tx,
43            rx,
44        }
45    }
46
47    /// Subscribe to config changes — get a ringside seat for every strategy adjustment.
48    pub fn subscribe(&self) -> watch::Receiver<PunchConfig> {
49        self.rx.clone()
50    }
51
52    /// Get a snapshot of the current config — check what game plan the fighter is using right now.
53    pub fn current(&self) -> PunchConfig {
54        self.rx.borrow().clone()
55    }
56
57    /// Start watching the config file for changes — the corner team takes their position.
58    ///
59    /// Spawns a background task that monitors the config file using filesystem events
60    /// and applies validated changes automatically.
61    pub async fn start_watching(&self) -> PunchResult<()> {
62        let config_path = self.config_path.clone();
63        let current = Arc::clone(&self.current);
64        let tx = self.tx.clone();
65
66        // Resolve the parent directory and file name for the watcher.
67        let watch_path = config_path
68            .parent()
69            .map(|p| p.to_path_buf())
70            .unwrap_or_else(|| PathBuf::from("."));
71
72        let target_file = config_path
73            .file_name()
74            .map(|f| f.to_os_string())
75            .ok_or_else(|| PunchError::Config("config path has no file name".to_string()))?;
76
77        let (notify_tx, mut notify_rx) = tokio::sync::mpsc::channel::<Event>(16);
78
79        // Create the filesystem watcher on a blocking thread since notify uses sync callbacks.
80        let _watcher: RecommendedWatcher = {
81            let notify_tx = notify_tx.clone();
82            let mut watcher =
83                notify::recommended_watcher(move |res: Result<Event, notify::Error>| match res {
84                    Ok(event) => {
85                        if let Err(e) = notify_tx.blocking_send(event) {
86                            error!(error = %e, "failed to forward file event");
87                        }
88                    }
89                    Err(e) => {
90                        error!(error = %e, "filesystem watcher error");
91                    }
92                })
93                .map_err(|e| PunchError::Config(format!("failed to create file watcher: {}", e)))?;
94
95            watcher
96                .watch(&watch_path, RecursiveMode::NonRecursive)
97                .map_err(|e| {
98                    PunchError::Config(format!("failed to watch config directory: {}", e))
99                })?;
100
101            watcher
102        };
103
104        let config_path_for_task = config_path.clone();
105        let target_file_for_task = target_file.clone();
106
107        // Spawn the background task — the corner team is now watching the fight.
108        tokio::spawn(async move {
109            // Keep the watcher alive for the lifetime of this task.
110            let _watcher = _watcher;
111
112            info!(path = %config_path_for_task.display(), "corner team watching config file");
113
114            while let Some(event) = notify_rx.recv().await {
115                // Only react to modify/create events on our target file.
116                let dominated = matches!(event.kind, EventKind::Modify(_) | EventKind::Create(_));
117                if !dominated {
118                    continue;
119                }
120
121                let affects_target = event.paths.iter().any(|p| {
122                    p.file_name()
123                        .map(|f| f == target_file_for_task)
124                        .unwrap_or(false)
125                });
126                if !affects_target {
127                    continue;
128                }
129
130                debug!("config file change detected, reloading");
131
132                // Read and parse the new config.
133                let content = match tokio::fs::read_to_string(&config_path_for_task).await {
134                    Ok(c) => c,
135                    Err(e) => {
136                        error!(error = %e, "failed to read config file during reload");
137                        continue;
138                    }
139                };
140
141                let new_config: PunchConfig = match toml::from_str(&content) {
142                    Ok(c) => c,
143                    Err(e) => {
144                        error!(error = %e, "failed to parse config file during reload");
145                        continue;
146                    }
147                };
148
149                // Validate the new config.
150                let errors: Vec<_> = validate_config(&new_config)
151                    .into_iter()
152                    .filter(|v| matches!(v.severity, ValidationSeverity::Error))
153                    .collect();
154
155                if !errors.is_empty() {
156                    for err in &errors {
157                        error!(field = %err.field, message = %err.message, "config validation failed");
158                    }
159                    continue;
160                }
161
162                // Apply the change.
163                let old_config = {
164                    let mut guard = current.write().await;
165                    let old = guard.clone();
166                    *guard = new_config.clone();
167                    old
168                };
169
170                let changes = diff_configs(&old_config, &new_config);
171                if changes.is_empty() {
172                    debug!("config file changed but no effective differences detected");
173                    continue;
174                }
175
176                for change in &changes {
177                    info!(change = ?change, "corner team adjustment applied");
178                }
179
180                // Broadcast the new config to all subscribers.
181                if tx.send(new_config).is_err() {
182                    warn!("no config subscribers remaining — corner team shouting into the void");
183                    break;
184                }
185            }
186
187            info!("config watcher task ended");
188        });
189
190        Ok(())
191    }
192
193    /// Validate and apply a new config programmatically — a direct corner team call.
194    ///
195    /// Returns the set of changes if the config is valid, or a validation error
196    /// if the new config fails checks.
197    pub fn apply_change(
198        &self,
199        new_config: PunchConfig,
200    ) -> Result<ConfigChangeSet, ConfigValidationError> {
201        let validation_errors: Vec<_> = validate_config(&new_config)
202            .into_iter()
203            .filter(|v| matches!(v.severity, ValidationSeverity::Error))
204            .collect();
205
206        if let Some(err) = validation_errors.into_iter().next() {
207            return Err(err);
208        }
209
210        let old_config = self.rx.borrow().clone();
211        let changes = diff_configs(&old_config, &new_config);
212
213        // Update current config behind the lock (blocking context is fine for apply_change).
214        {
215            let current = Arc::clone(&self.current);
216            let new_config_clone = new_config.clone();
217            // Use try_write to avoid blocking in sync context. If contended, fall back to
218            // a blocking write via std::thread::spawn, but for apply_change this is acceptable.
219            let rt = tokio::runtime::Handle::try_current();
220            match rt {
221                Ok(handle) => {
222                    let current = current.clone();
223                    let cfg = new_config_clone.clone();
224                    handle.spawn(async move {
225                        let mut guard = current.write().await;
226                        *guard = cfg;
227                    });
228                }
229                Err(_) => {
230                    // If no runtime, we're in a sync context — just best-effort.
231                    // The watch channel is the source of truth anyway.
232                }
233            }
234        }
235
236        // Broadcast through the watch channel — this is the authoritative update.
237        let _ = self.tx.send(new_config);
238
239        Ok(ConfigChangeSet {
240            changes,
241            applied_at: Utc::now(),
242        })
243    }
244}
245
246/// A set of changes applied in a single config reload — the corner team's adjustment notes.
247#[derive(Debug, Clone, Serialize, Deserialize)]
248pub struct ConfigChangeSet {
249    /// Individual changes detected between old and new configs.
250    pub changes: Vec<ConfigChange>,
251    /// Timestamp when these adjustments were applied.
252    pub applied_at: DateTime<Utc>,
253}
254
255/// A single configuration change detected during a reload.
256#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
257pub enum ConfigChange {
258    /// The default model was swapped — switching fighting stance.
259    ModelChanged {
260        old_model: String,
261        new_model: String,
262    },
263    /// API key was rotated — new credentials for the fight.
264    ApiKeyChanged,
265    /// Rate limit was adjusted — changing the pace of the bout.
266    RateLimitChanged { old: u32, new: u32 },
267    /// Listen address was changed — moving to a different ring.
268    ListenAddressChanged { old: String, new: String },
269    /// A new channel entered the arena.
270    ChannelAdded(String),
271    /// A channel was pulled from the fight card.
272    ChannelRemoved(String),
273    /// A new MCP server joined the corner team.
274    McpServerAdded(String),
275    /// An MCP server was cut from the roster.
276    McpServerRemoved(String),
277    /// Memory configuration was adjusted — changing the fighter's recall strategy.
278    MemoryConfigChanged,
279}
280
281/// A validation error found in a config — a foul called by the referee.
282#[derive(Debug, Clone, Serialize, Deserialize)]
283pub struct ConfigValidationError {
284    /// The config field that failed validation.
285    pub field: String,
286    /// Human-readable description of the issue.
287    pub message: String,
288    /// How severe this validation failure is.
289    pub severity: ValidationSeverity,
290}
291
292impl std::fmt::Display for ConfigValidationError {
293    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
294        write!(f, "[{:?}] {}: {}", self.severity, self.field, self.message)
295    }
296}
297
298impl std::error::Error for ConfigValidationError {}
299
300/// Severity of a configuration validation issue.
301#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
302pub enum ValidationSeverity {
303    /// Something worth noting but not a showstopper — the fighter can continue.
304    Warning,
305    /// A hard foul — the config cannot be accepted.
306    Error,
307}
308
309/// Compare two configs and enumerate what changed — scouting the opponent's adjustments.
310pub fn diff_configs(old: &PunchConfig, new: &PunchConfig) -> Vec<ConfigChange> {
311    let mut changes = Vec::new();
312
313    // Model change
314    if old.default_model.model != new.default_model.model {
315        changes.push(ConfigChange::ModelChanged {
316            old_model: old.default_model.model.clone(),
317            new_model: new.default_model.model.clone(),
318        });
319    }
320
321    // API key change
322    if old.api_key != new.api_key {
323        changes.push(ConfigChange::ApiKeyChanged);
324    }
325
326    // Rate limit change
327    if old.rate_limit_rpm != new.rate_limit_rpm {
328        changes.push(ConfigChange::RateLimitChanged {
329            old: old.rate_limit_rpm,
330            new: new.rate_limit_rpm,
331        });
332    }
333
334    // Listen address change
335    if old.api_listen != new.api_listen {
336        changes.push(ConfigChange::ListenAddressChanged {
337            old: old.api_listen.clone(),
338            new: new.api_listen.clone(),
339        });
340    }
341
342    // Channel diffs
343    let old_channels: HashSet<&String> = old.channels.keys().collect();
344    let new_channels: HashSet<&String> = new.channels.keys().collect();
345
346    for added in new_channels.difference(&old_channels) {
347        changes.push(ConfigChange::ChannelAdded((*added).clone()));
348    }
349    for removed in old_channels.difference(&new_channels) {
350        changes.push(ConfigChange::ChannelRemoved((*removed).clone()));
351    }
352
353    // MCP server diffs
354    let old_servers: HashSet<&String> = old.mcp_servers.keys().collect();
355    let new_servers: HashSet<&String> = new.mcp_servers.keys().collect();
356
357    for added in new_servers.difference(&old_servers) {
358        changes.push(ConfigChange::McpServerAdded((*added).clone()));
359    }
360    for removed in old_servers.difference(&new_servers) {
361        changes.push(ConfigChange::McpServerRemoved((*removed).clone()));
362    }
363
364    // Memory config change — compare serialized forms to catch any field differences.
365    let old_mem = serde_json::to_string(&old.memory).unwrap_or_default();
366    let new_mem = serde_json::to_string(&new.memory).unwrap_or_default();
367    if old_mem != new_mem {
368        changes.push(ConfigChange::MemoryConfigChanged);
369    }
370
371    changes
372}
373
374/// Validate a config for correctness — the referee's pre-fight inspection.
375///
376/// Returns a list of validation issues. Errors must be fixed before the config
377/// can be accepted; warnings are advisory.
378pub fn validate_config(config: &PunchConfig) -> Vec<ConfigValidationError> {
379    let mut errors = Vec::new();
380
381    // Check api_listen is a valid socket address format.
382    if config.api_listen.parse::<std::net::SocketAddr>().is_err() {
383        errors.push(ConfigValidationError {
384            field: "api_listen".to_string(),
385            message: format!(
386                "'{}' is not a valid socket address (expected host:port)",
387                config.api_listen
388            ),
389            severity: ValidationSeverity::Error,
390        });
391    }
392
393    // Check default_model has a non-empty model name.
394    if config.default_model.model.trim().is_empty() {
395        errors.push(ConfigValidationError {
396            field: "default_model.model".to_string(),
397            message: "model name cannot be empty — the fighter needs a stance".to_string(),
398            severity: ValidationSeverity::Error,
399        });
400    }
401
402    // Check memory db_path is non-empty.
403    if config.memory.db_path.trim().is_empty() {
404        errors.push(ConfigValidationError {
405            field: "memory.db_path".to_string(),
406            message: "database path cannot be empty — the fighter needs memory".to_string(),
407            severity: ValidationSeverity::Error,
408        });
409    }
410
411    // Check rate_limit_rpm is > 0.
412    if config.rate_limit_rpm == 0 {
413        errors.push(ConfigValidationError {
414            field: "rate_limit_rpm".to_string(),
415            message: "rate limit must be greater than zero — even a slugger needs some pace"
416                .to_string(),
417            severity: ValidationSeverity::Error,
418        });
419    }
420
421    // Warn if api_key is empty (dev mode).
422    if config.api_key.is_empty() {
423        errors.push(ConfigValidationError {
424            field: "api_key".to_string(),
425            message: "API key is empty — running in dev mode with no authentication".to_string(),
426            severity: ValidationSeverity::Warning,
427        });
428    }
429
430    errors
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436    use crate::config::{MemoryConfig, ModelConfig, Provider};
437    use std::collections::HashMap;
438
439    /// Build a valid test config — a well-prepared fighter entering the ring.
440    fn make_test_config() -> PunchConfig {
441        PunchConfig {
442            api_listen: "127.0.0.1:6660".to_string(),
443            api_key: "test-key-123".to_string(),
444            rate_limit_rpm: 60,
445            default_model: ModelConfig {
446                provider: Provider::Anthropic,
447                model: "claude-sonnet-4-20250514".to_string(),
448                api_key_env: Some("ANTHROPIC_API_KEY".to_string()),
449                base_url: None,
450                max_tokens: Some(4096),
451                temperature: Some(0.7),
452            },
453            memory: MemoryConfig {
454                db_path: "/tmp/punch-test.db".to_string(),
455                knowledge_graph_enabled: true,
456                max_entries: Some(10000),
457            },
458            tunnel: None,
459            channels: HashMap::new(),
460            mcp_servers: HashMap::new(),
461            model_routing: Default::default(),
462            budget: Default::default(),
463        }
464    }
465
466    #[test]
467    fn diff_detects_model_change() {
468        let old = make_test_config();
469        let mut new = old.clone();
470        new.default_model.model = "claude-opus-4-20250514".to_string();
471
472        let changes = diff_configs(&old, &new);
473        assert!(changes.contains(&ConfigChange::ModelChanged {
474            old_model: "claude-sonnet-4-20250514".to_string(),
475            new_model: "claude-opus-4-20250514".to_string(),
476        }));
477    }
478
479    #[test]
480    fn diff_detects_rate_limit_change() {
481        let old = make_test_config();
482        let mut new = old.clone();
483        new.rate_limit_rpm = 120;
484
485        let changes = diff_configs(&old, &new);
486        assert!(changes.contains(&ConfigChange::RateLimitChanged { old: 60, new: 120 }));
487    }
488
489    #[test]
490    fn diff_detects_channel_added() {
491        let old = make_test_config();
492        let mut new = old.clone();
493        new.channels.insert(
494            "slack".to_string(),
495            crate::config::ChannelConfig {
496                channel_type: "slack".to_string(),
497                token_env: Some("SLACK_TOKEN".to_string()),
498                webhook_secret_env: None,
499                allowed_user_ids: vec![],
500                rate_limit_per_user: 20,
501                settings: HashMap::new(),
502            },
503        );
504
505        let changes = diff_configs(&old, &new);
506        assert!(changes.contains(&ConfigChange::ChannelAdded("slack".to_string())));
507    }
508
509    #[test]
510    fn diff_detects_channel_removed() {
511        let mut old = make_test_config();
512        old.channels.insert(
513            "discord".to_string(),
514            crate::config::ChannelConfig {
515                channel_type: "discord".to_string(),
516                token_env: Some("DISCORD_TOKEN".to_string()),
517                webhook_secret_env: None,
518                allowed_user_ids: vec![],
519                rate_limit_per_user: 20,
520                settings: HashMap::new(),
521            },
522        );
523        let new = make_test_config();
524
525        let changes = diff_configs(&old, &new);
526        assert!(changes.contains(&ConfigChange::ChannelRemoved("discord".to_string())));
527    }
528
529    #[test]
530    fn diff_returns_empty_for_identical_configs() {
531        let config = make_test_config();
532        let changes = diff_configs(&config, &config);
533        assert!(
534            changes.is_empty(),
535            "identical configs should produce no changes"
536        );
537    }
538
539    #[test]
540    fn validate_passes_valid_config() {
541        let config = make_test_config();
542        let errors: Vec<_> = validate_config(&config)
543            .into_iter()
544            .filter(|e| matches!(e.severity, ValidationSeverity::Error))
545            .collect();
546        assert!(errors.is_empty(), "valid config should produce no errors");
547    }
548
549    #[test]
550    fn validate_catches_empty_model_name() {
551        let mut config = make_test_config();
552        config.default_model.model = "".to_string();
553
554        let errors = validate_config(&config);
555        assert!(
556            errors.iter().any(|e| e.field == "default_model.model"
557                && matches!(e.severity, ValidationSeverity::Error))
558        );
559    }
560
561    #[test]
562    fn validate_catches_empty_db_path() {
563        let mut config = make_test_config();
564        config.memory.db_path = "".to_string();
565
566        let errors = validate_config(&config);
567        assert!(errors.iter().any(
568            |e| e.field == "memory.db_path" && matches!(e.severity, ValidationSeverity::Error)
569        ));
570    }
571
572    #[test]
573    fn validate_warns_on_empty_api_key() {
574        let mut config = make_test_config();
575        config.api_key = "".to_string();
576
577        let errors = validate_config(&config);
578        assert!(
579            errors
580                .iter()
581                .any(|e| e.field == "api_key" && matches!(e.severity, ValidationSeverity::Warning))
582        );
583    }
584
585    #[test]
586    fn config_watcher_can_be_created() {
587        let config = make_test_config();
588        let watcher = ConfigWatcher::new(PathBuf::from("/tmp/punch.toml"), config.clone());
589        assert_eq!(watcher.current().api_listen, config.api_listen);
590    }
591
592    #[tokio::test]
593    async fn apply_change_returns_change_set() {
594        let config = make_test_config();
595        let watcher = ConfigWatcher::new(PathBuf::from("/tmp/punch.toml"), config);
596
597        let mut new_config = make_test_config();
598        new_config.rate_limit_rpm = 120;
599
600        let result = watcher.apply_change(new_config);
601        assert!(result.is_ok());
602        let change_set = result.expect("should succeed");
603        assert!(
604            change_set
605                .changes
606                .contains(&ConfigChange::RateLimitChanged { old: 60, new: 120 })
607        );
608    }
609
610    #[tokio::test]
611    async fn apply_change_rejects_invalid_config() {
612        let config = make_test_config();
613        let watcher = ConfigWatcher::new(PathBuf::from("/tmp/punch.toml"), config);
614
615        let mut bad_config = make_test_config();
616        bad_config.default_model.model = "".to_string();
617
618        let result = watcher.apply_change(bad_config);
619        assert!(result.is_err());
620    }
621
622    #[test]
623    fn current_config_accessible_after_creation() {
624        let config = make_test_config();
625        let watcher = ConfigWatcher::new(PathBuf::from("/tmp/punch.toml"), config.clone());
626
627        let current = watcher.current();
628        assert_eq!(current.api_listen, "127.0.0.1:6660");
629        assert_eq!(current.rate_limit_rpm, 60);
630        assert_eq!(current.default_model.model, "claude-sonnet-4-20250514");
631        assert_eq!(current.memory.db_path, "/tmp/punch-test.db");
632    }
633
634    #[test]
635    fn diff_detects_mcp_server_added() {
636        let old = make_test_config();
637        let mut new = old.clone();
638        new.mcp_servers.insert(
639            "filesystem".to_string(),
640            crate::config::McpServerConfig {
641                command: "npx".to_string(),
642                args: vec!["-y".to_string(), "@mcp/filesystem".to_string()],
643                env: HashMap::new(),
644            },
645        );
646
647        let changes = diff_configs(&old, &new);
648        assert!(changes.contains(&ConfigChange::McpServerAdded("filesystem".to_string())));
649    }
650
651    #[test]
652    fn diff_detects_mcp_server_removed() {
653        let mut old = make_test_config();
654        old.mcp_servers.insert(
655            "memory".to_string(),
656            crate::config::McpServerConfig {
657                command: "mcp-memory".to_string(),
658                args: vec![],
659                env: HashMap::new(),
660            },
661        );
662        let new = make_test_config();
663
664        let changes = diff_configs(&old, &new);
665        assert!(changes.contains(&ConfigChange::McpServerRemoved("memory".to_string())));
666    }
667
668    #[test]
669    fn diff_detects_memory_config_changed() {
670        let old = make_test_config();
671        let mut new = old.clone();
672        new.memory.knowledge_graph_enabled = false;
673
674        let changes = diff_configs(&old, &new);
675        assert!(changes.contains(&ConfigChange::MemoryConfigChanged));
676    }
677
678    #[test]
679    fn diff_detects_api_key_changed() {
680        let old = make_test_config();
681        let mut new = old.clone();
682        new.api_key = "new-secret-key".to_string();
683
684        let changes = diff_configs(&old, &new);
685        assert!(changes.contains(&ConfigChange::ApiKeyChanged));
686    }
687
688    #[test]
689    fn diff_detects_listen_address_changed() {
690        let old = make_test_config();
691        let mut new = old.clone();
692        new.api_listen = "0.0.0.0:8080".to_string();
693
694        let changes = diff_configs(&old, &new);
695        assert!(changes.contains(&ConfigChange::ListenAddressChanged {
696            old: "127.0.0.1:6660".to_string(),
697            new: "0.0.0.0:8080".to_string(),
698        }));
699    }
700
701    #[test]
702    fn validate_catches_invalid_socket_addr() {
703        let mut config = make_test_config();
704        config.api_listen = "not-a-valid-address".to_string();
705
706        let errors = validate_config(&config);
707        assert!(errors.iter().any(|e| e.field == "api_listen"
708            && matches!(e.severity, ValidationSeverity::Error)));
709    }
710
711    #[test]
712    fn validate_catches_zero_rate_limit() {
713        let mut config = make_test_config();
714        config.rate_limit_rpm = 0;
715
716        let errors = validate_config(&config);
717        assert!(errors.iter().any(
718            |e| e.field == "rate_limit_rpm" && matches!(e.severity, ValidationSeverity::Error)
719        ));
720    }
721
722    #[tokio::test]
723    async fn subscriber_receives_initial_config() {
724        let config = make_test_config();
725        let watcher = ConfigWatcher::new(PathBuf::from("/tmp/punch.toml"), config.clone());
726
727        let rx = watcher.subscribe();
728        let received = rx.borrow().clone();
729        assert_eq!(received.api_listen, config.api_listen);
730    }
731}