use super::metrics::{DistributionMetric, PSI};
use crate::dataset::{BinnedDataset, FeatureInfo, FeatureType};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AlertLevel {
None,
Warning,
Critical,
}
#[derive(Debug, Clone)]
pub struct ShiftResult {
pub feature_scores: Vec<(String, f32)>,
pub overall_score: f32,
pub alert: AlertLevel,
pub drifted_features: Vec<String>,
pub critical_features: Vec<String>,
}
impl ShiftResult {
pub fn has_drift(&self) -> bool {
self.alert != AlertLevel::None
}
pub fn n_drifted(&self) -> usize {
self.drifted_features.len()
}
pub fn n_critical(&self) -> usize {
self.critical_features.len()
}
}
pub const MAX_HISTOGRAM_BINS: usize = 256;
pub struct ShiftDetector {
reference_histograms: Vec<Vec<u32>>,
bins_per_feature: Vec<usize>,
reference_counts: Vec<usize>,
feature_info: Vec<FeatureInfo>,
metric: Box<dyn DistributionMetric>,
warning_threshold: f32,
critical_threshold: f32,
}
impl ShiftDetector {
pub fn from_dataset(dataset: &BinnedDataset) -> Self {
let num_features = dataset.num_features();
let num_rows = dataset.num_rows();
let feature_info = dataset.all_feature_info().to_vec();
let bins_per_feature: Vec<usize> = feature_info
.iter()
.map(|info| {
let bins = info.num_bins as usize;
if bins == 0 {
MAX_HISTOGRAM_BINS
} else {
bins
}
})
.collect();
let mut reference_histograms: Vec<Vec<u32>> = bins_per_feature
.iter()
.map(|&bins| vec![0u32; bins])
.collect();
#[allow(clippy::needless_range_loop)] for row in 0..num_rows {
for feature in 0..num_features {
let bin = dataset.get_bin(row, feature) as usize;
if bin < reference_histograms[feature].len() {
reference_histograms[feature][bin] += 1;
}
}
}
let reference_counts = vec![num_rows; num_features];
let metric = Box::new(PSI::default());
let warning_threshold = metric.warning_threshold();
let critical_threshold = metric.critical_threshold();
Self {
reference_histograms,
bins_per_feature,
reference_counts,
feature_info,
metric,
warning_threshold,
critical_threshold,
}
}
pub fn from_raw(
features: &[f32],
num_features: usize,
feature_names: Option<&[String]>,
) -> Self {
let num_rows = if num_features > 0 {
features.len() / num_features
} else {
0
};
let bins_per_feature = vec![MAX_HISTOGRAM_BINS; num_features];
let reference_histograms = vec![vec![0u32; MAX_HISTOGRAM_BINS]; num_features];
let reference_counts = vec![num_rows; num_features];
let feature_info: Vec<FeatureInfo> = (0..num_features)
.map(|i| {
let name = feature_names
.and_then(|names| names.get(i).cloned())
.unwrap_or_else(|| format!("feature_{}", i));
FeatureInfo {
name,
feature_type: FeatureType::Numeric,
num_bins: 0,
bin_boundaries: vec![],
}
})
.collect();
let metric = Box::new(PSI::default());
let warning_threshold = metric.warning_threshold();
let critical_threshold = metric.critical_threshold();
Self {
reference_histograms,
bins_per_feature,
reference_counts,
feature_info,
metric,
warning_threshold,
critical_threshold,
}
}
pub fn with_thresholds(mut self, warning: f32, critical: f32) -> Self {
self.warning_threshold = warning;
self.critical_threshold = critical;
self
}
pub fn with_metric<M: DistributionMetric + 'static>(mut self, metric: M) -> Self {
self.warning_threshold = metric.warning_threshold();
self.critical_threshold = metric.critical_threshold();
self.metric = Box::new(metric);
self
}
fn compute_threshold_multiplier(&self, feature_idx: usize, inference_count: usize) -> f32 {
debug_assert!(
feature_idx < self.reference_counts.len(),
"Feature index {} out of bounds (max: {})",
feature_idx,
self.reference_counts.len()
);
let ref_count = self.reference_counts[feature_idx];
if ref_count == 0 || inference_count == 0 {
return 1.0; }
((ref_count + inference_count) as f32 / inference_count as f32).sqrt()
}
pub fn check(&self, inference_data: &BinnedDataset) -> ShiftResult {
let num_features = self.feature_info.len().min(inference_data.num_features());
let num_rows = inference_data.num_rows();
if num_rows == 0 || num_features == 0 {
return ShiftResult {
feature_scores: Vec::new(),
overall_score: 0.0,
alert: AlertLevel::None,
drifted_features: Vec::new(),
critical_features: Vec::new(),
};
}
let mut inference_histograms: Vec<Vec<u32>> = self
.bins_per_feature
.iter()
.take(num_features)
.map(|&bins| vec![0u32; bins])
.collect();
#[allow(clippy::needless_range_loop)] for row in 0..num_rows {
for feature in 0..num_features {
let bin = inference_data.get_bin(row, feature) as usize;
if bin < inference_histograms[feature].len() {
inference_histograms[feature][bin] += 1;
}
}
}
let mut feature_scores = Vec::with_capacity(num_features);
let mut drifted_features = Vec::new();
let mut critical_features = Vec::new();
let mut total_score = 0.0f32;
for (i, feature_info) in self.feature_info.iter().take(num_features).enumerate() {
let ref_total: f32 = self.reference_histograms[i].iter().map(|&c| c as f32).sum();
let inf_total: f32 = inference_histograms[i].iter().map(|&c| c as f32).sum();
if ref_total == 0.0 || inf_total == 0.0 {
feature_scores.push((feature_info.name.clone(), 0.0));
continue;
}
let ref_probs: Vec<f32> = self.reference_histograms[i]
.iter()
.map(|&c| c as f32 / ref_total)
.collect();
let inf_probs: Vec<f32> = inference_histograms[i]
.iter()
.map(|&c| c as f32 / inf_total)
.collect();
let score = self.metric.compute(&ref_probs, &inf_probs);
total_score += score;
feature_scores.push((feature_info.name.clone(), score));
let threshold_multiplier = self.compute_threshold_multiplier(i, num_rows);
let adjusted_warning = self.warning_threshold * threshold_multiplier;
let adjusted_critical = self.critical_threshold * threshold_multiplier;
if score >= adjusted_critical {
drifted_features.push(feature_info.name.clone());
critical_features.push(feature_info.name.clone());
} else if score >= adjusted_warning {
drifted_features.push(feature_info.name.clone());
}
}
let overall_score = total_score / num_features as f32;
let alert = if !critical_features.is_empty() || overall_score >= self.critical_threshold {
AlertLevel::Critical
} else if !drifted_features.is_empty() || overall_score >= self.warning_threshold {
AlertLevel::Warning
} else {
AlertLevel::None
};
ShiftResult {
feature_scores,
overall_score,
alert,
drifted_features,
critical_features,
}
}
pub fn feature_info(&self) -> &[FeatureInfo] {
&self.feature_info
}
pub fn warning_threshold(&self) -> f32 {
self.warning_threshold
}
pub fn critical_threshold(&self) -> f32 {
self.critical_threshold
}
pub fn metric_name(&self) -> &'static str {
self.metric.name()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_alert_level() {
assert_eq!(AlertLevel::None, AlertLevel::None);
assert_ne!(AlertLevel::None, AlertLevel::Warning);
assert_ne!(AlertLevel::Warning, AlertLevel::Critical);
}
#[test]
fn test_shift_result() {
let result = ShiftResult {
feature_scores: vec![("a".to_string(), 0.1), ("b".to_string(), 0.3)],
overall_score: 0.2,
alert: AlertLevel::Warning,
drifted_features: vec!["a".to_string(), "b".to_string()],
critical_features: vec!["b".to_string()],
};
assert!(result.has_drift());
assert_eq!(result.n_drifted(), 2);
assert_eq!(result.n_critical(), 1);
}
#[test]
fn test_shift_result_no_drift() {
let result = ShiftResult {
feature_scores: vec![("a".to_string(), 0.01)],
overall_score: 0.01,
alert: AlertLevel::None,
drifted_features: vec![],
critical_features: vec![],
};
assert!(!result.has_drift());
assert_eq!(result.n_drifted(), 0);
}
#[test]
fn test_threshold_multiplier() {
use crate::dataset::FeatureType;
let feature_info = vec![FeatureInfo {
name: "test".to_string(),
feature_type: FeatureType::Numeric,
num_bins: 2,
bin_boundaries: vec![0.0, 0.5, 1.0],
}];
let detector = ShiftDetector {
reference_histograms: vec![vec![0; 2]], bins_per_feature: vec![2],
reference_counts: vec![1000], feature_info,
metric: Box::new(PSI::default()),
warning_threshold: 0.1,
critical_threshold: 0.25,
};
let mult_equal = detector.compute_threshold_multiplier(0, 1000);
assert!(
(mult_equal - 1.414).abs() < 0.01,
"Expected ~1.414, got {}",
mult_equal
);
let mult_small = detector.compute_threshold_multiplier(0, 100);
assert!(
mult_small > mult_equal,
"Small samples should have larger multipliers (more lenient)"
);
assert!(
(mult_small - 3.32).abs() < 0.01,
"Expected ~3.32, got {}",
mult_small
);
let mult_large = detector.compute_threshold_multiplier(0, 10000);
assert!(
mult_large < mult_equal,
"Large samples should have smaller multipliers (stricter)"
);
assert!(
(mult_large - 1.049).abs() < 0.01,
"Expected ~1.049, got {}",
mult_large
);
assert!(
mult_small > mult_large,
"Smaller inference samples should have LARGER multipliers (more lenient thresholds). Got small={}, large={}",
mult_small, mult_large
);
assert!(
mult_small > mult_equal && mult_equal > mult_large,
"Expected: small ({}) > equal ({}) > large ({})",
mult_small,
mult_equal,
mult_large
);
}
}