use faer::{Col, Mat};
use thiserror::Error;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum NaAction {
#[default]
Omit,
Exclude,
Fail,
Pass,
}
#[derive(Debug, Error)]
pub enum NaError {
#[error("NA values found in data (na.fail): {n_na} rows contain missing values")]
NaValuesPresent { n_na: usize },
#[error("all observations contain NA values")]
AllNa,
#[error("insufficient observations after NA removal: {remaining} remaining, {needed} needed")]
InsufficientAfterNa { remaining: usize, needed: usize },
}
#[derive(Debug, Clone)]
pub struct NaInfo {
pub n_original: usize,
pub n_clean: usize,
pub na_mask: Vec<bool>,
pub kept_indices: Vec<usize>,
pub n_removed: usize,
pub action: NaAction,
}
impl NaInfo {
pub fn has_removed(&self) -> bool {
self.n_removed > 0
}
pub fn needs_expansion(&self) -> bool {
self.action == NaAction::Exclude && self.n_removed > 0
}
pub fn expand(&self, clean_values: &Col<f64>) -> Col<f64> {
if !self.needs_expansion() {
return clean_values.clone();
}
let mut expanded = Col::zeros(self.n_original);
let mut clean_idx = 0;
for (orig_idx, &had_na) in self.na_mask.iter().enumerate() {
if had_na {
expanded[orig_idx] = f64::NAN;
} else {
expanded[orig_idx] = clean_values[clean_idx];
clean_idx += 1;
}
}
expanded
}
pub fn no_na(n_observations: usize, action: NaAction) -> Self {
Self {
n_original: n_observations,
n_clean: n_observations,
na_mask: vec![false; n_observations],
kept_indices: (0..n_observations).collect(),
n_removed: 0,
action,
}
}
}
#[derive(Debug, Clone)]
pub struct NaResult {
pub x_clean: Mat<f64>,
pub y_clean: Col<f64>,
pub na_info: NaInfo,
}
pub struct NaHandler;
impl NaHandler {
pub fn process(x: &Mat<f64>, y: &Col<f64>, action: NaAction) -> Result<NaResult, NaError> {
let n_samples = x.nrows();
let n_features = x.ncols();
let na_mask = Self::find_na_rows(x, y);
let n_na = na_mask.iter().filter(|&&v| v).count();
match action {
NaAction::Fail => {
if n_na > 0 {
return Err(NaError::NaValuesPresent { n_na });
}
Ok(NaResult {
x_clean: x.clone(),
y_clean: y.clone(),
na_info: NaInfo::no_na(n_samples, action),
})
}
NaAction::Pass => {
Ok(NaResult {
x_clean: x.clone(),
y_clean: y.clone(),
na_info: NaInfo::no_na(n_samples, action),
})
}
NaAction::Omit | NaAction::Exclude => {
if n_na == n_samples {
return Err(NaError::AllNa);
}
if n_na == 0 {
return Ok(NaResult {
x_clean: x.clone(),
y_clean: y.clone(),
na_info: NaInfo::no_na(n_samples, action),
});
}
let kept_indices: Vec<usize> = na_mask
.iter()
.enumerate()
.filter_map(|(i, &had_na)| if !had_na { Some(i) } else { None })
.collect();
let n_clean = kept_indices.len();
let x_clean = Mat::from_fn(n_clean, n_features, |i, j| x[(kept_indices[i], j)]);
let y_clean = Col::from_fn(n_clean, |i| y[kept_indices[i]]);
let na_info = NaInfo {
n_original: n_samples,
n_clean,
na_mask,
kept_indices,
n_removed: n_na,
action,
};
Ok(NaResult {
x_clean,
y_clean,
na_info,
})
}
}
}
pub fn process_with_weights(
x: &Mat<f64>,
y: &Col<f64>,
weights: &Col<f64>,
action: NaAction,
) -> Result<(NaResult, Col<f64>), NaError> {
let n_samples = x.nrows();
let n_features = x.ncols();
let na_mask = Self::find_na_rows_with_weights(x, y, weights);
let n_na = na_mask.iter().filter(|&&v| v).count();
match action {
NaAction::Fail => {
if n_na > 0 {
return Err(NaError::NaValuesPresent { n_na });
}
Ok((
NaResult {
x_clean: x.clone(),
y_clean: y.clone(),
na_info: NaInfo::no_na(n_samples, action),
},
weights.clone(),
))
}
NaAction::Pass => Ok((
NaResult {
x_clean: x.clone(),
y_clean: y.clone(),
na_info: NaInfo::no_na(n_samples, action),
},
weights.clone(),
)),
NaAction::Omit | NaAction::Exclude => {
if n_na == n_samples {
return Err(NaError::AllNa);
}
if n_na == 0 {
return Ok((
NaResult {
x_clean: x.clone(),
y_clean: y.clone(),
na_info: NaInfo::no_na(n_samples, action),
},
weights.clone(),
));
}
let kept_indices: Vec<usize> = na_mask
.iter()
.enumerate()
.filter_map(|(i, &had_na)| if !had_na { Some(i) } else { None })
.collect();
let n_clean = kept_indices.len();
let x_clean = Mat::from_fn(n_clean, n_features, |i, j| x[(kept_indices[i], j)]);
let y_clean = Col::from_fn(n_clean, |i| y[kept_indices[i]]);
let weights_clean = Col::from_fn(n_clean, |i| weights[kept_indices[i]]);
let na_info = NaInfo {
n_original: n_samples,
n_clean,
na_mask,
kept_indices,
n_removed: n_na,
action,
};
Ok((
NaResult {
x_clean,
y_clean,
na_info,
},
weights_clean,
))
}
}
}
fn find_na_rows(x: &Mat<f64>, y: &Col<f64>) -> Vec<bool> {
let n_samples = x.nrows();
let n_features = x.ncols();
(0..n_samples)
.map(|i| {
if y[i].is_nan() {
return true;
}
for j in 0..n_features {
if x[(i, j)].is_nan() {
return true;
}
}
false
})
.collect()
}
fn find_na_rows_with_weights(x: &Mat<f64>, y: &Col<f64>, weights: &Col<f64>) -> Vec<bool> {
let n_samples = x.nrows();
let n_features = x.ncols();
(0..n_samples)
.map(|i| {
if y[i].is_nan() {
return true;
}
if weights[i].is_nan() {
return true;
}
for j in 0..n_features {
if x[(i, j)].is_nan() {
return true;
}
}
false
})
.collect()
}
pub fn has_na_matrix(x: &Mat<f64>) -> bool {
let n_rows = x.nrows();
let n_cols = x.ncols();
for i in 0..n_rows {
for j in 0..n_cols {
if x[(i, j)].is_nan() {
return true;
}
}
}
false
}
pub fn has_na_vector(v: &Col<f64>) -> bool {
v.iter().any(|&x| x.is_nan())
}
pub fn count_na_matrix(x: &Mat<f64>) -> usize {
let n_rows = x.nrows();
let n_cols = x.ncols();
let mut count = 0;
for i in 0..n_rows {
for j in 0..n_cols {
if x[(i, j)].is_nan() {
count += 1;
}
}
}
count
}
pub fn count_na_vector(v: &Col<f64>) -> usize {
v.iter().filter(|&&x| x.is_nan()).count()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_data_with_na() -> (Mat<f64>, Col<f64>) {
let x = Mat::from_fn(5, 2, |i, j| {
if i == 2 && j == 0 {
f64::NAN
} else {
(i * 2 + j) as f64
}
});
let y = Col::from_fn(5, |i| if i == 3 { f64::NAN } else { (i * 10) as f64 });
(x, y)
}
fn create_clean_data() -> (Mat<f64>, Col<f64>) {
let x = Mat::from_fn(5, 2, |i, j| (i * 2 + j) as f64);
let y = Col::from_fn(5, |i| (i * 10) as f64);
(x, y)
}
#[test]
fn test_na_omit() {
let (x, y) = create_test_data_with_na();
let result = NaHandler::process(&x, &y, NaAction::Omit).unwrap();
assert_eq!(result.x_clean.nrows(), 3);
assert_eq!(result.y_clean.nrows(), 3);
assert_eq!(result.na_info.n_removed, 2);
assert_eq!(result.na_info.kept_indices, vec![0, 1, 4]);
assert!(!result.na_info.needs_expansion());
}
#[test]
fn test_na_exclude() {
let (x, y) = create_test_data_with_na();
let result = NaHandler::process(&x, &y, NaAction::Exclude).unwrap();
assert_eq!(result.x_clean.nrows(), 3);
assert_eq!(result.y_clean.nrows(), 3);
assert_eq!(result.na_info.n_removed, 2);
assert!(result.na_info.needs_expansion());
let clean_resid = Col::from_fn(3, |i| (i + 1) as f64);
let expanded = result.na_info.expand(&clean_resid);
assert_eq!(expanded.nrows(), 5);
assert!((expanded[0] - 1.0).abs() < 1e-10);
assert!((expanded[1] - 2.0).abs() < 1e-10);
assert!(expanded[2].is_nan()); assert!(expanded[3].is_nan()); assert!((expanded[4] - 3.0).abs() < 1e-10);
}
#[test]
fn test_na_fail() {
let (x, y) = create_test_data_with_na();
let result = NaHandler::process(&x, &y, NaAction::Fail);
assert!(matches!(result, Err(NaError::NaValuesPresent { n_na: 2 })));
}
#[test]
fn test_na_fail_no_na() {
let (x, y) = create_clean_data();
let result = NaHandler::process(&x, &y, NaAction::Fail).unwrap();
assert_eq!(result.x_clean.nrows(), 5);
assert_eq!(result.na_info.n_removed, 0);
}
#[test]
fn test_na_pass() {
let (x, y) = create_test_data_with_na();
let result = NaHandler::process(&x, &y, NaAction::Pass).unwrap();
assert_eq!(result.x_clean.nrows(), 5);
assert!(result.x_clean[(2, 0)].is_nan());
assert!(result.y_clean[3].is_nan());
}
#[test]
fn test_clean_data() {
let (x, y) = create_clean_data();
let result = NaHandler::process(&x, &y, NaAction::Omit).unwrap();
assert_eq!(result.x_clean.nrows(), 5);
assert_eq!(result.na_info.n_removed, 0);
assert!(!result.na_info.needs_expansion());
}
#[test]
fn test_all_na() {
let x = Mat::from_fn(3, 2, |_, _| f64::NAN);
let y = Col::from_fn(3, |_| f64::NAN);
let result = NaHandler::process(&x, &y, NaAction::Omit);
assert!(matches!(result, Err(NaError::AllNa)));
}
#[test]
fn test_has_na_helpers() {
let (x_na, y_na) = create_test_data_with_na();
let (x_clean, y_clean) = create_clean_data();
assert!(NaHandler::has_na_matrix(&x_na));
assert!(NaHandler::has_na_vector(&y_na));
assert!(!NaHandler::has_na_matrix(&x_clean));
assert!(!NaHandler::has_na_vector(&y_clean));
}
#[test]
fn test_count_na() {
let (x, y) = create_test_data_with_na();
assert_eq!(NaHandler::count_na_matrix(&x), 1);
assert_eq!(NaHandler::count_na_vector(&y), 1);
}
#[test]
fn test_process_with_weights() {
let (x, y) = create_clean_data();
let mut weights = Col::from_fn(5, |i| (i + 1) as f64);
weights[2] = f64::NAN;
let (result, clean_weights) =
NaHandler::process_with_weights(&x, &y, &weights, NaAction::Omit).unwrap();
assert_eq!(result.x_clean.nrows(), 4);
assert_eq!(clean_weights.nrows(), 4);
assert_eq!(result.na_info.n_removed, 1);
}
#[test]
fn test_na_info_no_na() {
let info = NaInfo::no_na(10, NaAction::Omit);
assert_eq!(info.n_original, 10);
assert_eq!(info.n_clean, 10);
assert_eq!(info.n_removed, 0);
assert!(!info.has_removed());
assert!(!info.needs_expansion());
}
#[test]
fn test_expand_no_expansion_needed() {
let info = NaInfo::no_na(5, NaAction::Omit);
let values = Col::from_fn(5, |i| i as f64);
let expanded = info.expand(&values);
assert_eq!(expanded.nrows(), 5);
for i in 0..5 {
assert!((expanded[i] - i as f64).abs() < 1e-10);
}
}
#[test]
fn test_na_action_default() {
assert_eq!(NaAction::default(), NaAction::Omit);
}
}