Skip to main content

aprender/online/
orchestrator.rs

1//! Retrain Orchestrator for Drift-Triggered Model Updates
2//!
3//! Automatically monitors model performance and triggers retraining
4//! when concept drift is detected.
5//!
6//! # References
7//!
8//! - [Gama et al. 2004] DDM for drift detection
9//! - [Bifet & Gavalda 2007] ADWIN for adaptive windowing
10//!
11//! # Toyota Way Principles
12//!
13//! - **Jidoka**: Stop and fix when problems detected
14//! - **Just-in-Time**: Retrain only when needed (pull system)
15//! - **Genchi Genbutsu**: Use actual prediction outcomes
16
17use crate::error::Result;
18
19use super::corpus::{CorpusBuffer, CorpusBufferConfig, EvictionPolicy, Sample, SampleSource};
20use super::curriculum::{CurriculumScheduler, LinearCurriculum, ScoredSample};
21use super::drift::{DriftDetector, DriftStatus, ADWIN};
22use super::OnlineLearner;
23
24/// Result of observing a new sample
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum ObserveResult {
27    /// Model is performing well
28    Stable,
29    /// Warning level - collecting data
30    Warning,
31    /// Model was retrained
32    Retrained,
33    /// Skipped (e.g., duplicate sample)
34    Skipped,
35}
36
37/// Configuration for retrain orchestrator
38#[derive(Debug, Clone)]
39pub struct RetrainConfig {
40    /// Minimum samples before retraining
41    pub min_samples: usize,
42    /// Maximum buffer size
43    pub max_buffer_size: usize,
44    /// Enable incremental updates on each sample
45    pub incremental_updates: bool,
46    /// Use curriculum learning during retrain
47    pub curriculum_learning: bool,
48    /// Number of curriculum stages
49    pub curriculum_stages: usize,
50    /// Save checkpoint after retrain
51    pub save_checkpoint: bool,
52    /// Learning rate for retraining
53    pub learning_rate: f64,
54    /// Number of epochs for full retraining
55    pub retrain_epochs: usize,
56}
57
58impl Default for RetrainConfig {
59    fn default() -> Self {
60        Self {
61            min_samples: 100,
62            max_buffer_size: 10_000,
63            incremental_updates: true,
64            curriculum_learning: true,
65            curriculum_stages: 5,
66            save_checkpoint: false,
67            learning_rate: 0.01,
68            retrain_epochs: 10,
69        }
70    }
71}
72
73/// Statistics from the orchestrator
74#[derive(Debug, Clone, Default)]
75pub struct OrchestratorStats {
76    /// Total samples observed
77    pub samples_observed: u64,
78    /// Number of retraining events
79    pub retrain_count: u64,
80    /// Current buffer size
81    pub buffer_size: usize,
82    /// Current drift status
83    pub drift_status: DriftStatus,
84    /// Last retrain sample count
85    pub last_retrain_samples: usize,
86    /// Samples since last retrain
87    pub samples_since_retrain: u64,
88}
89
90/// Automatic retraining orchestrator
91///
92/// Monitors predictions for drift and triggers retraining when needed.
93#[derive(Debug)]
94pub struct RetrainOrchestrator<
95    M: OnlineLearner + std::fmt::Debug,
96    D: DriftDetector + std::fmt::Debug,
97> {
98    /// Current model
99    model: M,
100    /// Drift detector
101    detector: D,
102    /// Data buffer for retraining
103    buffer: CorpusBuffer,
104    /// Retraining configuration
105    config: RetrainConfig,
106    /// Statistics
107    stats: OrchestratorStats,
108    /// Number of features (for validation)
109    #[allow(dead_code)]
110    n_features: usize,
111}
112
113impl<M: OnlineLearner + std::fmt::Debug> RetrainOrchestrator<M, ADWIN> {
114    /// Create an orchestrator with ADWIN detector (recommended default)
115    pub fn new(model: M, n_features: usize) -> Self {
116        Self::with_detector(model, ADWIN::new(), n_features)
117    }
118}
119
120impl<M: OnlineLearner + std::fmt::Debug, D: DriftDetector + std::fmt::Debug>
121    RetrainOrchestrator<M, D>
122{
123    /// Create with custom drift detector
124    pub fn with_detector(model: M, detector: D, n_features: usize) -> Self {
125        let config = RetrainConfig::default();
126        let buffer_config = CorpusBufferConfig {
127            max_size: config.max_buffer_size,
128            policy: EvictionPolicy::Reservoir,
129            deduplicate: true,
130            ..Default::default()
131        };
132
133        Self {
134            model,
135            detector,
136            buffer: CorpusBuffer::with_config(buffer_config),
137            config,
138            stats: OrchestratorStats::default(),
139            n_features,
140        }
141    }
142
143    /// Create with custom configuration
144    pub fn with_config(model: M, detector: D, n_features: usize, config: RetrainConfig) -> Self {
145        let buffer_config = CorpusBufferConfig {
146            max_size: config.max_buffer_size,
147            policy: EvictionPolicy::Reservoir,
148            deduplicate: true,
149            ..Default::default()
150        };
151
152        Self {
153            model,
154            detector,
155            buffer: CorpusBuffer::with_config(buffer_config),
156            config,
157            stats: OrchestratorStats::default(),
158            n_features,
159        }
160    }
161
162    /// Process new sample and handle drift
163    ///
164    /// # Arguments
165    /// * `features` - Input features
166    /// * `target` - True target value
167    /// * `prediction` - Model's prediction
168    ///
169    /// # Returns
170    /// Result indicating what action was taken
171    pub fn observe(
172        &mut self,
173        features: &[f64],
174        target: &[f64],
175        prediction: &[f64],
176    ) -> Result<ObserveResult> {
177        self.stats.samples_observed += 1;
178        self.stats.samples_since_retrain += 1;
179
180        // Check prediction correctness
181        let error = self.compute_error(target, prediction);
182        self.detector.add_element(error);
183
184        // Buffer data for potential retraining
185        let sample =
186            Sample::with_source(features.to_vec(), target.to_vec(), SampleSource::Production);
187
188        if !self.buffer.add(sample) {
189            return Ok(ObserveResult::Skipped);
190        }
191
192        self.stats.buffer_size = self.buffer.len();
193        self.stats.drift_status = self.detector.detected_change();
194
195        match self.detector.detected_change() {
196            DriftStatus::Stable => {
197                // Incremental update only
198                if self.config.incremental_updates {
199                    self.model
200                        .partial_fit(features, target, Some(self.config.learning_rate))?;
201                }
202                Ok(ObserveResult::Stable)
203            }
204            DriftStatus::Warning => {
205                // Continue collecting data, maybe do incremental update
206                if self.config.incremental_updates {
207                    self.model
208                        .partial_fit(features, target, Some(self.config.learning_rate))?;
209                }
210                Ok(ObserveResult::Warning)
211            }
212            DriftStatus::Drift => {
213                // Check if we have enough samples
214                if self.buffer.len() >= self.config.min_samples {
215                    self.retrain()?;
216                    Ok(ObserveResult::Retrained)
217                } else {
218                    // Not enough data yet, do incremental update
219                    if self.config.incremental_updates {
220                        self.model.partial_fit(
221                            features,
222                            target,
223                            Some(self.config.learning_rate),
224                        )?;
225                    }
226                    Ok(ObserveResult::Warning)
227                }
228            }
229        }
230    }
231
232    /// Compute if prediction was an error
233    fn compute_error(&self, target: &[f64], prediction: &[f64]) -> bool {
234        let _ = self; // suppress unused self warning - method for consistency
235        if target.is_empty() || prediction.is_empty() {
236            return true;
237        }
238
239        // For regression: error if prediction is too far from target
240        // For classification: error if predicted class differs
241        if target.len() == 1 && prediction.len() == 1 {
242            // Regression or binary classification
243            let diff = (target[0] - prediction[0]).abs();
244            if target[0].abs() < 1.0 {
245                // Classification threshold
246                diff > 0.5
247            } else {
248                // Regression: relative error > 10%
249                diff / target[0].abs().max(1.0) > 0.1
250            }
251        } else {
252            // Multi-class: compare argmax
253            let target_class = target
254                .iter()
255                .enumerate()
256                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
257                .map_or(0, |(i, _)| i);
258
259            let pred_class = prediction
260                .iter()
261                .enumerate()
262                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
263                .map_or(0, |(i, _)| i);
264
265            target_class != pred_class
266        }
267    }
268
269    /// Perform full retraining on buffered data
270    fn retrain(&mut self) -> Result<()> {
271        let (features, targets, n_samples, n_features) = self.buffer.to_dataset();
272
273        if n_samples == 0 || n_features == 0 {
274            return Ok(());
275        }
276
277        // Reset model
278        self.model.reset();
279
280        if self.config.curriculum_learning {
281            self.retrain_with_curriculum(&features, &targets, n_samples, n_features)?;
282        } else {
283            self.retrain_standard(&features, &targets, n_samples, n_features)?;
284        }
285
286        // Update stats
287        self.stats.retrain_count += 1;
288        self.stats.last_retrain_samples = n_samples;
289        self.stats.samples_since_retrain = 0;
290
291        // Reset drift detector
292        self.detector.reset();
293
294        // Clear buffer but keep some recent samples
295        let keep = (self.config.min_samples / 2).min(self.buffer.len());
296        let recent: Vec<Sample> = self
297            .buffer
298            .samples()
299            .iter()
300            .rev()
301            .take(keep)
302            .cloned()
303            .collect();
304
305        self.buffer.clear();
306        for sample in recent {
307            self.buffer.add(sample);
308        }
309
310        Ok(())
311    }
312
313    /// Standard retraining without curriculum
314    fn retrain_standard(
315        &mut self,
316        features: &[f64],
317        targets: &[f64],
318        n_samples: usize,
319        n_features: usize,
320    ) -> Result<()> {
321        for _ in 0..self.config.retrain_epochs {
322            for i in 0..n_samples {
323                let x = &features[i * n_features..(i + 1) * n_features];
324                let y = &targets[i..=i];
325                self.model
326                    .partial_fit(x, y, Some(self.config.learning_rate))?;
327            }
328        }
329        Ok(())
330    }
331
332    /// Curriculum-based retraining
333    fn retrain_with_curriculum(
334        &mut self,
335        features: &[f64],
336        targets: &[f64],
337        n_samples: usize,
338        n_features: usize,
339    ) -> Result<()> {
340        // Score samples by loss (using current model)
341        let mut scored_samples: Vec<ScoredSample> = Vec::with_capacity(n_samples);
342
343        for i in 0..n_samples {
344            let x = &features[i * n_features..(i + 1) * n_features];
345            let y = targets[i];
346
347            // Use feature norm as difficulty proxy (simple but effective)
348            let difficulty: f64 = x.iter().map(|v| v * v).sum::<f64>().sqrt();
349
350            scored_samples.push(ScoredSample::new(x.to_vec(), y, difficulty));
351        }
352
353        // Sort by difficulty (easiest first)
354        scored_samples.sort_by(|a, b| {
355            a.difficulty
356                .partial_cmp(&b.difficulty)
357                .unwrap_or(std::cmp::Ordering::Equal)
358        });
359
360        // Create curriculum scheduler
361        let mut curriculum = LinearCurriculum::new(self.config.curriculum_stages);
362
363        // Train in stages
364        let samples_per_stage = n_samples / self.config.curriculum_stages.max(1);
365
366        for stage in 0..self.config.curriculum_stages {
367            let end_idx = ((stage + 1) * samples_per_stage).min(n_samples);
368
369            // Train on samples up to current stage
370            for _epoch in 0..self.config.retrain_epochs / self.config.curriculum_stages.max(1) {
371                for sample in scored_samples.iter().take(end_idx) {
372                    let y = &[sample.target];
373                    self.model
374                        .partial_fit(&sample.features, y, Some(self.config.learning_rate))?;
375                }
376            }
377
378            curriculum.advance();
379        }
380
381        Ok(())
382    }
383
384    /// Get reference to the model
385    pub fn model(&self) -> &M {
386        &self.model
387    }
388
389    /// Get mutable reference to the model
390    pub fn model_mut(&mut self) -> &mut M {
391        &mut self.model
392    }
393
394    /// Get reference to the drift detector
395    pub fn detector(&self) -> &D {
396        &self.detector
397    }
398
399    /// Get orchestrator statistics
400    pub fn stats(&self) -> &OrchestratorStats {
401        &self.stats
402    }
403
404    /// Get current drift status
405    pub fn drift_status(&self) -> DriftStatus {
406        self.detector.detected_change()
407    }
408
409    /// Force a retrain (useful for manual triggers)
410    pub fn force_retrain(&mut self) -> Result<()> {
411        self.retrain()
412    }
413
414    /// Get buffer size
415    pub fn buffer_size(&self) -> usize {
416        self.buffer.len()
417    }
418
419    /// Check if retraining is recommended
420    pub fn should_retrain(&self) -> bool {
421        self.detector.detected_change() == DriftStatus::Drift
422            && self.buffer.len() >= self.config.min_samples
423    }
424
425    /// Get configuration
426    pub fn config(&self) -> &RetrainConfig {
427        &self.config
428    }
429}
430
431/// Builder for `RetrainOrchestrator`
432#[derive(Debug)]
433pub struct OrchestratorBuilder<M: OnlineLearner + std::fmt::Debug> {
434    model: M,
435    n_features: usize,
436    config: RetrainConfig,
437    delta: f64, // For ADWIN
438}
439
440include!("orchestrator_builder.rs");
441include!("orchestrator_tests.rs");