Skip to main content

trueno/tuner/data_collector/
drift.rs

1//! Online learning and concept drift detection (T-TUNER-005, GitHub #82).
2//!
3//! Tracks prediction errors via a sliding window to detect model staleness
4//! and trigger auto-retraining when concept drift is observed.
5
6use crate::tuner::brick_tuner::BrickTuner;
7use crate::tuner::features::TunerFeatures;
8
9use super::types::{ConceptDriftStatus, TrainingStats, UserFeedback};
10use super::TunerDataCollector;
11
12impl TunerDataCollector {
13    // ========================================================================
14    // T-TUNER-005: Online Learning (GitHub #82)
15    // ========================================================================
16
17    /// Record user feedback on a recommendation
18    pub fn record_feedback(&mut self, sample_index: usize, feedback: UserFeedback) {
19        // Extend feedback vector if needed
20        while self.feedback.len() <= sample_index {
21            self.feedback.push(UserFeedback::None);
22        }
23        self.feedback[sample_index] = feedback;
24    }
25
26    /// Get feedback for a sample
27    pub fn get_feedback(&self, sample_index: usize) -> UserFeedback {
28        self.feedback.get(sample_index).copied().unwrap_or(UserFeedback::None)
29    }
30
31    /// Record prediction error for concept drift detection
32    pub fn record_prediction_error(&mut self, predicted: f32, actual: f32) {
33        if !self.online_learning_enabled {
34            return;
35        }
36
37        // Compute relative error (0.0 = perfect, 1.0 = 100% error)
38        let error = if actual > 0.0 { ((predicted - actual) / actual).abs().min(1.0) } else { 1.0 };
39
40        // Add to sliding window
41        self.error_window.push(error);
42
43        // Trim window to max size
44        if self.error_window.len() > self.error_window_size {
45            self.error_window.remove(0);
46        }
47    }
48
49    /// Detect concept drift based on prediction error trends
50    pub fn detect_concept_drift(&self) -> ConceptDriftStatus {
51        let samples_since_training = self.samples.len().saturating_sub(self.samples_at_last_train);
52
53        // Not enough data for drift detection
54        if self.error_window.len() < 10 {
55            return ConceptDriftStatus {
56                drift_detected: false,
57                staleness_score: 0.0,
58                samples_since_training,
59                recommend_retrain: false,
60                explanation: "Insufficient data for drift detection".to_string(),
61            };
62        }
63
64        // Compute mean error
65        let mean_error: f32 =
66            self.error_window.iter().sum::<f32>() / self.error_window.len().max(1) as f32;
67
68        // Compute staleness score (0.0 = fresh, 1.0 = stale)
69        let staleness_score =
70            (samples_since_training as f32 / Self::STALENESS_THRESHOLD as f32).min(1.0);
71
72        // Detect drift
73        let drift_detected = mean_error > Self::DRIFT_ERROR_THRESHOLD;
74
75        // Recommend retrain if drift detected OR stale
76        let recommend_retrain = drift_detected || staleness_score > 0.8;
77
78        let explanation = if drift_detected {
79            format!(
80                "Concept drift detected: mean error {:.1}% exceeds threshold {:.1}%",
81                mean_error * 100.0,
82                Self::DRIFT_ERROR_THRESHOLD * 100.0
83            )
84        } else if staleness_score > 0.8 {
85            format!(
86                "Model stale: {} samples since last training (threshold: {})",
87                samples_since_training,
88                Self::STALENESS_THRESHOLD
89            )
90        } else {
91            format!(
92                "Model fresh: mean error {:.1}%, {} samples since training",
93                mean_error * 100.0,
94                samples_since_training
95            )
96        };
97
98        ConceptDriftStatus {
99            drift_detected,
100            staleness_score,
101            samples_since_training,
102            recommend_retrain,
103            explanation,
104        }
105    }
106
107    /// Check if auto-retrain should trigger
108    pub fn should_retrain(&self) -> bool {
109        if !self.online_learning_enabled {
110            return false;
111        }
112
113        let samples_since = self.samples.len().saturating_sub(self.samples_at_last_train);
114
115        // Retrain if we have enough new samples
116        if samples_since >= self.retrain_threshold {
117            return true;
118        }
119
120        // Or if concept drift is detected
121        let drift = self.detect_concept_drift();
122        drift.recommend_retrain && samples_since >= 10
123    }
124
125    /// Mark that training occurred (resets drift counters)
126    pub fn mark_trained(&mut self) {
127        self.samples_at_last_train = self.samples.len();
128        self.error_window.clear();
129    }
130
131    /// Get training statistics
132    pub fn training_stats(&self) -> TrainingStats {
133        let drift = self.detect_concept_drift();
134
135        // Count feedback types
136        let accepted_count = self.feedback.iter().filter(|f| **f == UserFeedback::Accepted).count();
137        let rejected_count = self.feedback.iter().filter(|f| **f == UserFeedback::Rejected).count();
138        let alternative_count =
139            self.feedback.iter().filter(|f| **f == UserFeedback::Alternative).count();
140
141        TrainingStats {
142            total_samples: self.samples.len(),
143            samples_since_training: drift.samples_since_training,
144            accepted_count,
145            rejected_count,
146            alternative_count,
147            staleness_score: drift.staleness_score,
148            drift_detected: drift.drift_detected,
149            online_learning_enabled: self.online_learning_enabled,
150        }
151    }
152
153    /// Auto-retrain and update BrickTuner if conditions are met
154    pub fn auto_retrain(&mut self, tuner: &mut BrickTuner) -> bool {
155        if !self.should_retrain() {
156            return false;
157        }
158
159        // Weight samples by feedback
160        let training_data = self.prepare_weighted_training_data();
161
162        if training_data.len() < 10 {
163            return false;
164        }
165
166        // Train and update
167        match tuner.train(&training_data) {
168            Ok(()) => {
169                self.mark_trained();
170                true
171            }
172            Err(_) => false,
173        }
174    }
175
176    /// Prepare training data with feedback weighting
177    pub(super) fn prepare_weighted_training_data(&self) -> Vec<(TunerFeatures, f32)> {
178        self.samples
179            .iter()
180            .enumerate()
181            .filter_map(|(i, s)| {
182                let feedback = self.get_feedback(i);
183
184                // Skip rejected samples (they had bad throughput measurements)
185                if feedback == UserFeedback::Rejected {
186                    return None;
187                }
188
189                // Weight accepted samples higher (duplicate them)
190                let weight = match feedback {
191                    UserFeedback::Accepted => 2,
192                    UserFeedback::Alternative | UserFeedback::Rejected | UserFeedback::None => 1,
193                };
194
195                Some((0..weight).map(|_| (s.features.clone(), s.throughput_tps)))
196            })
197            .flatten()
198            .collect()
199    }
200}