use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub const MAX_CONSECUTIVE_MISSING_RUNS: usize = 3;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum DriftValidity {
Valid,
Unknown,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct FeatureMissingnessTracker {
pub feature_id: String,
pub consecutive_missing: usize,
pub total_missing: usize,
pub invalidation_events: Vec<usize>,
}
impl FeatureMissingnessTracker {
pub fn new(feature_id: impl Into<String>) -> Self {
Self {
feature_id: feature_id.into(),
..Default::default()
}
}
pub fn update(&mut self, is_missing: bool, run_index: usize) -> DriftValidity {
if is_missing {
self.consecutive_missing += 1;
self.total_missing += 1;
} else {
self.consecutive_missing = 0;
}
if self.consecutive_missing > MAX_CONSECUTIVE_MISSING_RUNS {
self.invalidation_events.push(run_index);
DriftValidity::Unknown
} else {
DriftValidity::Valid
}
}
#[must_use]
pub fn is_invalidated(&self) -> bool {
self.consecutive_missing > MAX_CONSECUTIVE_MISSING_RUNS
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MissingnessAwareRecord {
pub run_index: usize,
pub feature_id: String,
pub value: f64,
pub is_missing: bool,
pub drift_validity: DriftValidity,
pub suppressed_by_missingness: bool,
}
#[derive(Debug, Default)]
pub struct MissingnessAwareGrammar {
trackers: HashMap<String, FeatureMissingnessTracker>,
}
impl MissingnessAwareGrammar {
pub fn new() -> Self {
Self::default()
}
pub fn process(
&mut self,
feature_id: &str,
values: &[f64],
is_imputed: &[bool],
) -> Vec<MissingnessAwareRecord> {
assert_eq!(
values.len(),
is_imputed.len(),
"values and is_imputed must have equal length"
);
let tracker = self
.trackers
.entry(feature_id.to_string())
.or_insert_with(|| FeatureMissingnessTracker::new(feature_id));
tracker.consecutive_missing = 0;
tracker.invalidation_events.clear();
values
.iter()
.zip(is_imputed.iter())
.enumerate()
.map(|(run_index, (&value, &is_missing))| {
let drift_validity = tracker.update(is_missing, run_index);
MissingnessAwareRecord {
run_index,
feature_id: feature_id.to_string(),
value,
is_missing,
drift_validity,
suppressed_by_missingness: drift_validity == DriftValidity::Unknown,
}
})
.collect()
}
pub fn trackers(&self) -> &HashMap<String, FeatureMissingnessTracker> {
&self.trackers
}
pub fn summary(&self) -> MissingSummary {
let total_features = self.trackers.len();
let features_with_invalidations = self
.trackers
.values()
.filter(|t| !t.invalidation_events.is_empty())
.count();
let total_invalidation_events: usize = self
.trackers
.values()
.map(|t| t.invalidation_events.len())
.sum();
let total_missing_observations: usize =
self.trackers.values().map(|t| t.total_missing).sum();
MissingSummary {
total_features,
features_with_invalidations,
total_invalidation_events,
total_missing_observations,
max_consecutive_missing_threshold: MAX_CONSECUTIVE_MISSING_RUNS,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MissingSummary {
pub total_features: usize,
pub features_with_invalidations: usize,
pub total_invalidation_events: usize,
pub total_missing_observations: usize,
pub max_consecutive_missing_threshold: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tracker_valid_below_threshold() {
let mut t = FeatureMissingnessTracker::new("S001");
assert_eq!(t.update(true, 0), DriftValidity::Valid);
assert_eq!(t.update(true, 1), DriftValidity::Valid);
assert_eq!(t.update(true, 2), DriftValidity::Valid);
assert_eq!(t.consecutive_missing, 3);
}
#[test]
fn tracker_invalidates_after_threshold() {
let mut t = FeatureMissingnessTracker::new("S002");
for i in 0..=3 {
t.update(true, i);
}
assert_eq!(t.update(true, 4), DriftValidity::Unknown);
assert!(t.is_invalidated());
}
#[test]
fn tracker_resets_on_valid_observation() {
let mut t = FeatureMissingnessTracker::new("S003");
for i in 0..10 {
t.update(true, i);
}
assert_eq!(t.update(false, 10), DriftValidity::Valid);
assert_eq!(t.consecutive_missing, 0);
assert!(!t.is_invalidated());
}
#[test]
fn grammar_filter_suppresses_after_threshold() {
let mut grammar = MissingnessAwareGrammar::new();
let values: Vec<f64> = vec![0.0; 8];
let is_imputed = vec![true, true, true, true, false, false, false, false];
let records = grammar.process("S001", &values, &is_imputed);
assert_eq!(records[0].drift_validity, DriftValidity::Valid);
assert_eq!(records[2].drift_validity, DriftValidity::Valid);
assert_eq!(records[3].drift_validity, DriftValidity::Unknown);
assert!(records[3].suppressed_by_missingness);
assert_eq!(records[4].drift_validity, DriftValidity::Valid);
assert!(!records[4].suppressed_by_missingness);
}
#[test]
fn summary_counts_invalidated_features() {
let mut grammar = MissingnessAwareGrammar::new();
let values = vec![0.0; 5];
let all_missing = vec![true; 5];
grammar.process("S_BAD", &values, &all_missing);
let none_missing = vec![false; 5];
grammar.process("S_GOOD", &values, &none_missing);
let summary = grammar.summary();
assert_eq!(summary.total_features, 2);
assert_eq!(summary.features_with_invalidations, 1);
}
}