1use crate::error::Result;
18
19use super::corpus::{CorpusBuffer, CorpusBufferConfig, EvictionPolicy, Sample, SampleSource};
20use super::curriculum::{CurriculumScheduler, LinearCurriculum, ScoredSample};
21use super::drift::{DriftDetector, DriftStatus, ADWIN};
22use super::OnlineLearner;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum ObserveResult {
27 Stable,
29 Warning,
31 Retrained,
33 Skipped,
35}
36
37#[derive(Debug, Clone)]
39pub struct RetrainConfig {
40 pub min_samples: usize,
42 pub max_buffer_size: usize,
44 pub incremental_updates: bool,
46 pub curriculum_learning: bool,
48 pub curriculum_stages: usize,
50 pub save_checkpoint: bool,
52 pub learning_rate: f64,
54 pub retrain_epochs: usize,
56}
57
58impl Default for RetrainConfig {
59 fn default() -> Self {
60 Self {
61 min_samples: 100,
62 max_buffer_size: 10_000,
63 incremental_updates: true,
64 curriculum_learning: true,
65 curriculum_stages: 5,
66 save_checkpoint: false,
67 learning_rate: 0.01,
68 retrain_epochs: 10,
69 }
70 }
71}
72
73#[derive(Debug, Clone, Default)]
75pub struct OrchestratorStats {
76 pub samples_observed: u64,
78 pub retrain_count: u64,
80 pub buffer_size: usize,
82 pub drift_status: DriftStatus,
84 pub last_retrain_samples: usize,
86 pub samples_since_retrain: u64,
88}
89
90#[derive(Debug)]
94pub struct RetrainOrchestrator<
95 M: OnlineLearner + std::fmt::Debug,
96 D: DriftDetector + std::fmt::Debug,
97> {
98 model: M,
100 detector: D,
102 buffer: CorpusBuffer,
104 config: RetrainConfig,
106 stats: OrchestratorStats,
108 #[allow(dead_code)]
110 n_features: usize,
111}
112
113impl<M: OnlineLearner + std::fmt::Debug> RetrainOrchestrator<M, ADWIN> {
114 pub fn new(model: M, n_features: usize) -> Self {
116 Self::with_detector(model, ADWIN::new(), n_features)
117 }
118}
119
120impl<M: OnlineLearner + std::fmt::Debug, D: DriftDetector + std::fmt::Debug>
121 RetrainOrchestrator<M, D>
122{
123 pub fn with_detector(model: M, detector: D, n_features: usize) -> Self {
125 let config = RetrainConfig::default();
126 let buffer_config = CorpusBufferConfig {
127 max_size: config.max_buffer_size,
128 policy: EvictionPolicy::Reservoir,
129 deduplicate: true,
130 ..Default::default()
131 };
132
133 Self {
134 model,
135 detector,
136 buffer: CorpusBuffer::with_config(buffer_config),
137 config,
138 stats: OrchestratorStats::default(),
139 n_features,
140 }
141 }
142
143 pub fn with_config(model: M, detector: D, n_features: usize, config: RetrainConfig) -> Self {
145 let buffer_config = CorpusBufferConfig {
146 max_size: config.max_buffer_size,
147 policy: EvictionPolicy::Reservoir,
148 deduplicate: true,
149 ..Default::default()
150 };
151
152 Self {
153 model,
154 detector,
155 buffer: CorpusBuffer::with_config(buffer_config),
156 config,
157 stats: OrchestratorStats::default(),
158 n_features,
159 }
160 }
161
162 pub fn observe(
172 &mut self,
173 features: &[f64],
174 target: &[f64],
175 prediction: &[f64],
176 ) -> Result<ObserveResult> {
177 self.stats.samples_observed += 1;
178 self.stats.samples_since_retrain += 1;
179
180 let error = self.compute_error(target, prediction);
182 self.detector.add_element(error);
183
184 let sample =
186 Sample::with_source(features.to_vec(), target.to_vec(), SampleSource::Production);
187
188 if !self.buffer.add(sample) {
189 return Ok(ObserveResult::Skipped);
190 }
191
192 self.stats.buffer_size = self.buffer.len();
193 self.stats.drift_status = self.detector.detected_change();
194
195 match self.detector.detected_change() {
196 DriftStatus::Stable => {
197 if self.config.incremental_updates {
199 self.model
200 .partial_fit(features, target, Some(self.config.learning_rate))?;
201 }
202 Ok(ObserveResult::Stable)
203 }
204 DriftStatus::Warning => {
205 if self.config.incremental_updates {
207 self.model
208 .partial_fit(features, target, Some(self.config.learning_rate))?;
209 }
210 Ok(ObserveResult::Warning)
211 }
212 DriftStatus::Drift => {
213 if self.buffer.len() >= self.config.min_samples {
215 self.retrain()?;
216 Ok(ObserveResult::Retrained)
217 } else {
218 if self.config.incremental_updates {
220 self.model.partial_fit(
221 features,
222 target,
223 Some(self.config.learning_rate),
224 )?;
225 }
226 Ok(ObserveResult::Warning)
227 }
228 }
229 }
230 }
231
232 fn compute_error(&self, target: &[f64], prediction: &[f64]) -> bool {
234 let _ = self; if target.is_empty() || prediction.is_empty() {
236 return true;
237 }
238
239 if target.len() == 1 && prediction.len() == 1 {
242 let diff = (target[0] - prediction[0]).abs();
244 if target[0].abs() < 1.0 {
245 diff > 0.5
247 } else {
248 diff / target[0].abs().max(1.0) > 0.1
250 }
251 } else {
252 let target_class = target
254 .iter()
255 .enumerate()
256 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
257 .map_or(0, |(i, _)| i);
258
259 let pred_class = prediction
260 .iter()
261 .enumerate()
262 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
263 .map_or(0, |(i, _)| i);
264
265 target_class != pred_class
266 }
267 }
268
269 fn retrain(&mut self) -> Result<()> {
271 let (features, targets, n_samples, n_features) = self.buffer.to_dataset();
272
273 if n_samples == 0 || n_features == 0 {
274 return Ok(());
275 }
276
277 self.model.reset();
279
280 if self.config.curriculum_learning {
281 self.retrain_with_curriculum(&features, &targets, n_samples, n_features)?;
282 } else {
283 self.retrain_standard(&features, &targets, n_samples, n_features)?;
284 }
285
286 self.stats.retrain_count += 1;
288 self.stats.last_retrain_samples = n_samples;
289 self.stats.samples_since_retrain = 0;
290
291 self.detector.reset();
293
294 let keep = (self.config.min_samples / 2).min(self.buffer.len());
296 let recent: Vec<Sample> = self
297 .buffer
298 .samples()
299 .iter()
300 .rev()
301 .take(keep)
302 .cloned()
303 .collect();
304
305 self.buffer.clear();
306 for sample in recent {
307 self.buffer.add(sample);
308 }
309
310 Ok(())
311 }
312
313 fn retrain_standard(
315 &mut self,
316 features: &[f64],
317 targets: &[f64],
318 n_samples: usize,
319 n_features: usize,
320 ) -> Result<()> {
321 for _ in 0..self.config.retrain_epochs {
322 for i in 0..n_samples {
323 let x = &features[i * n_features..(i + 1) * n_features];
324 let y = &targets[i..=i];
325 self.model
326 .partial_fit(x, y, Some(self.config.learning_rate))?;
327 }
328 }
329 Ok(())
330 }
331
332 fn retrain_with_curriculum(
334 &mut self,
335 features: &[f64],
336 targets: &[f64],
337 n_samples: usize,
338 n_features: usize,
339 ) -> Result<()> {
340 let mut scored_samples: Vec<ScoredSample> = Vec::with_capacity(n_samples);
342
343 for i in 0..n_samples {
344 let x = &features[i * n_features..(i + 1) * n_features];
345 let y = targets[i];
346
347 let difficulty: f64 = x.iter().map(|v| v * v).sum::<f64>().sqrt();
349
350 scored_samples.push(ScoredSample::new(x.to_vec(), y, difficulty));
351 }
352
353 scored_samples.sort_by(|a, b| {
355 a.difficulty
356 .partial_cmp(&b.difficulty)
357 .unwrap_or(std::cmp::Ordering::Equal)
358 });
359
360 let mut curriculum = LinearCurriculum::new(self.config.curriculum_stages);
362
363 let samples_per_stage = n_samples / self.config.curriculum_stages.max(1);
365
366 for stage in 0..self.config.curriculum_stages {
367 let end_idx = ((stage + 1) * samples_per_stage).min(n_samples);
368
369 for _epoch in 0..self.config.retrain_epochs / self.config.curriculum_stages.max(1) {
371 for sample in scored_samples.iter().take(end_idx) {
372 let y = &[sample.target];
373 self.model
374 .partial_fit(&sample.features, y, Some(self.config.learning_rate))?;
375 }
376 }
377
378 curriculum.advance();
379 }
380
381 Ok(())
382 }
383
384 pub fn model(&self) -> &M {
386 &self.model
387 }
388
389 pub fn model_mut(&mut self) -> &mut M {
391 &mut self.model
392 }
393
394 pub fn detector(&self) -> &D {
396 &self.detector
397 }
398
399 pub fn stats(&self) -> &OrchestratorStats {
401 &self.stats
402 }
403
404 pub fn drift_status(&self) -> DriftStatus {
406 self.detector.detected_change()
407 }
408
409 pub fn force_retrain(&mut self) -> Result<()> {
411 self.retrain()
412 }
413
414 pub fn buffer_size(&self) -> usize {
416 self.buffer.len()
417 }
418
419 pub fn should_retrain(&self) -> bool {
421 self.detector.detected_change() == DriftStatus::Drift
422 && self.buffer.len() >= self.config.min_samples
423 }
424
425 pub fn config(&self) -> &RetrainConfig {
427 &self.config
428 }
429}
430
431#[derive(Debug)]
433pub struct OrchestratorBuilder<M: OnlineLearner + std::fmt::Debug> {
434 model: M,
435 n_features: usize,
436 config: RetrainConfig,
437 delta: f64, }
439
440include!("orchestrator_builder.rs");
441include!("orchestrator_tests.rs");