entrenar/eval/retrain/
retrainer.rs1use super::action::Action;
4use super::config::RetrainConfig;
5use super::policy::RetrainPolicy;
6use crate::error::Result;
7use crate::eval::drift::{DriftDetector, DriftResult, DriftSummary, Severity};
8
9pub type RetrainCallback = Box<dyn Fn(&[DriftResult]) -> Result<String> + Send + Sync>;
11
12pub struct AutoRetrainer {
14 detector: DriftDetector,
15 config: RetrainConfig,
16 retrain_callback: Option<RetrainCallback>,
17 batches_since_retrain: usize,
18 total_retrains: usize,
19}
20
21impl AutoRetrainer {
22 pub fn new(detector: DriftDetector, config: RetrainConfig) -> Self {
24 Self {
25 detector,
26 config,
27 retrain_callback: None,
28 batches_since_retrain: 0,
29 total_retrains: 0,
30 }
31 }
32
33 pub fn on_retrain<F>(&mut self, callback: F)
38 where
39 F: Fn(&[DriftResult]) -> Result<String> + Send + Sync + 'static,
40 {
41 self.retrain_callback = Some(Box::new(callback));
42 }
43
44 pub fn process_batch(&mut self, batch: &[Vec<f64>]) -> Result<Action> {
48 self.batches_since_retrain += 1;
49
50 let results = self.detector.check(batch);
52
53 if results.is_empty() {
54 return Ok(Action::None);
55 }
56
57 let summary = DriftDetector::summary(&results);
58
59 if self.batches_since_retrain < self.config.cooldown_batches {
61 if summary.has_drift() && self.config.log_warnings {
62 return Ok(Action::WarningLogged);
63 }
64 return Ok(Action::None);
65 }
66
67 if self.config.max_retrains > 0 && self.total_retrains >= self.config.max_retrains {
69 if summary.has_drift() && self.config.log_warnings {
70 return Ok(Action::WarningLogged);
71 }
72 return Ok(Action::None);
73 }
74
75 let should_retrain = self.evaluate_policy(&results, &summary);
77
78 if should_retrain {
79 if let Some(ref callback) = self.retrain_callback {
80 let job_id = callback(&results)?;
81 self.batches_since_retrain = 0;
82 self.total_retrains += 1;
83 return Ok(Action::RetrainTriggered(job_id));
84 }
85 return Ok(Action::WarningLogged);
87 }
88
89 if summary.warnings > 0 && self.config.log_warnings {
90 return Ok(Action::WarningLogged);
91 }
92
93 Ok(Action::None)
94 }
95
96 fn evaluate_policy(&self, results: &[DriftResult], summary: &DriftSummary) -> bool {
98 match &self.config.policy {
99 RetrainPolicy::FeatureCount { count } => summary.drifted_features >= *count,
100
101 RetrainPolicy::CriticalFeature { names } => {
102 results.iter().any(|r| r.drifted && names.contains(&r.feature))
103 }
104
105 RetrainPolicy::DriftPercentage { threshold } => {
106 summary.drift_percentage() >= *threshold
107 }
108
109 RetrainPolicy::AnyCritical => results.iter().any(|r| r.severity == Severity::Critical),
110 }
111 }
112
113 pub fn detector(&self) -> &DriftDetector {
115 &self.detector
116 }
117
118 pub fn detector_mut(&mut self) -> &mut DriftDetector {
120 &mut self.detector
121 }
122
123 pub fn stats(&self) -> RetrainerStats {
125 RetrainerStats {
126 total_retrains: self.total_retrains,
127 batches_since_retrain: self.batches_since_retrain,
128 }
129 }
130
131 pub fn reset_cooldown(&mut self) {
133 self.batches_since_retrain = self.config.cooldown_batches;
134 }
135}
136
137#[derive(Clone, Debug)]
139pub struct RetrainerStats {
140 pub total_retrains: usize,
142 pub batches_since_retrain: usize,
144}