use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DriftConfig {
pub window_size: usize,
pub num_windows: usize,
pub confidence_drift_threshold: f64,
pub distribution_drift_threshold: f64,
pub vocab_drift_threshold: f64,
pub min_samples: usize,
}
impl Default for DriftConfig {
fn default() -> Self {
Self {
window_size: 1000,
num_windows: 5,
confidence_drift_threshold: 0.1,
distribution_drift_threshold: 0.5,
vocab_drift_threshold: 0.2,
min_samples: 500,
}
}
}
#[derive(Debug, Clone)]
struct PredictionLog {
#[allow(dead_code)]
timestamp: u64,
confidence: f64,
entity_type: String,
entity_text: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DriftWindow {
pub window_id: usize,
pub mean_confidence: f64,
pub std_confidence: f64,
pub type_distribution: HashMap<String, f64>,
pub count: usize,
pub unique_tokens: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DriftReport {
pub drift_detected: bool,
pub summary: String,
pub confidence_drift: ConfidenceDrift,
pub distribution_drift: DistributionDrift,
pub vocabulary_drift: VocabularyDrift,
pub windows: Vec<DriftWindow>,
pub recommendations: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConfidenceDrift {
pub baseline_mean: f64,
pub current_mean: f64,
pub drift_amount: f64,
pub is_significant: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DistributionDrift {
pub kl_divergence: f64,
pub increased_types: Vec<(String, f64)>,
pub decreased_types: Vec<(String, f64)>,
pub new_types: Vec<String>,
pub is_significant: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VocabularyDrift {
pub baseline_vocab_size: usize,
pub current_vocab_size: usize,
pub new_token_rate: f64,
pub is_significant: bool,
}
#[derive(Debug, Clone)]
pub struct DriftDetector {
config: DriftConfig,
predictions: VecDeque<PredictionLog>,
baseline_vocab: HashMap<String, usize>,
current_vocab: HashMap<String, usize>,
baseline_established: bool,
}
impl DriftDetector {
pub fn new(config: DriftConfig) -> Self {
let max_size = config.window_size * config.num_windows;
Self {
config,
predictions: VecDeque::with_capacity(max_size),
baseline_vocab: HashMap::new(),
current_vocab: HashMap::new(),
baseline_established: false,
}
}
pub fn log_prediction(
&mut self,
timestamp: u64,
confidence: f64,
entity_type: &str,
entity_text: &str,
) {
let log = PredictionLog {
timestamp,
confidence,
entity_type: entity_type.to_string(),
entity_text: entity_text.to_string(),
};
for token in entity_text.split_whitespace() {
let lower = token.to_lowercase();
*self.current_vocab.entry(lower).or_insert(0) += 1;
}
let max_size = self.config.window_size * self.config.num_windows;
if self.predictions.len() >= max_size {
self.predictions.pop_front();
}
self.predictions.push_back(log);
if !self.baseline_established && self.predictions.len() >= self.config.min_samples {
self.establish_baseline();
}
}
fn establish_baseline(&mut self) {
self.baseline_vocab = self.current_vocab.clone();
self.baseline_established = true;
}
pub fn reset(&mut self) {
self.predictions.clear();
self.baseline_vocab.clear();
self.current_vocab.clear();
self.baseline_established = false;
}
pub fn analyze(&self) -> DriftReport {
if self.predictions.len() < self.config.min_samples {
return DriftReport {
drift_detected: false,
summary: format!(
"Insufficient data: {} predictions (need {})",
self.predictions.len(),
self.config.min_samples
),
confidence_drift: ConfidenceDrift {
baseline_mean: 0.0,
current_mean: 0.0,
drift_amount: 0.0,
is_significant: false,
},
distribution_drift: DistributionDrift {
kl_divergence: 0.0,
increased_types: Vec::new(),
decreased_types: Vec::new(),
new_types: Vec::new(),
is_significant: false,
},
vocabulary_drift: VocabularyDrift {
baseline_vocab_size: 0,
current_vocab_size: 0,
new_token_rate: 0.0,
is_significant: false,
},
windows: Vec::new(),
recommendations: vec!["Collect more data for drift analysis".into()],
};
}
let windows = self.compute_windows();
let confidence_drift = self.analyze_confidence_drift(&windows);
let distribution_drift = self.analyze_distribution_drift(&windows);
let vocabulary_drift = self.analyze_vocabulary_drift();
let drift_detected = confidence_drift.is_significant
|| distribution_drift.is_significant
|| vocabulary_drift.is_significant;
let (summary, recommendations) = self.generate_summary_and_recommendations(
drift_detected,
&confidence_drift,
&distribution_drift,
&vocabulary_drift,
);
DriftReport {
drift_detected,
summary,
confidence_drift,
distribution_drift,
vocabulary_drift,
windows,
recommendations,
}
}
fn compute_windows(&self) -> Vec<DriftWindow> {
let predictions: Vec<_> = self.predictions.iter().collect();
let window_size = self.config.window_size.min(predictions.len());
if window_size == 0 {
return Vec::new();
}
let num_windows = (predictions.len() / window_size).min(self.config.num_windows);
let mut windows = Vec::new();
for i in 0..num_windows {
let start = predictions.len() - (num_windows - i) * window_size;
let end = start + window_size;
let window_preds = &predictions[start..end];
let confidences: Vec<f64> = window_preds.iter().map(|p| p.confidence).collect();
let mean_conf = confidences.iter().sum::<f64>() / confidences.len() as f64;
let std_conf = (confidences
.iter()
.map(|c| (c - mean_conf).powi(2))
.sum::<f64>()
/ confidences.len() as f64)
.sqrt();
let mut type_counts: HashMap<String, usize> = HashMap::new();
let mut unique_tokens = std::collections::HashSet::new();
for pred in window_preds {
*type_counts.entry(pred.entity_type.clone()).or_insert(0) += 1;
for token in pred.entity_text.split_whitespace() {
unique_tokens.insert(token.to_lowercase());
}
}
let total = window_preds.len() as f64;
let type_distribution: HashMap<String, f64> = type_counts
.iter()
.map(|(t, c)| (t.clone(), *c as f64 / total))
.collect();
windows.push(DriftWindow {
window_id: i,
mean_confidence: mean_conf,
std_confidence: std_conf,
type_distribution,
count: window_preds.len(),
unique_tokens: unique_tokens.len(),
});
}
windows
}
fn analyze_confidence_drift(&self, windows: &[DriftWindow]) -> ConfidenceDrift {
if windows.len() < 2 {
return ConfidenceDrift {
baseline_mean: 0.0,
current_mean: 0.0,
drift_amount: 0.0,
is_significant: false,
};
}
let baseline_mean = windows[0].mean_confidence;
let current_mean = windows.last().map(|w| w.mean_confidence).unwrap_or(0.0);
let drift_amount = current_mean - baseline_mean;
let is_significant = drift_amount.abs() > self.config.confidence_drift_threshold;
ConfidenceDrift {
baseline_mean,
current_mean,
drift_amount,
is_significant,
}
}
fn analyze_distribution_drift(&self, windows: &[DriftWindow]) -> DistributionDrift {
if windows.len() < 2 {
return DistributionDrift {
kl_divergence: 0.0,
increased_types: Vec::new(),
decreased_types: Vec::new(),
new_types: Vec::new(),
is_significant: false,
};
}
let baseline = &windows[0].type_distribution;
let current = &windows[windows.len() - 1].type_distribution;
let epsilon = 1e-10;
let mut kl_div = 0.0;
for (typ, p) in current {
let q = baseline.get(typ).copied().unwrap_or(epsilon);
kl_div += p * (p / q).ln();
}
let mut increased = Vec::new();
let mut decreased = Vec::new();
let mut new_types = Vec::new();
for (typ, curr_freq) in current {
if let Some(base_freq) = baseline.get(typ) {
let change = curr_freq - base_freq;
if change > 0.05 {
increased.push((typ.clone(), change));
} else if change < -0.05 {
decreased.push((typ.clone(), change));
}
} else {
new_types.push(typ.clone());
}
}
increased.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
decreased.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let is_significant =
kl_div > self.config.distribution_drift_threshold || !new_types.is_empty();
DistributionDrift {
kl_divergence: kl_div,
increased_types: increased,
decreased_types: decreased,
new_types,
is_significant,
}
}
fn analyze_vocabulary_drift(&self) -> VocabularyDrift {
if !self.baseline_established {
return VocabularyDrift {
baseline_vocab_size: 0,
current_vocab_size: self.current_vocab.len(),
new_token_rate: 0.0,
is_significant: false,
};
}
let baseline_size = self.baseline_vocab.len();
let current_size = self.current_vocab.len();
let new_tokens: usize = self
.current_vocab
.keys()
.filter(|t| !self.baseline_vocab.contains_key(*t))
.count();
let new_token_rate = if current_size == 0 {
0.0
} else {
new_tokens as f64 / current_size as f64
};
let is_significant = new_token_rate > self.config.vocab_drift_threshold;
VocabularyDrift {
baseline_vocab_size: baseline_size,
current_vocab_size: current_size,
new_token_rate,
is_significant,
}
}
fn generate_summary_and_recommendations(
&self,
drift_detected: bool,
confidence: &ConfidenceDrift,
distribution: &DistributionDrift,
vocabulary: &VocabularyDrift,
) -> (String, Vec<String>) {
let mut issues = Vec::new();
let mut recommendations = Vec::new();
if confidence.is_significant {
if confidence.drift_amount < 0.0 {
issues.push(format!(
"Confidence dropped by {:.1}%",
confidence.drift_amount.abs() * 100.0
));
recommendations
.push("Model may be encountering harder examples - consider retraining".into());
} else {
issues.push(format!(
"Confidence increased by {:.1}%",
confidence.drift_amount * 100.0
));
recommendations
.push("Verify model isn't becoming overconfident on new patterns".into());
}
}
if distribution.is_significant {
issues.push(format!(
"Entity type distribution shifted (KL={:.2})",
distribution.kl_divergence
));
if !distribution.new_types.is_empty() {
recommendations.push(format!(
"New entity types detected: {:?} - update training data",
distribution.new_types
));
}
}
if vocabulary.is_significant {
issues.push(format!(
"{:.1}% new vocabulary",
vocabulary.new_token_rate * 100.0
));
recommendations
.push("Significant vocabulary shift - consider domain adaptation".into());
}
let summary = if drift_detected {
format!("Drift detected: {}", issues.join("; "))
} else {
"No significant drift detected".into()
};
if recommendations.is_empty() {
recommendations.push("Continue monitoring".into());
}
(summary, recommendations)
}
}
impl Default for DriftDetector {
fn default() -> Self {
Self::new(DriftConfig::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_insufficient_data() {
let detector = DriftDetector::default();
let report = detector.analyze();
assert!(!report.drift_detected);
assert!(report.summary.contains("Insufficient"));
}
#[test]
fn test_no_drift() {
let mut detector = DriftDetector::new(DriftConfig {
min_samples: 10,
window_size: 5,
num_windows: 2,
..Default::default()
});
for i in 0..20 {
detector.log_prediction(i as u64, 0.90, "PER", "John Smith");
}
let report = detector.analyze();
assert!(!report.confidence_drift.is_significant);
}
#[test]
fn test_confidence_drift_detection() {
let mut detector = DriftDetector::new(DriftConfig {
min_samples: 10,
window_size: 10,
num_windows: 2,
confidence_drift_threshold: 0.1,
..Default::default()
});
for i in 0..10 {
detector.log_prediction(i as u64, 0.95, "PER", "John");
}
for i in 10..20 {
detector.log_prediction(i as u64, 0.60, "PER", "John");
}
let report = detector.analyze();
assert!(report.confidence_drift.is_significant);
assert!(report.confidence_drift.drift_amount < 0.0);
}
#[test]
fn test_vocabulary_drift() {
let mut detector = DriftDetector::new(DriftConfig {
min_samples: 5,
window_size: 5,
num_windows: 2,
vocab_drift_threshold: 0.3,
..Default::default()
});
for i in 0..5 {
detector.log_prediction(i as u64, 0.90, "PER", "John Smith");
}
for i in 5..10 {
detector.log_prediction(i as u64, 0.90, "PER", "Xiangjun Chen Zhang Wei");
}
let report = detector.analyze();
assert!(report.vocabulary_drift.new_token_rate > 0.0);
}
#[test]
fn test_reset() {
let mut detector = DriftDetector::default();
detector.log_prediction(0, 0.9, "PER", "Test");
detector.reset();
let report = detector.analyze();
assert!(report.summary.contains("Insufficient"));
}
}