use rand::{Rng, RngExt};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone)]
pub enum MissingValueStrategy {
MCAR {
probability: f64,
},
MAR {
base_probability: f64,
conditions: Vec<MissingCondition>,
},
MNAR {
value_patterns: Vec<MissingPattern>,
},
Systematic {
field_groups: Vec<Vec<String>>,
probability: f64,
},
}
impl Default for MissingValueStrategy {
fn default() -> Self {
MissingValueStrategy::MCAR { probability: 0.01 }
}
}
#[derive(Debug, Clone)]
pub struct MissingCondition {
pub field: String,
pub condition_type: ConditionType,
pub multiplier: f64,
}
#[derive(Debug, Clone)]
pub enum ConditionType {
Equals(String),
Contains(String),
IsEmpty,
Matches(String),
GreaterThan(f64),
LessThan(f64),
}
#[derive(Debug, Clone)]
pub struct MissingPattern {
pub description: String,
pub field: String,
pub pattern_type: PatternType,
pub probability: f64,
}
#[derive(Debug, Clone)]
pub enum PatternType {
HighValues { threshold: f64 },
LowValues { threshold: f64 },
ExtremeValues { low: f64, high: f64 },
SensitivePatterns { patterns: Vec<String> },
}
#[derive(Debug, Clone)]
pub struct MissingValueConfig {
pub global_rate: f64,
pub field_rates: HashMap<String, f64>,
pub required_fields: HashSet<String>,
pub strategy: MissingValueStrategy,
pub track_statistics: bool,
}
impl Default for MissingValueConfig {
fn default() -> Self {
let mut required_fields = HashSet::new();
required_fields.insert("document_number".to_string());
required_fields.insert("company_code".to_string());
required_fields.insert("posting_date".to_string());
required_fields.insert("account_code".to_string());
Self {
global_rate: 0.01,
field_rates: HashMap::new(),
required_fields,
strategy: MissingValueStrategy::default(),
track_statistics: true,
}
}
}
impl MissingValueConfig {
pub fn with_field_rates(mut self, rates: HashMap<String, f64>) -> Self {
self.field_rates = rates;
self
}
pub fn with_required_field(mut self, field: &str) -> Self {
self.required_fields.insert(field.to_string());
self
}
pub fn with_strategy(mut self, strategy: MissingValueStrategy) -> Self {
self.strategy = strategy;
self
}
pub fn get_rate(&self, field: &str) -> f64 {
if self.required_fields.contains(field) {
return 0.0;
}
*self.field_rates.get(field).unwrap_or(&self.global_rate)
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct MissingValueStats {
pub total_fields: usize,
pub total_missing: usize,
pub by_field: HashMap<String, usize>,
pub records_with_missing: usize,
pub total_records: usize,
}
impl MissingValueStats {
pub fn overall_rate(&self) -> f64 {
if self.total_fields == 0 {
0.0
} else {
self.total_missing as f64 / self.total_fields as f64
}
}
pub fn field_rate(&self, field: &str, total_records: usize) -> f64 {
if total_records == 0 {
return 0.0;
}
*self.by_field.get(field).unwrap_or(&0) as f64 / total_records as f64
}
}
pub struct MissingValueInjector {
config: MissingValueConfig,
stats: MissingValueStats,
}
impl MissingValueInjector {
pub fn new(config: MissingValueConfig) -> Self {
Self {
config,
stats: MissingValueStats::default(),
}
}
pub fn should_be_missing<R: Rng>(
&mut self,
field: &str,
value: Option<&str>,
context: &HashMap<String, String>,
rng: &mut R,
) -> bool {
if self.config.required_fields.contains(field) {
return false;
}
let probability = self.calculate_probability(field, value, context);
if self.config.track_statistics {
self.stats.total_fields += 1;
}
let is_missing = rng.random::<f64>() < probability;
if is_missing && self.config.track_statistics {
self.stats.total_missing += 1;
*self.stats.by_field.entry(field.to_string()).or_insert(0) += 1;
}
is_missing
}
fn calculate_probability(
&self,
field: &str,
value: Option<&str>,
context: &HashMap<String, String>,
) -> f64 {
match &self.config.strategy {
MissingValueStrategy::MCAR { probability } => {
let base = self.config.get_rate(field);
if base > 0.0 {
base
} else {
*probability
}
}
MissingValueStrategy::MAR {
base_probability,
conditions,
} => {
let mut prob = *base_probability;
for condition in conditions {
if let Some(field_value) = context.get(&condition.field) {
if self.check_condition(&condition.condition_type, field_value) {
prob *= condition.multiplier;
}
}
}
prob.min(1.0)
}
MissingValueStrategy::MNAR { value_patterns } => {
if let Some(val) = value {
for pattern in value_patterns {
if pattern.field == field
&& self.check_value_pattern(&pattern.pattern_type, val)
{
return pattern.probability;
}
}
}
self.config.get_rate(field)
}
MissingValueStrategy::Systematic {
field_groups,
probability,
} => {
for group in field_groups {
if group.contains(&field.to_string()) {
return *probability;
}
}
self.config.get_rate(field)
}
}
}
fn check_condition(&self, condition: &ConditionType, value: &str) -> bool {
match condition {
ConditionType::Equals(expected) => value == expected,
ConditionType::Contains(substring) => value.contains(substring),
ConditionType::IsEmpty => value.is_empty(),
ConditionType::Matches(pattern) => {
value.contains(pattern)
}
ConditionType::GreaterThan(threshold) => value
.parse::<f64>()
.map(|v| v > *threshold)
.unwrap_or(false),
ConditionType::LessThan(threshold) => value
.parse::<f64>()
.map(|v| v < *threshold)
.unwrap_or(false),
}
}
fn check_value_pattern(&self, pattern: &PatternType, value: &str) -> bool {
match pattern {
PatternType::HighValues { threshold } => value
.parse::<f64>()
.map(|v| v > *threshold)
.unwrap_or(false),
PatternType::LowValues { threshold } => value
.parse::<f64>()
.map(|v| v < *threshold)
.unwrap_or(false),
PatternType::ExtremeValues { low, high } => value
.parse::<f64>()
.map(|v| v < *low || v > *high)
.unwrap_or(false),
PatternType::SensitivePatterns { patterns } => {
patterns.iter().any(|p| value.contains(p))
}
}
}
pub fn record_processed(&mut self, had_missing: bool) {
if self.config.track_statistics {
self.stats.total_records += 1;
if had_missing {
self.stats.records_with_missing += 1;
}
}
}
pub fn stats(&self) -> &MissingValueStats {
&self.stats
}
pub fn reset_stats(&mut self) {
self.stats = MissingValueStats::default();
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum MissingValue {
Null,
Empty,
Marker(String),
NA,
Dash,
Unknown,
}
impl MissingValue {
pub fn to_string_value(&self) -> String {
match self {
MissingValue::Null => String::new(),
MissingValue::Empty => String::new(),
MissingValue::Marker(s) => s.clone(),
MissingValue::NA => "N/A".to_string(),
MissingValue::Dash => "-".to_string(),
MissingValue::Unknown => "?".to_string(),
}
}
pub fn common_representations() -> Vec<Self> {
vec![
MissingValue::Null,
MissingValue::Empty,
MissingValue::NA,
MissingValue::Marker("NULL".to_string()),
MissingValue::Marker("NONE".to_string()),
MissingValue::Marker("#N/A".to_string()),
MissingValue::Dash,
MissingValue::Unknown,
]
}
}
pub fn random_missing_representation<R: Rng>(rng: &mut R) -> MissingValue {
let representations = MissingValue::common_representations();
representations[rng.random_range(0..representations.len())].clone()
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
#[test]
fn test_mcar_strategy() {
let config = MissingValueConfig {
global_rate: 0.5, strategy: MissingValueStrategy::MCAR { probability: 0.5 },
..Default::default()
};
let mut injector = MissingValueInjector::new(config);
let mut rng = ChaCha8Rng::seed_from_u64(42);
let context = HashMap::new();
let mut missing_count = 0;
for _ in 0..1000 {
if injector.should_be_missing("description", Some("test"), &context, &mut rng) {
missing_count += 1;
}
}
assert!(missing_count > 400 && missing_count < 600);
}
#[test]
fn test_required_fields() {
let config = MissingValueConfig {
global_rate: 1.0, ..Default::default()
};
let mut injector = MissingValueInjector::new(config);
let mut rng = ChaCha8Rng::seed_from_u64(42);
let context = HashMap::new();
assert!(!injector.should_be_missing("document_number", Some("JE001"), &context, &mut rng));
assert!(injector.should_be_missing("description", Some("test"), &context, &mut rng));
}
#[test]
fn test_field_specific_rates() {
let mut field_rates = HashMap::new();
field_rates.insert("description".to_string(), 0.0);
field_rates.insert("cost_center".to_string(), 1.0);
let config = MissingValueConfig::default().with_field_rates(field_rates);
let mut injector = MissingValueInjector::new(config);
let mut rng = ChaCha8Rng::seed_from_u64(42);
let context = HashMap::new();
assert!(!injector.should_be_missing("description", Some("test"), &context, &mut rng));
assert!(injector.should_be_missing("cost_center", Some("CC001"), &context, &mut rng));
}
#[test]
fn test_statistics() {
let config = MissingValueConfig {
global_rate: 0.5,
track_statistics: true,
..Default::default()
};
let mut injector = MissingValueInjector::new(config);
let mut rng = ChaCha8Rng::seed_from_u64(42);
let context = HashMap::new();
for _ in 0..100 {
injector.should_be_missing("description", Some("test"), &context, &mut rng);
}
assert_eq!(injector.stats().total_fields, 100);
assert!(injector.stats().total_missing > 0);
}
}