use crate::{Result, TreeBoostError};
#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
pub enum OutlierMethod {
Iqr {
k: f32,
},
ZScore {
threshold: f32,
},
}
impl OutlierMethod {
pub fn iqr() -> Self {
Self::Iqr { k: 1.5 }
}
pub fn iqr_with_k(k: f32) -> Self {
Self::Iqr { k }
}
pub fn zscore() -> Self {
Self::ZScore { threshold: 3.0 }
}
pub fn zscore_with_threshold(threshold: f32) -> Self {
Self::ZScore { threshold }
}
}
impl Default for OutlierMethod {
fn default() -> Self {
Self::iqr()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, serde::Serialize, serde::Deserialize)]
pub enum OutlierAction {
#[default]
Cap,
Flag,
Remove,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct FeatureBounds {
pub lower: f32,
pub upper: f32,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct OutlierDetector {
method: OutlierMethod,
action: OutlierAction,
bounds: Vec<FeatureBounds>,
fitted: bool,
}
impl OutlierDetector {
pub fn new(method: OutlierMethod) -> Self {
Self {
method,
action: OutlierAction::default(),
bounds: Vec::new(),
fitted: false,
}
}
pub fn with_action(mut self, action: OutlierAction) -> Self {
self.action = action;
self
}
pub fn method(&self) -> OutlierMethod {
self.method
}
pub fn action(&self) -> OutlierAction {
self.action
}
pub fn is_fitted(&self) -> bool {
self.fitted
}
pub fn bounds(&self) -> &[FeatureBounds] {
&self.bounds
}
pub fn fit(&mut self, data: &[f32], num_features: usize) -> Result<()> {
if num_features == 0 {
return Err(TreeBoostError::Data("num_features must be > 0".into()));
}
if data.is_empty() {
return Err(TreeBoostError::Data("Empty data".into()));
}
if !data.len().is_multiple_of(num_features) {
return Err(TreeBoostError::Data(format!(
"Data length {} not divisible by num_features {}",
data.len(),
num_features
)));
}
let num_rows = data.len() / num_features;
if num_rows < 4 {
return Err(TreeBoostError::Data(
"Need at least 4 rows to compute quartiles".into(),
));
}
self.bounds = Vec::with_capacity(num_features);
match self.method {
OutlierMethod::Iqr { k } => {
self.fit_iqr(data, num_features, num_rows, k)?;
}
OutlierMethod::ZScore { threshold } => {
self.fit_zscore(data, num_features, num_rows, threshold)?;
}
}
self.fitted = true;
Ok(())
}
fn fit_iqr(
&mut self,
data: &[f32],
num_features: usize,
num_rows: usize,
k: f32,
) -> Result<()> {
for feat in 0..num_features {
let mut column: Vec<f32> = (0..num_rows)
.map(|row| data[row * num_features + feat])
.filter(|v| v.is_finite())
.collect();
if column.is_empty() {
self.bounds.push(FeatureBounds {
lower: f32::NEG_INFINITY,
upper: f32::INFINITY,
});
continue;
}
column.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let q1 = percentile(&column, 0.25);
let q3 = percentile(&column, 0.75);
let iqr = q3 - q1;
let lower = q1 - k * iqr;
let upper = q3 + k * iqr;
self.bounds.push(FeatureBounds { lower, upper });
}
Ok(())
}
fn fit_zscore(
&mut self,
data: &[f32],
num_features: usize,
num_rows: usize,
threshold: f32,
) -> Result<()> {
for feat in 0..num_features {
let column: Vec<f32> = (0..num_rows)
.map(|row| data[row * num_features + feat])
.filter(|v| v.is_finite())
.collect();
if column.is_empty() {
self.bounds.push(FeatureBounds {
lower: f32::NEG_INFINITY,
upper: f32::INFINITY,
});
continue;
}
let n = column.len() as f32;
let mean: f32 = column.iter().sum::<f32>() / n;
let variance: f32 = column.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / n;
let std = variance.sqrt().max(1e-10);
let lower = mean - threshold * std;
let upper = mean + threshold * std;
self.bounds.push(FeatureBounds { lower, upper });
}
Ok(())
}
pub fn is_outlier(&self, value: f32, feature_idx: usize) -> bool {
if !self.fitted || feature_idx >= self.bounds.len() {
return false;
}
let bounds = &self.bounds[feature_idx];
value < bounds.lower || value > bounds.upper
}
pub fn detect(&self, data: &[f32], num_features: usize) -> Result<Vec<(usize, usize)>> {
if !self.fitted {
return Err(TreeBoostError::Data(
"OutlierDetector not fitted. Call fit() first.".into(),
));
}
if num_features != self.bounds.len() {
return Err(TreeBoostError::Data(format!(
"num_features mismatch: fit with {}, detect with {}",
self.bounds.len(),
num_features
)));
}
let num_rows = data.len() / num_features;
let mut outliers = Vec::new();
for row in 0..num_rows {
for feat in 0..num_features {
let value = data[row * num_features + feat];
if value.is_finite() && self.is_outlier(value, feat) {
outliers.push((row, feat));
}
}
}
Ok(outliers)
}
pub fn transform(
&self,
data: &mut [f32],
num_features: usize,
feature_names: &[String],
) -> Result<TransformResult> {
if !self.fitted {
return Err(TreeBoostError::Data(
"OutlierDetector not fitted. Call fit() first.".into(),
));
}
if num_features != self.bounds.len() {
return Err(TreeBoostError::Data(format!(
"num_features mismatch: fit with {}, transform with {}",
self.bounds.len(),
num_features
)));
}
if !data.len().is_multiple_of(num_features) {
return Err(TreeBoostError::Data(format!(
"Data length {} not divisible by num_features {}",
data.len(),
num_features
)));
}
let num_rows = data.len() / num_features;
match self.action {
OutlierAction::Cap => self.transform_cap(data, num_features, num_rows),
OutlierAction::Flag => self.transform_flag(data, num_features, num_rows, feature_names),
OutlierAction::Remove => {
self.transform_remove(data, num_features, num_rows, feature_names)
}
}
}
fn transform_cap(
&self,
data: &mut [f32],
num_features: usize,
num_rows: usize,
) -> Result<TransformResult> {
let mut outlier_count = 0;
for row in 0..num_rows {
for feat in 0..num_features {
let idx = row * num_features + feat;
let value = data[idx];
if !value.is_finite() {
continue;
}
let bounds = &self.bounds[feat];
if value < bounds.lower {
data[idx] = bounds.lower;
outlier_count += 1;
} else if value > bounds.upper {
data[idx] = bounds.upper;
outlier_count += 1;
}
}
}
Ok(TransformResult::Capped { outlier_count })
}
fn transform_flag(
&self,
data: &[f32],
num_features: usize,
num_rows: usize,
feature_names: &[String],
) -> Result<TransformResult> {
let mut indicators = vec![0.0f32; num_rows * num_features];
let mut indicator_names = Vec::with_capacity(num_features);
for feat in 0..num_features {
let name = feature_names
.get(feat)
.cloned()
.unwrap_or_else(|| format!("f{}", feat));
indicator_names.push(format!("{}_outlier", name));
let bounds = &self.bounds[feat];
for row in 0..num_rows {
let value = data[row * num_features + feat];
if value.is_finite() && (value < bounds.lower || value > bounds.upper) {
indicators[row * num_features + feat] = 1.0;
}
}
}
Ok(TransformResult::Flagged {
indicators,
indicator_names,
})
}
fn transform_remove(
&self,
data: &[f32],
num_features: usize,
num_rows: usize,
_feature_names: &[String],
) -> Result<TransformResult> {
let mut outlier_rows = vec![false; num_rows];
for row in 0..num_rows {
for feat in 0..num_features {
let value = data[row * num_features + feat];
if value.is_finite() && self.is_outlier(value, feat) {
outlier_rows[row] = true;
break;
}
}
}
let kept_indices: Vec<usize> = (0..num_rows).filter(|&row| !outlier_rows[row]).collect();
let mut cleaned_data = Vec::with_capacity(kept_indices.len() * num_features);
for &row in &kept_indices {
for feat in 0..num_features {
cleaned_data.push(data[row * num_features + feat]);
}
}
let removed_count = num_rows - kept_indices.len();
Ok(TransformResult::Removed {
cleaned_data,
kept_indices,
removed_count,
})
}
pub fn outlier_counts(&self, data: &[f32], num_features: usize) -> Result<Vec<usize>> {
if !self.fitted {
return Err(TreeBoostError::Data(
"OutlierDetector not fitted. Call fit() first.".into(),
));
}
let num_rows = data.len() / num_features;
let mut counts = vec![0usize; num_features];
for row in 0..num_rows {
for feat in 0..num_features {
let value = data[row * num_features + feat];
if value.is_finite() && self.is_outlier(value, feat) {
counts[feat] += 1;
}
}
}
Ok(counts)
}
}
impl Default for OutlierDetector {
fn default() -> Self {
Self::new(OutlierMethod::default())
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub enum TransformResult {
Capped {
outlier_count: usize,
},
Flagged {
indicators: Vec<f32>,
indicator_names: Vec<String>,
},
Removed {
cleaned_data: Vec<f32>,
kept_indices: Vec<usize>,
removed_count: usize,
},
}
impl TransformResult {
pub fn outlier_count(&self) -> usize {
match self {
Self::Capped { outlier_count } => *outlier_count,
Self::Flagged { indicators, .. } => indicators.iter().filter(|&&v| v > 0.0).count(),
Self::Removed { removed_count, .. } => *removed_count,
}
}
}
fn percentile(sorted_data: &[f32], p: f32) -> f32 {
if sorted_data.is_empty() {
return 0.0;
}
let n = sorted_data.len();
if n == 1 {
return sorted_data[0];
}
let idx = p * (n - 1) as f32;
let lower = idx.floor() as usize;
let upper = idx.ceil() as usize;
let frac = idx - lower as f32;
if upper >= n {
sorted_data[n - 1]
} else if lower == upper {
sorted_data[lower]
} else {
sorted_data[lower] * (1.0 - frac) + sorted_data[upper] * frac
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_outlier_method_defaults() {
let iqr = OutlierMethod::iqr();
assert!(matches!(iqr, OutlierMethod::Iqr { k } if (k - 1.5).abs() < 1e-6));
let zscore = OutlierMethod::zscore();
assert!(
matches!(zscore, OutlierMethod::ZScore { threshold } if (threshold - 3.0).abs() < 1e-6)
);
}
#[test]
fn test_outlier_method_custom() {
let iqr = OutlierMethod::iqr_with_k(2.0);
assert!(matches!(iqr, OutlierMethod::Iqr { k } if (k - 2.0).abs() < 1e-6));
let zscore = OutlierMethod::zscore_with_threshold(2.5);
assert!(
matches!(zscore, OutlierMethod::ZScore { threshold } if (threshold - 2.5).abs() < 1e-6)
);
}
#[test]
fn test_iqr_detection_basic() {
let data = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 100.0, ];
let mut detector = OutlierDetector::new(OutlierMethod::iqr());
detector.fit(&data, 1).unwrap();
assert!(detector.is_outlier(100.0, 0));
assert!(!detector.is_outlier(5.0, 0));
}
#[test]
fn test_iqr_bounds_computation() {
let data: Vec<f32> = (1..=8).map(|x| x as f32).collect();
let mut detector = OutlierDetector::new(OutlierMethod::iqr());
detector.fit(&data, 1).unwrap();
assert!(!detector.is_outlier(-2.0, 0)); assert!(detector.is_outlier(-3.0, 0)); assert!(!detector.is_outlier(11.0, 0)); assert!(detector.is_outlier(12.0, 0)); }
#[test]
fn test_iqr_multifeature() {
let data = vec![
1.0, 100.0, 2.0, 200.0, 3.0, 300.0, 4.0, 400.0, ];
let mut detector = OutlierDetector::new(OutlierMethod::iqr());
detector.fit(&data, 2).unwrap();
assert_eq!(detector.bounds.len(), 2);
}
#[test]
fn test_zscore_detection_basic() {
let data: Vec<f32> = (1..=10).map(|x| x as f32).collect();
let mut detector = OutlierDetector::new(OutlierMethod::zscore());
detector.fit(&data, 1).unwrap();
assert!(!detector.is_outlier(5.0, 0)); assert!(detector.is_outlier(20.0, 0)); assert!(detector.is_outlier(-10.0, 0)); }
#[test]
fn test_zscore_custom_threshold() {
let data: Vec<f32> = (1..=10).map(|x| x as f32).collect();
let mut detector = OutlierDetector::new(OutlierMethod::zscore_with_threshold(2.0));
detector.fit(&data, 1).unwrap();
assert!(detector.is_outlier(12.0, 0));
}
#[test]
fn test_cap_action() {
let mut data = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 100.0, -50.0, ];
let mut detector =
OutlierDetector::new(OutlierMethod::iqr()).with_action(OutlierAction::Cap);
detector.fit(&data, 1).unwrap();
let result = detector.transform(&mut data, 1, &["f0".into()]).unwrap();
assert!(data[8] < 100.0); assert!(data[9] > -50.0);
if let TransformResult::Capped { outlier_count } = result {
assert_eq!(outlier_count, 2);
} else {
panic!("Expected Capped result");
}
}
#[test]
fn test_flag_action() {
let mut data = vec![
1.0, 10.0, 2.0, 20.0, 3.0, 30.0, 4.0, 40.0, 100.0, 50.0, ];
let mut detector =
OutlierDetector::new(OutlierMethod::iqr()).with_action(OutlierAction::Flag);
detector.fit(&data, 2).unwrap();
let names = vec!["f0".into(), "f1".into()];
let result = detector.transform(&mut data, 2, &names).unwrap();
if let TransformResult::Flagged {
indicators,
indicator_names,
} = result
{
assert_eq!(indicator_names.len(), 2);
assert_eq!(indicator_names[0], "f0_outlier");
assert_eq!(indicator_names[1], "f1_outlier");
assert!((indicators[4 * 2 + 0] - 1.0).abs() < 1e-6);
assert!((indicators[4 * 2 + 1] - 0.0).abs() < 1e-6);
} else {
panic!("Expected Flagged result");
}
}
#[test]
fn test_remove_action() {
let mut data = vec![
1.0, 10.0, 2.0, 20.0, 3.0, 30.0, 4.0, 40.0, 100.0, 50.0, ];
let mut detector =
OutlierDetector::new(OutlierMethod::iqr()).with_action(OutlierAction::Remove);
detector.fit(&data, 2).unwrap();
let names = vec!["f0".into(), "f1".into()];
let result = detector.transform(&mut data, 2, &names).unwrap();
if let TransformResult::Removed {
cleaned_data,
kept_indices,
removed_count,
} = result
{
assert_eq!(removed_count, 1);
assert_eq!(kept_indices.len(), 4);
assert_eq!(cleaned_data.len(), 8); assert!(!kept_indices.contains(&4)); } else {
panic!("Expected Removed result");
}
}
#[test]
fn test_no_outliers() {
let mut data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let mut detector =
OutlierDetector::new(OutlierMethod::iqr()).with_action(OutlierAction::Cap);
detector.fit(&data, 1).unwrap();
let result = detector.transform(&mut data, 1, &["f0".into()]).unwrap();
if let TransformResult::Capped { outlier_count } = result {
assert_eq!(outlier_count, 0);
}
}
#[test]
fn test_nan_handling() {
let data = vec![1.0, f32::NAN, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let mut detector = OutlierDetector::new(OutlierMethod::iqr());
detector.fit(&data, 1).unwrap();
assert!(!detector.is_outlier(f32::NAN, 0));
}
#[test]
fn test_detect_method() {
let data = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 100.0, ];
let mut detector = OutlierDetector::new(OutlierMethod::iqr());
detector.fit(&data, 1).unwrap();
let outliers = detector.detect(&data, 1).unwrap();
assert_eq!(outliers.len(), 1);
assert_eq!(outliers[0], (8, 0)); }
#[test]
fn test_outlier_counts() {
let data = vec![
1.0, 10.0, 2.0, 20.0, 3.0, 30.0, 4.0, 40.0, 100.0, 1000.0, ];
let mut detector = OutlierDetector::new(OutlierMethod::iqr());
detector.fit(&data, 2).unwrap();
let counts = detector.outlier_counts(&data, 2).unwrap();
assert_eq!(counts[0], 1); assert_eq!(counts[1], 1); }
#[test]
fn test_not_fitted_error() {
let detector = OutlierDetector::new(OutlierMethod::iqr());
let result = detector.detect(&[1.0, 2.0], 1);
assert!(result.is_err());
}
#[test]
fn test_feature_mismatch_error() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let mut detector = OutlierDetector::new(OutlierMethod::iqr());
detector.fit(&data, 1).unwrap();
let result = detector.detect(&[1.0, 2.0, 3.0, 4.0], 2);
assert!(result.is_err());
}
#[test]
fn test_percentile_basic() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
assert!((percentile(&data, 0.0) - 1.0).abs() < 1e-6);
assert!((percentile(&data, 0.5) - 3.0).abs() < 1e-6);
assert!((percentile(&data, 1.0) - 5.0).abs() < 1e-6);
}
#[test]
fn test_percentile_interpolation() {
let data = vec![1.0, 2.0, 3.0, 4.0];
let q1 = percentile(&data, 0.25);
assert!((q1 - 1.75).abs() < 1e-6);
let q3 = percentile(&data, 0.75);
assert!((q3 - 3.25).abs() < 1e-6);
}
#[test]
fn test_outlier_detector_serialization() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let mut detector =
OutlierDetector::new(OutlierMethod::iqr()).with_action(OutlierAction::Cap);
detector.fit(&data, 1).unwrap();
let json = serde_json::to_string(&detector).unwrap();
assert!(!json.is_empty());
let loaded: OutlierDetector = serde_json::from_str(&json).unwrap();
assert!(loaded.is_fitted());
assert_eq!(loaded.bounds.len(), 1);
assert_eq!(loaded.method(), OutlierMethod::iqr());
assert_eq!(loaded.action(), OutlierAction::Cap);
}
#[test]
fn test_outlier_method_serialization() {
let methods = vec![
OutlierMethod::iqr(),
OutlierMethod::iqr_with_k(2.0),
OutlierMethod::zscore(),
OutlierMethod::zscore_with_threshold(2.5),
];
for method in methods {
let json = serde_json::to_string(&method).unwrap();
let loaded: OutlierMethod = serde_json::from_str(&json).unwrap();
assert_eq!(loaded, method);
}
}
#[test]
fn test_outlier_action_serialization() {
let actions = vec![
OutlierAction::Cap,
OutlierAction::Flag,
OutlierAction::Remove,
];
for action in actions {
let json = serde_json::to_string(&action).unwrap();
let loaded: OutlierAction = serde_json::from_str(&json).unwrap();
assert_eq!(loaded, action);
}
}
#[test]
fn test_transform_result_serialization() {
let result1 = TransformResult::Capped { outlier_count: 5 };
let json1 = serde_json::to_string(&result1).unwrap();
let loaded1: TransformResult = serde_json::from_str(&json1).unwrap();
assert_eq!(loaded1.outlier_count(), 5);
let result2 = TransformResult::Flagged {
indicators: vec![0.0, 1.0, 0.0],
indicator_names: vec!["f0_outlier".into(), "f1_outlier".into()],
};
let json2 = serde_json::to_string(&result2).unwrap();
let loaded2: TransformResult = serde_json::from_str(&json2).unwrap();
assert_eq!(loaded2.outlier_count(), 1);
let result3 = TransformResult::Removed {
cleaned_data: vec![1.0, 2.0, 3.0],
kept_indices: vec![0, 1, 2],
removed_count: 2,
};
let json3 = serde_json::to_string(&result3).unwrap();
let loaded3: TransformResult = serde_json::from_str(&json3).unwrap();
assert_eq!(loaded3.outlier_count(), 2);
}
}