use crate::tuner::brick_tuner::BrickTuner;
use crate::tuner::features::TunerFeatures;
use super::types::{ConceptDriftStatus, TrainingStats, UserFeedback};
use super::TunerDataCollector;
impl TunerDataCollector {
pub fn record_feedback(&mut self, sample_index: usize, feedback: UserFeedback) {
while self.feedback.len() <= sample_index {
self.feedback.push(UserFeedback::None);
}
self.feedback[sample_index] = feedback;
}
pub fn get_feedback(&self, sample_index: usize) -> UserFeedback {
self.feedback.get(sample_index).copied().unwrap_or(UserFeedback::None)
}
pub fn record_prediction_error(&mut self, predicted: f32, actual: f32) {
if !self.online_learning_enabled {
return;
}
let error = if actual > 0.0 { ((predicted - actual) / actual).abs().min(1.0) } else { 1.0 };
self.error_window.push(error);
if self.error_window.len() > self.error_window_size {
self.error_window.remove(0);
}
}
pub fn detect_concept_drift(&self) -> ConceptDriftStatus {
let samples_since_training = self.samples.len().saturating_sub(self.samples_at_last_train);
if self.error_window.len() < 10 {
return ConceptDriftStatus {
drift_detected: false,
staleness_score: 0.0,
samples_since_training,
recommend_retrain: false,
explanation: "Insufficient data for drift detection".to_string(),
};
}
let mean_error: f32 =
self.error_window.iter().sum::<f32>() / self.error_window.len().max(1) as f32;
let staleness_score =
(samples_since_training as f32 / Self::STALENESS_THRESHOLD as f32).min(1.0);
let drift_detected = mean_error > Self::DRIFT_ERROR_THRESHOLD;
let recommend_retrain = drift_detected || staleness_score > 0.8;
let explanation = if drift_detected {
format!(
"Concept drift detected: mean error {:.1}% exceeds threshold {:.1}%",
mean_error * 100.0,
Self::DRIFT_ERROR_THRESHOLD * 100.0
)
} else if staleness_score > 0.8 {
format!(
"Model stale: {} samples since last training (threshold: {})",
samples_since_training,
Self::STALENESS_THRESHOLD
)
} else {
format!(
"Model fresh: mean error {:.1}%, {} samples since training",
mean_error * 100.0,
samples_since_training
)
};
ConceptDriftStatus {
drift_detected,
staleness_score,
samples_since_training,
recommend_retrain,
explanation,
}
}
pub fn should_retrain(&self) -> bool {
if !self.online_learning_enabled {
return false;
}
let samples_since = self.samples.len().saturating_sub(self.samples_at_last_train);
if samples_since >= self.retrain_threshold {
return true;
}
let drift = self.detect_concept_drift();
drift.recommend_retrain && samples_since >= 10
}
pub fn mark_trained(&mut self) {
self.samples_at_last_train = self.samples.len();
self.error_window.clear();
}
pub fn training_stats(&self) -> TrainingStats {
let drift = self.detect_concept_drift();
let accepted_count = self.feedback.iter().filter(|f| **f == UserFeedback::Accepted).count();
let rejected_count = self.feedback.iter().filter(|f| **f == UserFeedback::Rejected).count();
let alternative_count =
self.feedback.iter().filter(|f| **f == UserFeedback::Alternative).count();
TrainingStats {
total_samples: self.samples.len(),
samples_since_training: drift.samples_since_training,
accepted_count,
rejected_count,
alternative_count,
staleness_score: drift.staleness_score,
drift_detected: drift.drift_detected,
online_learning_enabled: self.online_learning_enabled,
}
}
pub fn auto_retrain(&mut self, tuner: &mut BrickTuner) -> bool {
if !self.should_retrain() {
return false;
}
let training_data = self.prepare_weighted_training_data();
if training_data.len() < 10 {
return false;
}
match tuner.train(&training_data) {
Ok(()) => {
self.mark_trained();
true
}
Err(_) => false,
}
}
pub(super) fn prepare_weighted_training_data(&self) -> Vec<(TunerFeatures, f32)> {
self.samples
.iter()
.enumerate()
.filter_map(|(i, s)| {
let feedback = self.get_feedback(i);
if feedback == UserFeedback::Rejected {
return None;
}
let weight = match feedback {
UserFeedback::Accepted => 2,
UserFeedback::Alternative | UserFeedback::Rejected | UserFeedback::None => 1,
};
Some((0..weight).map(|_| (s.features.clone(), s.throughput_tps)))
})
.flatten()
.collect()
}
}