trueno/tuner/data_collector/
drift.rs1use crate::tuner::brick_tuner::BrickTuner;
7use crate::tuner::features::TunerFeatures;
8
9use super::types::{ConceptDriftStatus, TrainingStats, UserFeedback};
10use super::TunerDataCollector;
11
12impl TunerDataCollector {
13 pub fn record_feedback(&mut self, sample_index: usize, feedback: UserFeedback) {
19 while self.feedback.len() <= sample_index {
21 self.feedback.push(UserFeedback::None);
22 }
23 self.feedback[sample_index] = feedback;
24 }
25
26 pub fn get_feedback(&self, sample_index: usize) -> UserFeedback {
28 self.feedback.get(sample_index).copied().unwrap_or(UserFeedback::None)
29 }
30
31 pub fn record_prediction_error(&mut self, predicted: f32, actual: f32) {
33 if !self.online_learning_enabled {
34 return;
35 }
36
37 let error = if actual > 0.0 { ((predicted - actual) / actual).abs().min(1.0) } else { 1.0 };
39
40 self.error_window.push(error);
42
43 if self.error_window.len() > self.error_window_size {
45 self.error_window.remove(0);
46 }
47 }
48
49 pub fn detect_concept_drift(&self) -> ConceptDriftStatus {
51 let samples_since_training = self.samples.len().saturating_sub(self.samples_at_last_train);
52
53 if self.error_window.len() < 10 {
55 return ConceptDriftStatus {
56 drift_detected: false,
57 staleness_score: 0.0,
58 samples_since_training,
59 recommend_retrain: false,
60 explanation: "Insufficient data for drift detection".to_string(),
61 };
62 }
63
64 let mean_error: f32 =
66 self.error_window.iter().sum::<f32>() / self.error_window.len().max(1) as f32;
67
68 let staleness_score =
70 (samples_since_training as f32 / Self::STALENESS_THRESHOLD as f32).min(1.0);
71
72 let drift_detected = mean_error > Self::DRIFT_ERROR_THRESHOLD;
74
75 let recommend_retrain = drift_detected || staleness_score > 0.8;
77
78 let explanation = if drift_detected {
79 format!(
80 "Concept drift detected: mean error {:.1}% exceeds threshold {:.1}%",
81 mean_error * 100.0,
82 Self::DRIFT_ERROR_THRESHOLD * 100.0
83 )
84 } else if staleness_score > 0.8 {
85 format!(
86 "Model stale: {} samples since last training (threshold: {})",
87 samples_since_training,
88 Self::STALENESS_THRESHOLD
89 )
90 } else {
91 format!(
92 "Model fresh: mean error {:.1}%, {} samples since training",
93 mean_error * 100.0,
94 samples_since_training
95 )
96 };
97
98 ConceptDriftStatus {
99 drift_detected,
100 staleness_score,
101 samples_since_training,
102 recommend_retrain,
103 explanation,
104 }
105 }
106
107 pub fn should_retrain(&self) -> bool {
109 if !self.online_learning_enabled {
110 return false;
111 }
112
113 let samples_since = self.samples.len().saturating_sub(self.samples_at_last_train);
114
115 if samples_since >= self.retrain_threshold {
117 return true;
118 }
119
120 let drift = self.detect_concept_drift();
122 drift.recommend_retrain && samples_since >= 10
123 }
124
125 pub fn mark_trained(&mut self) {
127 self.samples_at_last_train = self.samples.len();
128 self.error_window.clear();
129 }
130
131 pub fn training_stats(&self) -> TrainingStats {
133 let drift = self.detect_concept_drift();
134
135 let accepted_count = self.feedback.iter().filter(|f| **f == UserFeedback::Accepted).count();
137 let rejected_count = self.feedback.iter().filter(|f| **f == UserFeedback::Rejected).count();
138 let alternative_count =
139 self.feedback.iter().filter(|f| **f == UserFeedback::Alternative).count();
140
141 TrainingStats {
142 total_samples: self.samples.len(),
143 samples_since_training: drift.samples_since_training,
144 accepted_count,
145 rejected_count,
146 alternative_count,
147 staleness_score: drift.staleness_score,
148 drift_detected: drift.drift_detected,
149 online_learning_enabled: self.online_learning_enabled,
150 }
151 }
152
153 pub fn auto_retrain(&mut self, tuner: &mut BrickTuner) -> bool {
155 if !self.should_retrain() {
156 return false;
157 }
158
159 let training_data = self.prepare_weighted_training_data();
161
162 if training_data.len() < 10 {
163 return false;
164 }
165
166 match tuner.train(&training_data) {
168 Ok(()) => {
169 self.mark_trained();
170 true
171 }
172 Err(_) => false,
173 }
174 }
175
176 pub(super) fn prepare_weighted_training_data(&self) -> Vec<(TunerFeatures, f32)> {
178 self.samples
179 .iter()
180 .enumerate()
181 .filter_map(|(i, s)| {
182 let feedback = self.get_feedback(i);
183
184 if feedback == UserFeedback::Rejected {
186 return None;
187 }
188
189 let weight = match feedback {
191 UserFeedback::Accepted => 2,
192 UserFeedback::Alternative | UserFeedback::Rejected | UserFeedback::None => 1,
193 };
194
195 Some((0..weight).map(|_| (s.features.clone(), s.throughput_tps)))
196 })
197 .flatten()
198 .collect()
199 }
200}