use crate::{Result, TreeBoostError};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, Default, serde::Serialize, serde::Deserialize, PartialEq)]
pub enum ImputeStrategy {
#[default]
Mean,
Median,
Mode,
Constant(f32),
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct SimpleImputer {
strategy: ImputeStrategy,
fill_values: Vec<f32>,
fitted: bool,
}
impl SimpleImputer {
pub fn new(strategy: ImputeStrategy) -> Self {
Self {
strategy,
fill_values: Vec::new(),
fitted: false,
}
}
pub fn mean() -> Self {
Self::new(ImputeStrategy::Mean)
}
pub fn median() -> Self {
Self::new(ImputeStrategy::Median)
}
pub fn mode() -> Self {
Self::new(ImputeStrategy::Mode)
}
pub fn constant(value: f32) -> Self {
Self::new(ImputeStrategy::Constant(value))
}
pub fn fit(&mut self, data: &[f32], num_features: usize) -> Result<()> {
if data.is_empty() {
return Err(TreeBoostError::Data("Cannot fit on empty data".into()));
}
let num_rows = data.len() / num_features;
if data.len() != num_rows * num_features {
return Err(TreeBoostError::Data(format!(
"Data length {} is not divisible by num_features {}",
data.len(),
num_features
)));
}
self.fill_values = Vec::with_capacity(num_features);
for col in 0..num_features {
let values: Vec<f32> = (0..num_rows)
.map(|row| data[row * num_features + col])
.filter(|v| !v.is_nan())
.collect();
let fill_value = if values.is_empty() {
0.0
} else {
match self.strategy {
ImputeStrategy::Mean => values.iter().sum::<f32>() / values.len() as f32,
ImputeStrategy::Median => compute_median(&values),
ImputeStrategy::Mode => compute_mode(&values),
ImputeStrategy::Constant(c) => c,
}
};
self.fill_values.push(fill_value);
}
self.fitted = true;
Ok(())
}
pub fn transform(&self, data: &mut [f32], num_features: usize) -> Result<()> {
if !self.fitted {
return Err(TreeBoostError::Config(
"SimpleImputer not fitted. Call fit() first.".into(),
));
}
if self.fill_values.len() != num_features {
return Err(TreeBoostError::Config(format!(
"Feature count mismatch: fitted with {} features, got {}",
self.fill_values.len(),
num_features
)));
}
let num_rows = data.len() / num_features;
for row in 0..num_rows {
for col in 0..num_features {
let idx = row * num_features + col;
if data[idx].is_nan() {
data[idx] = self.fill_values[col];
}
}
}
Ok(())
}
pub fn fit_transform(&mut self, data: &mut [f32], num_features: usize) -> Result<()> {
self.fit(data, num_features)?;
self.transform(data, num_features)?;
Ok(())
}
pub fn is_fitted(&self) -> bool {
self.fitted
}
pub fn fill_values(&self) -> &[f32] {
&self.fill_values
}
pub fn strategy(&self) -> ImputeStrategy {
self.strategy
}
}
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
pub struct IndicatorImputer {
suffix: String,
only_if_missing: bool,
}
impl IndicatorImputer {
pub fn new() -> Self {
Self {
suffix: "_missing".to_string(),
only_if_missing: true,
}
}
pub fn with_suffix(mut self, suffix: impl Into<String>) -> Self {
self.suffix = suffix.into();
self
}
pub fn for_all_columns(mut self) -> Self {
self.only_if_missing = false;
self
}
pub fn create_indicators(
&self,
data: &[f32],
num_features: usize,
feature_names: &[String],
) -> (Vec<f32>, Vec<String>) {
let num_rows = data.len() / num_features;
let mut has_missing: Vec<bool> = vec![false; num_features];
if self.only_if_missing {
for col in 0..num_features {
for row in 0..num_rows {
if data[row * num_features + col].is_nan() {
has_missing[col] = true;
break;
}
}
}
} else {
has_missing.fill(true);
}
let num_indicators: usize = has_missing.iter().filter(|&&x| x).count();
if num_indicators == 0 {
return (Vec::new(), Vec::new());
}
let mut indicators = Vec::with_capacity(num_rows * num_indicators);
let mut names = Vec::with_capacity(num_indicators);
for col in 0..num_features {
if !has_missing[col] {
continue;
}
let name = if col < feature_names.len() {
format!("{}{}", feature_names[col], self.suffix)
} else {
format!("f{}{}", col, self.suffix)
};
names.push(name);
}
for row in 0..num_rows {
for col in 0..num_features {
if !has_missing[col] {
continue;
}
let is_missing = data[row * num_features + col].is_nan();
indicators.push(if is_missing { 1.0 } else { 0.0 });
}
}
(indicators, names)
}
pub fn transform_with_indicators(
&self,
data: &[f32],
num_features: usize,
feature_names: &[String],
) -> (Vec<f32>, Vec<String>) {
let num_rows = data.len() / num_features;
let (indicators, indicator_names) =
self.create_indicators(data, num_features, feature_names);
if indicator_names.is_empty() {
return (data.to_vec(), feature_names.to_vec());
}
let num_indicators = indicator_names.len();
let total_features = num_features + num_indicators;
let mut combined = Vec::with_capacity(num_rows * total_features);
for row in 0..num_rows {
for col in 0..num_features {
combined.push(data[row * num_features + col]);
}
for ind in 0..num_indicators {
combined.push(indicators[row * num_indicators + ind]);
}
}
let mut combined_names = feature_names.to_vec();
combined_names.extend(indicator_names);
(combined, combined_names)
}
}
fn compute_median(values: &[f32]) -> f32 {
if values.is_empty() {
return 0.0;
}
let mut sorted: Vec<f32> = values.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let len = sorted.len();
if len.is_multiple_of(2) {
(sorted[len / 2 - 1] + sorted[len / 2]) / 2.0
} else {
sorted[len / 2]
}
}
fn compute_mode(values: &[f32]) -> f32 {
if values.is_empty() {
return 0.0;
}
let mut counts: HashMap<i64, (usize, f32)> = HashMap::new();
for &v in values {
let key = (v * 100.0).round() as i64;
let entry = counts.entry(key).or_insert((0, v));
entry.0 += 1;
}
counts
.into_values()
.max_by_key(|(count, _)| *count)
.map(|(_, value)| value)
.unwrap_or(0.0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simple_imputer_mean() {
let mut imputer = SimpleImputer::mean();
let mut data = vec![1.0, 2.0, f32::NAN, 4.0, 3.0, f32::NAN];
imputer.fit(&data, 2).unwrap();
assert!(imputer.is_fitted());
assert!((imputer.fill_values()[0] - 2.0).abs() < 0.01);
assert!((imputer.fill_values()[1] - 3.0).abs() < 0.01);
imputer.transform(&mut data, 2).unwrap();
assert!((data[2] - 2.0).abs() < 0.01); assert!((data[5] - 3.0).abs() < 0.01); }
#[test]
fn test_simple_imputer_median() {
let mut imputer = SimpleImputer::median();
let mut data = vec![1.0, 3.0, f32::NAN, 5.0];
imputer.fit(&data, 1).unwrap();
assert!((imputer.fill_values()[0] - 3.0).abs() < 0.01);
imputer.transform(&mut data, 1).unwrap();
assert!((data[2] - 3.0).abs() < 0.01);
}
#[test]
fn test_simple_imputer_mode() {
let mut imputer = SimpleImputer::mode();
let mut data = vec![1.0, 2.0, 2.0, f32::NAN, 3.0];
imputer.fit(&data, 1).unwrap();
assert!((imputer.fill_values()[0] - 2.0).abs() < 0.01);
imputer.transform(&mut data, 1).unwrap();
assert!((data[3] - 2.0).abs() < 0.01);
}
#[test]
fn test_simple_imputer_constant() {
let mut imputer = SimpleImputer::constant(-999.0);
let mut data = vec![1.0, f32::NAN, 3.0];
imputer.fit(&data, 1).unwrap();
assert!((imputer.fill_values()[0] - (-999.0)).abs() < 0.01);
imputer.transform(&mut data, 1).unwrap();
assert!((data[1] - (-999.0)).abs() < 0.01);
}
#[test]
fn test_simple_imputer_fit_transform() {
let mut imputer = SimpleImputer::mean();
let mut data = vec![1.0, f32::NAN, 3.0];
imputer.fit_transform(&mut data, 1).unwrap();
assert!(imputer.is_fitted());
assert!((data[1] - 2.0).abs() < 0.01); }
#[test]
fn test_simple_imputer_all_nan_column() {
let mut imputer = SimpleImputer::mean();
let data = vec![f32::NAN, 1.0, f32::NAN, 2.0];
imputer.fit(&data, 2).unwrap();
assert!((imputer.fill_values()[0] - 0.0).abs() < 0.01);
assert!((imputer.fill_values()[1] - 1.5).abs() < 0.01);
}
#[test]
fn test_simple_imputer_not_fitted_error() {
let imputer = SimpleImputer::mean();
let mut data = vec![1.0, 2.0];
let result = imputer.transform(&mut data, 2);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("not fitted"));
}
#[test]
fn test_indicator_imputer_basic() {
let imputer = IndicatorImputer::new();
let data = vec![1.0, f32::NAN, 3.0, 4.0];
let names = vec!["age".to_string(), "income".to_string()];
let (indicators, indicator_names) = imputer.create_indicators(&data, 2, &names);
assert_eq!(indicator_names.len(), 1);
assert_eq!(indicator_names[0], "income_missing");
assert_eq!(indicators, vec![1.0, 0.0]);
}
#[test]
fn test_indicator_imputer_all_columns() {
let imputer = IndicatorImputer::new().for_all_columns();
let data = vec![1.0, 2.0, 3.0, 4.0]; let names = vec!["a".to_string(), "b".to_string()];
let (indicators, indicator_names) = imputer.create_indicators(&data, 2, &names);
assert_eq!(indicator_names.len(), 2);
assert_eq!(indicators, vec![0.0, 0.0, 0.0, 0.0]); }
#[test]
fn test_indicator_imputer_custom_suffix() {
let imputer = IndicatorImputer::new().with_suffix("_is_null");
let data = vec![f32::NAN, 2.0];
let names = vec!["x".to_string()];
let (_, indicator_names) = imputer.create_indicators(&data, 1, &names);
assert_eq!(indicator_names[0], "x_is_null");
}
#[test]
fn test_indicator_imputer_transform_with_indicators() {
let imputer = IndicatorImputer::new();
let data = vec![1.0, f32::NAN, 3.0, 4.0];
let names = vec!["a".to_string(), "b".to_string()];
let (combined, combined_names) = imputer.transform_with_indicators(&data, 2, &names);
assert_eq!(combined_names.len(), 3);
assert_eq!(combined_names, vec!["a", "b", "b_missing"]);
assert_eq!(combined.len(), 6);
assert!((combined[0] - 1.0).abs() < 0.01);
assert!(combined[1].is_nan());
assert!((combined[2] - 1.0).abs() < 0.01);
assert!((combined[3] - 3.0).abs() < 0.01);
assert!((combined[4] - 4.0).abs() < 0.01);
assert!((combined[5] - 0.0).abs() < 0.01);
}
#[test]
fn test_indicator_imputer_no_missing() {
let imputer = IndicatorImputer::new();
let data = vec![1.0, 2.0, 3.0, 4.0];
let names = vec!["a".to_string(), "b".to_string()];
let (indicators, indicator_names) = imputer.create_indicators(&data, 2, &names);
assert!(indicators.is_empty());
assert!(indicator_names.is_empty());
}
#[test]
fn test_compute_median() {
assert!((compute_median(&[1.0, 2.0, 3.0]) - 2.0).abs() < 0.01);
assert!((compute_median(&[1.0, 2.0, 3.0, 4.0]) - 2.5).abs() < 0.01);
assert!((compute_median(&[5.0]) - 5.0).abs() < 0.01);
assert!((compute_median(&[]) - 0.0).abs() < 0.01);
}
#[test]
fn test_compute_mode() {
assert!((compute_mode(&[1.0, 2.0, 2.0, 3.0]) - 2.0).abs() < 0.01);
assert!((compute_mode(&[5.0]) - 5.0).abs() < 0.01);
assert!((compute_mode(&[]) - 0.0).abs() < 0.01);
}
#[test]
fn test_imputer_serialization() {
let mut imputer = SimpleImputer::mean();
imputer.fit(&[1.0, 2.0, 3.0, 4.0], 2).unwrap();
let json = serde_json::to_string(&imputer).unwrap();
let loaded: SimpleImputer = serde_json::from_str(&json).unwrap();
assert!(loaded.is_fitted());
assert_eq!(loaded.fill_values(), imputer.fill_values());
}
}