use crate::primitives::{Matrix, Vector};
#[derive(Clone, Debug, PartialEq)]
pub enum DriftStatus {
NoDrift,
Warning { score: f32 },
Drift { score: f32 },
}
impl DriftStatus {
#[must_use]
pub fn needs_retraining(&self) -> bool {
matches!(self, DriftStatus::Drift { .. })
}
#[must_use]
pub fn score(&self) -> Option<f32> {
match self {
DriftStatus::NoDrift => None,
DriftStatus::Warning { score } | DriftStatus::Drift { score } => Some(*score),
}
}
}
#[derive(Clone, Debug)]
pub struct DriftConfig {
pub warning_threshold: f32,
pub drift_threshold: f32,
pub min_samples: usize,
pub window_size: usize,
}
impl Default for DriftConfig {
fn default() -> Self {
Self {
warning_threshold: 0.1,
drift_threshold: 0.2,
min_samples: 30,
window_size: 100,
}
}
}
impl DriftConfig {
#[must_use]
pub fn new(warning: f32, drift: f32) -> Self {
Self {
warning_threshold: warning,
drift_threshold: drift,
..Default::default()
}
}
#[must_use]
pub fn with_min_samples(mut self, min: usize) -> Self {
self.min_samples = min;
self
}
#[must_use]
pub fn with_window_size(mut self, size: usize) -> Self {
self.window_size = size;
self
}
}
#[derive(Debug)]
pub struct DriftDetector {
config: DriftConfig,
}
impl DriftDetector {
#[must_use]
pub fn new(config: DriftConfig) -> Self {
Self { config }
}
#[must_use]
pub fn detect_univariate(&self, reference: &Vector<f32>, current: &Vector<f32>) -> DriftStatus {
if reference.len() < self.config.min_samples || current.len() < self.config.min_samples {
return DriftStatus::NoDrift;
}
let ref_mean = reference.mean();
let cur_mean = current.mean();
let ref_std = std_dev(reference.as_slice(), ref_mean);
if ref_std < 1e-10 {
return DriftStatus::NoDrift;
}
let score = (ref_mean - cur_mean).abs() / ref_std;
self.classify_drift(score)
}
#[must_use]
pub fn detect_multivariate(
&self,
reference: &Matrix<f32>,
current: &Matrix<f32>,
) -> (DriftStatus, Vec<DriftStatus>) {
let n_features = reference.n_cols();
let mut feature_statuses = Vec::with_capacity(n_features);
let mut max_score: f32 = 0.0;
for col in 0..n_features {
let ref_col = reference.column(col);
let cur_col = current.column(col);
let status = self.detect_univariate(&ref_col, &cur_col);
if let Some(score) = status.score() {
max_score = max_score.max(score);
}
feature_statuses.push(status);
}
let overall = self.classify_drift(max_score);
(overall, feature_statuses)
}
#[must_use]
pub fn detect_performance_drift(
&self,
baseline_scores: &[f32],
current_scores: &[f32],
) -> DriftStatus {
if baseline_scores.is_empty() || current_scores.is_empty() {
return DriftStatus::NoDrift;
}
let baseline_mean = mean(baseline_scores);
let current_mean = mean(current_scores);
let baseline_std = std_dev(baseline_scores, baseline_mean);
if baseline_std < 1e-10 {
let relative_drop = (baseline_mean - current_mean) / baseline_mean.abs().max(1e-10);
return self.classify_drift(relative_drop.max(0.0));
}
let score = (baseline_mean - current_mean) / baseline_std;
self.classify_drift(score.max(0.0))
}
fn classify_drift(&self, score: f32) -> DriftStatus {
if score >= self.config.drift_threshold {
DriftStatus::Drift { score }
} else if score >= self.config.warning_threshold {
DriftStatus::Warning { score }
} else {
DriftStatus::NoDrift
}
}
}
#[derive(Debug)]
pub struct RollingDriftMonitor {
reference_window: Vec<f32>,
current_window: Vec<f32>,
detector: DriftDetector,
max_window: usize,
}
impl RollingDriftMonitor {
#[must_use]
pub fn new(config: DriftConfig) -> Self {
let max_window = config.window_size;
Self {
reference_window: Vec::with_capacity(max_window),
current_window: Vec::with_capacity(max_window),
detector: DriftDetector::new(config),
max_window,
}
}
pub fn set_reference(&mut self, data: &[f32]) {
self.reference_window.clear();
let start = if data.len() > self.max_window {
data.len() - self.max_window
} else {
0
};
self.reference_window.extend_from_slice(&data[start..]);
}
pub fn observe(&mut self, value: f32) -> DriftStatus {
self.current_window.push(value);
if self.current_window.len() > self.max_window {
self.current_window.remove(0);
}
self.check_drift()
}
#[must_use]
pub fn check_drift(&self) -> DriftStatus {
if self.reference_window.is_empty() {
return DriftStatus::NoDrift;
}
let ref_vec = Vector::from_slice(&self.reference_window);
let cur_vec = Vector::from_slice(&self.current_window);
self.detector.detect_univariate(&ref_vec, &cur_vec)
}
pub fn reset_current(&mut self) {
self.current_window.clear();
}
pub fn update_reference(&mut self) {
self.reference_window.clone_from(&self.current_window);
self.current_window.clear();
}
}
#[derive(Debug)]
pub struct RetrainingTrigger {
performance_monitor: RollingDriftMonitor,
feature_monitors: Vec<RollingDriftMonitor>,
consecutive_required: usize,
consecutive_count: usize,
}
impl RetrainingTrigger {
#[must_use]
pub fn new(n_features: usize, config: DriftConfig) -> Self {
let feature_monitors = (0..n_features)
.map(|_| RollingDriftMonitor::new(config.clone()))
.collect();
Self {
performance_monitor: RollingDriftMonitor::new(config),
feature_monitors,
consecutive_required: 3,
consecutive_count: 0,
}
}
#[must_use]
pub fn with_consecutive_required(mut self, count: usize) -> Self {
self.consecutive_required = count.max(1);
self
}
pub fn set_baseline_performance(&mut self, scores: &[f32]) {
self.performance_monitor.set_reference(scores);
}
pub fn set_baseline_features(&mut self, features: &Matrix<f32>) {
for (i, monitor) in self.feature_monitors.iter_mut().enumerate() {
if i < features.n_cols() {
let col: Vec<f32> = (0..features.n_rows()).map(|r| features.get(r, i)).collect();
monitor.set_reference(&col);
}
}
}
pub fn observe_performance(&mut self, score: f32) -> bool {
let status = self.performance_monitor.observe(score);
if status.needs_retraining() {
self.consecutive_count += 1;
} else {
self.consecutive_count = 0;
}
self.consecutive_count >= self.consecutive_required
}
#[must_use]
pub fn is_triggered(&self) -> bool {
self.consecutive_count >= self.consecutive_required
}
pub fn reset(&mut self) {
self.consecutive_count = 0;
self.performance_monitor.reset_current();
for monitor in &mut self.feature_monitors {
monitor.reset_current();
}
}
}
fn mean(data: &[f32]) -> f32 {
if data.is_empty() {
return 0.0;
}
data.iter().sum::<f32>() / data.len() as f32
}
fn std_dev(data: &[f32], mean_val: f32) -> f32 {
batuta_common::math::std_dev_f32_with_mean(data, mean_val)
}
#[cfg(test)]
#[path = "drift_tests.rs"]
mod tests;
#[cfg(test)]
#[path = "tests_drift_contract.rs"]
mod tests_drift_contract;