Skip to main content

entrenar/eval/retrain/
retrainer.rs

1//! Auto-retrainer implementation.
2
3use super::action::Action;
4use super::config::RetrainConfig;
5use super::policy::RetrainPolicy;
6use crate::error::Result;
7use crate::eval::drift::{DriftDetector, DriftResult, DriftSummary, Severity};
8
9/// Callback type for retrain triggers
10pub type RetrainCallback = Box<dyn Fn(&[DriftResult]) -> Result<String> + Send + Sync>;
11
12/// Auto-retrainer that monitors drift and triggers retraining
13pub struct AutoRetrainer {
14    detector: DriftDetector,
15    config: RetrainConfig,
16    retrain_callback: Option<RetrainCallback>,
17    batches_since_retrain: usize,
18    total_retrains: usize,
19}
20
21impl AutoRetrainer {
22    /// Create a new auto-retrainer with given detector and config
23    pub fn new(detector: DriftDetector, config: RetrainConfig) -> Self {
24        Self {
25            detector,
26            config,
27            retrain_callback: None,
28            batches_since_retrain: 0,
29            total_retrains: 0,
30        }
31    }
32
33    /// Set the callback to invoke when retraining is triggered
34    ///
35    /// The callback receives the drift results and should return a job ID
36    /// or an error if retraining failed to start.
37    pub fn on_retrain<F>(&mut self, callback: F)
38    where
39        F: Fn(&[DriftResult]) -> Result<String> + Send + Sync + 'static,
40    {
41        self.retrain_callback = Some(Box::new(callback));
42    }
43
44    /// Process a batch of data and check for drift
45    ///
46    /// Returns the action taken based on drift detection and policy.
47    pub fn process_batch(&mut self, batch: &[Vec<f64>]) -> Result<Action> {
48        self.batches_since_retrain += 1;
49
50        // Check for drift
51        let results = self.detector.check(batch);
52
53        if results.is_empty() {
54            return Ok(Action::None);
55        }
56
57        let summary = DriftDetector::summary(&results);
58
59        // Check if we're in cooldown
60        if self.batches_since_retrain < self.config.cooldown_batches {
61            if summary.has_drift() && self.config.log_warnings {
62                return Ok(Action::WarningLogged);
63            }
64            return Ok(Action::None);
65        }
66
67        // Check max retrains limit
68        if self.config.max_retrains > 0 && self.total_retrains >= self.config.max_retrains {
69            if summary.has_drift() && self.config.log_warnings {
70                return Ok(Action::WarningLogged);
71            }
72            return Ok(Action::None);
73        }
74
75        // Evaluate policy
76        let should_retrain = self.evaluate_policy(&results, &summary);
77
78        if should_retrain {
79            if let Some(ref callback) = self.retrain_callback {
80                let job_id = callback(&results)?;
81                self.batches_since_retrain = 0;
82                self.total_retrains += 1;
83                return Ok(Action::RetrainTriggered(job_id));
84            }
85            // No callback set but policy says retrain
86            return Ok(Action::WarningLogged);
87        }
88
89        if summary.warnings > 0 && self.config.log_warnings {
90            return Ok(Action::WarningLogged);
91        }
92
93        Ok(Action::None)
94    }
95
96    /// Evaluate whether retraining should be triggered based on policy
97    fn evaluate_policy(&self, results: &[DriftResult], summary: &DriftSummary) -> bool {
98        match &self.config.policy {
99            RetrainPolicy::FeatureCount { count } => summary.drifted_features >= *count,
100
101            RetrainPolicy::CriticalFeature { names } => {
102                results.iter().any(|r| r.drifted && names.contains(&r.feature))
103            }
104
105            RetrainPolicy::DriftPercentage { threshold } => {
106                summary.drift_percentage() >= *threshold
107            }
108
109            RetrainPolicy::AnyCritical => results.iter().any(|r| r.severity == Severity::Critical),
110        }
111    }
112
113    /// Get the underlying drift detector
114    pub fn detector(&self) -> &DriftDetector {
115        &self.detector
116    }
117
118    /// Get mutable reference to drift detector (for setting baseline)
119    pub fn detector_mut(&mut self) -> &mut DriftDetector {
120        &mut self.detector
121    }
122
123    /// Get statistics about retraining
124    pub fn stats(&self) -> RetrainerStats {
125        RetrainerStats {
126            total_retrains: self.total_retrains,
127            batches_since_retrain: self.batches_since_retrain,
128        }
129    }
130
131    /// Reset the cooldown counter
132    pub fn reset_cooldown(&mut self) {
133        self.batches_since_retrain = self.config.cooldown_batches;
134    }
135}
136
137/// Statistics about the auto-retrainer
138#[derive(Clone, Debug)]
139pub struct RetrainerStats {
140    /// Total number of retrains triggered
141    pub total_retrains: usize,
142    /// Batches processed since last retrain
143    pub batches_since_retrain: usize,
144}