use faer::{Col, Mat};
#[derive(Debug, Clone, Default)]
pub struct SeparationCheck {
pub has_separation: bool,
pub separated_predictors: Vec<usize>,
pub separation_types: Vec<SeparationType>,
pub warning_message: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SeparationType {
None,
Complete,
Quasi,
MonotonicResponse,
}
pub fn check_binary_separation(x: &Mat<f64>, y: &Col<f64>) -> SeparationCheck {
let n_samples = x.nrows();
let n_features = x.ncols();
if n_samples == 0 || n_features == 0 {
return SeparationCheck::default();
}
let mut result = SeparationCheck::default();
let mut warnings = Vec::new();
let y_vec: Vec<f64> = (0..n_samples).map(|i| y[i]).collect();
for j in 0..n_features {
let x_col: Vec<f64> = (0..n_samples).map(|i| x[(i, j)]).collect();
let sep_type = check_predictor_separation(&x_col, &y_vec);
result.separation_types.push(sep_type);
if sep_type != SeparationType::None {
result.has_separation = true;
result.separated_predictors.push(j);
let msg = match sep_type {
SeparationType::Complete => {
format!("Feature {} shows complete separation", j)
}
SeparationType::Quasi => {
format!("Feature {} shows quasi-separation", j)
}
SeparationType::MonotonicResponse => {
format!(
"Feature {} has all responses in one class for some values",
j
)
}
SeparationType::None => unreachable!(),
};
warnings.push(msg);
}
}
if !warnings.is_empty() {
let combined = warnings.join("; ");
result.warning_message = Some(format!(
"Quasi-separation detected: {}. Coefficients may be unstable. \
Consider using regularization (Ridge/Lasso) or removing problematic features.",
combined
));
}
result
}
fn check_predictor_separation(x: &[f64], y: &[f64]) -> SeparationType {
let n = x.len();
if n == 0 {
return SeparationType::None;
}
let mut pairs: Vec<(f64, f64)> = x.iter().zip(y.iter()).map(|(&xi, &yi)| (xi, yi)).collect();
pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
let mut x_prev = pairs[0].0;
let mut count_0 = 0;
let mut count_1 = 0;
let mut all_same_class_for_x = true;
let mut found_only_zeros = false;
let mut found_only_ones = false;
let mut groups_with_multiple_obs = 0;
for &(xi, yi) in &pairs {
if (xi - x_prev).abs() > 1e-10 {
if count_0 + count_1 > 1 {
groups_with_multiple_obs += 1;
}
if count_0 > 0 && count_1 == 0 {
found_only_zeros = true;
}
if count_1 > 0 && count_0 == 0 {
found_only_ones = true;
}
if count_0 > 0 && count_1 > 0 {
all_same_class_for_x = false;
}
count_0 = 0;
count_1 = 0;
x_prev = xi;
}
if yi < 0.5 {
count_0 += 1;
} else {
count_1 += 1;
}
}
if count_0 + count_1 > 1 {
groups_with_multiple_obs += 1;
}
if count_0 > 0 && count_1 == 0 {
found_only_zeros = true;
}
if count_1 > 0 && count_0 == 0 {
found_only_ones = true;
}
if count_0 > 0 && count_1 > 0 {
all_same_class_for_x = false;
}
let total_1s: usize = pairs.iter().filter(|(_, yi)| *yi >= 0.5).count();
let total_0s = n - total_1s;
if total_1s == 0 || total_0s == 0 {
return SeparationType::None;
}
let first_1_idx = pairs.iter().position(|(_, yi)| *yi >= 0.5);
let last_0_idx = pairs.iter().rposition(|(_, yi)| *yi < 0.5);
if let (Some(first_1), Some(last_0)) = (first_1_idx, last_0_idx) {
if first_1 > last_0 {
return SeparationType::Complete;
}
let overlap_count = pairs[..=last_0].iter().filter(|(_, yi)| *yi >= 0.5).count();
if overlap_count <= 2 && (overlap_count as f64) / (total_1s as f64) < 0.1 {
return SeparationType::Quasi;
}
}
if found_only_zeros && found_only_ones && all_same_class_for_x && groups_with_multiple_obs > 0 {
return SeparationType::MonotonicResponse;
}
SeparationType::None
}
pub fn check_count_sparsity(x: &Mat<f64>, y: &Col<f64>) -> SeparationCheck {
let n_samples = x.nrows();
let n_features = x.ncols();
if n_samples == 0 || n_features == 0 {
return SeparationCheck::default();
}
let mut result = SeparationCheck::default();
let mut warnings = Vec::new();
let y_vec: Vec<f64> = (0..n_samples).map(|i| y[i]).collect();
let total_zeros = y_vec.iter().filter(|&&yi| yi < 1e-10).count();
if total_zeros as f64 / n_samples as f64 > 0.9 {
for j in 0..n_features {
let x_col: Vec<f64> = (0..n_samples).map(|i| x[(i, j)]).collect();
let unique_vals: std::collections::HashSet<i64> =
x_col.iter().map(|&v| (v * 1000.0) as i64).collect();
if unique_vals.len() <= 3 {
let non_zero_x_indices: Vec<usize> = x_col
.iter()
.enumerate()
.filter(|(_, &xi)| xi.abs() > 1e-10)
.map(|(i, _)| i)
.collect();
if !non_zero_x_indices.is_empty() {
let non_zero_y_count = non_zero_x_indices
.iter()
.filter(|&&i| y_vec[i] > 1e-10)
.count();
if non_zero_y_count == 0 {
result.has_separation = true;
result.separated_predictors.push(j);
result
.separation_types
.push(SeparationType::MonotonicResponse);
warnings.push(format!(
"Feature {} (binary indicator) has all-zero responses when active",
j
));
}
}
}
}
}
if !warnings.is_empty() {
let combined = warnings.join("; ");
result.warning_message = Some(format!(
"Sparse data separation detected: {}. Coefficients may diverge to -infinity. \
Consider using regularization or removing sparse indicator features.",
combined
));
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_complete_separation() {
let x = Mat::from_fn(10, 1, |i, _| i as f64);
let y = Col::from_fn(10, |i| if i < 5 { 0.0 } else { 1.0 });
let result = check_binary_separation(&x, &y);
assert!(result.has_separation);
assert!(result.separated_predictors.contains(&0));
assert_eq!(result.separation_types[0], SeparationType::Complete);
}
#[test]
fn test_no_separation() {
let x = Mat::from_fn(10, 1, |i, _| i as f64);
let y_vals = [0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0];
let y = Col::from_fn(10, |i| y_vals[i]);
let result = check_binary_separation(&x, &y);
assert!(!result.has_separation);
}
#[test]
fn test_count_sparsity() {
let x = Mat::from_fn(20, 1, |i, _| if i < 10 { 0.0 } else { 1.0 });
let y = Col::from_fn(20, |i| if i < 10 { (i as f64) * 0.01 } else { 0.0 });
let result = check_count_sparsity(&x, &y);
assert!(!result.has_separation || result.separated_predictors.is_empty());
}
#[test]
fn test_quasi_separation() {
let n = 30;
let x = Mat::from_fn(n, 1, |i, _| i as f64);
let y = Col::from_fn(n, |i| {
if i == 0 {
1.0 } else if i < 18 {
0.0 } else {
1.0 }
});
let result = check_binary_separation(&x, &y);
assert!(result.has_separation);
assert_eq!(result.separation_types[0], SeparationType::Quasi);
assert!(result.warning_message.is_some());
let msg = result.warning_message.unwrap();
assert!(msg.contains("quasi-separation"));
}
#[test]
fn test_monotonic_response() {
let x = Mat::from_fn(6, 1, |i, _| (i / 2) as f64);
let y_vals = [0.0, 0.0, 1.0, 1.0, 0.0, 0.0];
let y = Col::from_fn(6, |i| y_vals[i]);
let result = check_binary_separation(&x, &y);
assert!(result.has_separation);
assert_eq!(
result.separation_types[0],
SeparationType::MonotonicResponse
);
assert!(result.warning_message.is_some());
let msg = result.warning_message.unwrap();
assert!(msg.contains("all responses in one class"));
}
#[test]
fn test_count_sparsity_with_separation() {
let n = 100;
let x = Mat::from_fn(n, 1, |i, _| if i < 50 { 0.0 } else { 1.0 });
let y = Col::from_fn(n, |i| if i < 5 { 1.0 } else { 0.0 });
let result = check_count_sparsity(&x, &y);
assert!(result.has_separation);
assert!(result.separated_predictors.contains(&0));
assert_eq!(
result.separation_types[0],
SeparationType::MonotonicResponse
);
assert!(result.warning_message.is_some());
let msg = result.warning_message.unwrap();
assert!(msg.contains("all-zero responses when active"));
}
#[test]
fn test_empty_matrix_binary() {
let x = Mat::<f64>::zeros(0, 0);
let y = Col::<f64>::zeros(0);
let result = check_binary_separation(&x, &y);
assert!(!result.has_separation);
assert!(result.separated_predictors.is_empty());
}
#[test]
fn test_empty_matrix_count() {
let x = Mat::<f64>::zeros(0, 0);
let y = Col::<f64>::zeros(0);
let result = check_count_sparsity(&x, &y);
assert!(!result.has_separation);
assert!(result.separated_predictors.is_empty());
}
#[test]
fn test_all_same_class() {
let x = Mat::from_fn(10, 1, |i, _| i as f64);
let y = Col::from_fn(10, |_| 1.0);
let result = check_binary_separation(&x, &y);
assert!(!result.has_separation);
}
#[test]
fn test_multiple_predictors() {
let x = Mat::from_fn(10, 2, |i, j| {
if j == 0 {
(i % 3) as f64 } else {
i as f64 }
});
let y = Col::from_fn(10, |i| if i < 5 { 0.0 } else { 1.0 });
let result = check_binary_separation(&x, &y);
assert!(result.has_separation);
assert!(result.separated_predictors.contains(&1));
assert_eq!(result.separation_types[1], SeparationType::Complete);
}
#[test]
fn test_count_sparsity_not_sparse_enough() {
let n = 10;
let x = Mat::from_fn(n, 1, |i, _| if i < 5 { 0.0 } else { 1.0 });
let y = Col::from_fn(n, |i| if i < 3 { 1.0 } else { 0.0 });
let result = check_count_sparsity(&x, &y);
assert!(!result.has_separation);
}
#[test]
fn test_count_sparsity_non_binary_predictor() {
let n = 100;
let x = Mat::from_fn(n, 1, |i, _| i as f64); let y = Col::from_fn(n, |i| if i < 5 { 1.0 } else { 0.0 });
let result = check_count_sparsity(&x, &y);
assert!(!result.has_separation);
}
#[test]
fn test_separation_type_clone() {
let sep = SeparationType::Complete;
let cloned = sep.clone();
assert_eq!(sep, cloned);
}
#[test]
fn test_separation_check_default() {
let check = SeparationCheck::default();
assert!(!check.has_separation);
assert!(check.separated_predictors.is_empty());
assert!(check.separation_types.is_empty());
assert!(check.warning_message.is_none());
}
#[test]
fn test_separation_check_clone() {
let mut check = SeparationCheck::default();
check.has_separation = true;
check.separated_predictors.push(0);
check.separation_types.push(SeparationType::Complete);
check.warning_message = Some("test".to_string());
let cloned = check.clone();
assert!(cloned.has_separation);
assert_eq!(cloned.separated_predictors, vec![0]);
assert_eq!(cloned.separation_types, vec![SeparationType::Complete]);
assert_eq!(cloned.warning_message, Some("test".to_string()));
}
}