#![allow(dead_code)]
use std::collections::HashMap;
use std::hash::Hash;
pub(crate) const MAD_GAUSSIAN_CONSISTENCY: f64 = 1.4826;
#[derive(Clone, Copy, Debug)]
pub(crate) struct OutlierConfig {
pub k: f64,
pub min_samples: usize,
pub trim_fraction: f64,
}
impl Default for OutlierConfig {
fn default() -> Self {
Self {
k: 5.0,
min_samples: 30,
trim_fraction: 0.05,
}
}
}
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub(crate) enum SkipReason {
InsufficientSamples,
MadCollapsed,
NoExtractableRatios,
}
#[derive(Clone, Debug)]
pub(crate) struct OutlierResult<K> {
pub flagged: Vec<K>,
pub median_log_ratio: Option<f64>,
pub mad: Option<f64>,
pub threshold: Option<f64>,
pub capacity_ceiling_binding: bool,
pub sample_size: usize,
pub trimmed_sample_size: usize,
pub skip_reason: Option<SkipReason>,
}
impl<K> OutlierResult<K> {
fn skip(reason: SkipReason, sample_size: usize) -> Self {
Self {
flagged: Vec::new(),
median_log_ratio: None,
mad: None,
threshold: None,
capacity_ceiling_binding: false,
sample_size,
trimmed_sample_size: 0,
skip_reason: Some(reason),
}
}
}
pub(crate) fn detect_outliers<K, S, F>(
samples: &HashMap<K, S>,
extract_log_ratio: F,
config: &OutlierConfig,
capacity_ceiling_log: f64,
) -> OutlierResult<K>
where
K: Clone + Eq + Hash,
F: Fn(&S) -> Option<f64>,
{
let mut pairs: Vec<(K, f64)> = samples
.iter()
.filter_map(|(k, s)| {
extract_log_ratio(s).and_then(|r| {
if r.is_finite() {
Some((k.clone(), r))
} else {
None
}
})
})
.collect();
let sample_size = pairs.len();
if sample_size == 0 {
return OutlierResult::skip(SkipReason::NoExtractableRatios, sample_size);
}
if sample_size < config.min_samples {
return OutlierResult::skip(SkipReason::InsufficientSamples, sample_size);
}
pairs.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let trim_count = (sample_size as f64 * config.trim_fraction).floor() as usize;
let trimmed_len = sample_size.saturating_sub(trim_count);
let trimmed_ratios: Vec<f64> = pairs.iter().take(trimmed_len).map(|(_, r)| *r).collect();
if trimmed_ratios.is_empty() {
return OutlierResult::skip(SkipReason::InsufficientSamples, sample_size);
}
let median = median_of_sorted(&trimmed_ratios);
let mut deviations: Vec<f64> = trimmed_ratios.iter().map(|r| (r - median).abs()).collect();
deviations.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Less));
let mad = median_of_sorted(&deviations);
if mad < f64::EPSILON {
return OutlierResult {
flagged: Vec::new(),
median_log_ratio: Some(median),
mad: Some(mad),
threshold: None,
capacity_ceiling_binding: false,
sample_size,
trimmed_sample_size: trimmed_len,
skip_reason: Some(SkipReason::MadCollapsed),
};
}
let scaled_mad = mad * MAD_GAUSSIAN_CONSISTENCY;
let raw_threshold = median + config.k * scaled_mad;
let capacity_ceiling_binding = raw_threshold > capacity_ceiling_log;
let threshold = raw_threshold.min(capacity_ceiling_log);
let flagged: Vec<K> = pairs
.into_iter()
.filter(|(_, r)| *r > threshold)
.map(|(k, _)| k)
.collect();
OutlierResult {
flagged,
median_log_ratio: Some(median),
mad: Some(mad),
threshold: Some(threshold),
capacity_ceiling_binding,
sample_size,
trimmed_sample_size: trimmed_len,
skip_reason: None,
}
}
fn median_of_sorted(sorted: &[f64]) -> f64 {
let n = sorted.len();
debug_assert!(n > 0, "median_of_sorted called with empty slice");
if n % 2 == 1 {
sorted[n / 2]
} else {
(sorted[n / 2 - 1] + sorted[n / 2]) / 2.0
}
}
#[cfg(test)]
mod tests {
use super::*;
fn map_from(samples: &[(&str, f64)]) -> HashMap<String, f64> {
samples.iter().map(|(k, v)| (k.to_string(), *v)).collect()
}
fn cfg(k: f64, min_samples: usize) -> OutlierConfig {
OutlierConfig {
k,
min_samples,
trim_fraction: 0.05,
}
}
#[test]
fn empty_samples_skips_no_extractable_ratios() {
let m: HashMap<String, f64> = HashMap::new();
let r = detect_outliers(&m, |x| Some(*x), &OutlierConfig::default(), 10.0);
assert_eq!(r.skip_reason, Some(SkipReason::NoExtractableRatios));
assert!(r.flagged.is_empty());
}
#[test]
fn below_min_samples_skips() {
let m = map_from(&[("a", -1.0), ("b", -0.5), ("c", 0.0)]);
let r = detect_outliers(&m, |x| Some(*x), &cfg(5.0, 30), 10.0);
assert_eq!(r.skip_reason, Some(SkipReason::InsufficientSamples));
assert!(r.flagged.is_empty());
}
#[test]
fn all_identical_collapses_mad() {
let pairs: Vec<(String, f64)> = (0..50).map(|i| (format!("c{i}"), -1.0)).collect();
let m: HashMap<_, _> = pairs.into_iter().collect();
let r = detect_outliers(&m, |x| Some(*x), &cfg(5.0, 30), 10.0);
assert_eq!(r.skip_reason, Some(SkipReason::MadCollapsed));
assert!(r.flagged.is_empty());
assert_eq!(r.median_log_ratio, Some(-1.0));
}
#[test]
fn worked_example_from_design_doc() {
let m = map_from(&[
("a", -1.5),
("b", -1.2),
("c", -1.0),
("d", -0.9),
("e", -0.8),
("f", -0.6),
("abuser", 2.5),
]);
let cfg = OutlierConfig {
k: 5.0,
min_samples: 5,
trim_fraction: 0.05,
};
let r = detect_outliers(&m, |x| Some(*x), &cfg, 10.0);
assert!((r.median_log_ratio.unwrap() - (-0.9)).abs() < 1e-9);
assert!((r.mad.unwrap() - 0.3).abs() < 1e-9);
let expected_threshold = -0.9 + 5.0 * MAD_GAUSSIAN_CONSISTENCY * 0.3;
assert!((r.threshold.unwrap() - expected_threshold).abs() < 1e-9);
assert_eq!(r.flagged, vec!["abuser".to_string()]);
assert_eq!(r.skip_reason, None);
}
#[test]
fn standard_deviation_would_miss_what_mad_catches() {
let mut pairs: Vec<(String, f64)> = (0..35)
.map(|i| (format!("c{i}"), (i as f64 - 17.0) * 0.02))
.collect();
pairs.push(("abuser".to_string(), 100.0));
let m: HashMap<_, _> = pairs.into_iter().collect();
let r = detect_outliers(
&m,
|x| Some(*x),
&OutlierConfig::default(), 1000.0,
);
assert!(r.flagged.contains(&"abuser".to_string()));
assert!(r.threshold.unwrap() < 10.0);
}
#[test]
fn capacity_ceiling_clamps_threshold() {
let pairs: Vec<(String, f64)> = (0..50)
.map(|i| (format!("c{i}"), (i as f64 - 25.0) * 0.5))
.collect();
let m: HashMap<_, _> = pairs.into_iter().collect();
let r = detect_outliers(&m, |x| Some(*x), &OutlierConfig::default(), 1.0);
assert!(r.capacity_ceiling_binding);
assert_eq!(r.threshold, Some(1.0));
}
#[test]
fn returns_keys_only_for_those_exceeding_threshold() {
let pairs: Vec<(String, f64)> = (0..30)
.map(|i| (format!("c{i}"), -1.0 + (i as f64) * 0.02))
.collect();
let mut all = pairs;
all.push(("at_threshold".into(), 100.0)); all.push(("normal".into(), -0.95));
let m: HashMap<_, _> = all.into_iter().collect();
let r = detect_outliers(&m, |x| Some(*x), &OutlierConfig::default(), 1000.0);
assert!(r.flagged.contains(&"at_threshold".to_string()));
assert!(!r.flagged.contains(&"normal".to_string()));
}
#[test]
fn extract_returning_none_excludes_from_sample() {
let pairs: Vec<(String, Option<f64>)> = (0..60)
.map(|i| (format!("c{i}"), if i % 2 == 0 { Some(-1.0) } else { None }))
.collect();
let m: HashMap<_, _> = pairs.into_iter().collect();
let r = detect_outliers(&m, |x: &Option<f64>| *x, &OutlierConfig::default(), 10.0);
assert_eq!(r.sample_size, 30);
assert_eq!(r.skip_reason, Some(SkipReason::MadCollapsed));
}
#[test]
fn skip_reason_none_when_pass_succeeds() {
let pairs: Vec<(String, f64)> = (0..30)
.map(|i| (format!("c{i}"), -1.0 + (i as f64 - 15.0) * 0.05))
.collect();
let m: HashMap<_, _> = pairs.into_iter().collect();
let r = detect_outliers(&m, |x| Some(*x), &OutlierConfig::default(), 10.0);
assert_eq!(r.skip_reason, None);
assert!(r.flagged.is_empty());
}
#[test]
fn non_finite_ratios_are_dropped_from_sample() {
let mut pairs: Vec<(String, f64)> = (0..30)
.map(|i| (format!("c{i}"), (i as f64 - 15.0) * 0.05))
.collect();
pairs.push(("nan".into(), f64::NAN));
pairs.push(("pos_inf".into(), f64::INFINITY));
pairs.push(("neg_inf".into(), f64::NEG_INFINITY));
let m: HashMap<_, _> = pairs.into_iter().collect();
let r = detect_outliers(&m, |x| Some(*x), &OutlierConfig::default(), 10.0);
assert_eq!(r.sample_size, 30);
assert!(r.median_log_ratio.unwrap().is_finite());
assert!(!r.flagged.iter().any(|k| k == "nan"));
assert!(!r.flagged.iter().any(|k| k == "pos_inf"));
assert!(!r.flagged.iter().any(|k| k == "neg_inf"));
}
#[test]
fn n_equals_one_with_min_samples_one() {
let m = map_from(&[("only", 0.5)]);
let cfg = OutlierConfig {
k: 5.0,
min_samples: 1,
trim_fraction: 0.05,
};
let r = detect_outliers(&m, |x| Some(*x), &cfg, 10.0);
assert_eq!(r.sample_size, 1);
assert_eq!(r.skip_reason, Some(SkipReason::MadCollapsed));
assert!(r.flagged.is_empty());
}
#[test]
fn n_equals_two_with_min_samples_two() {
let m = map_from(&[("low", -0.1), ("high", 0.1)]);
let cfg = OutlierConfig {
k: 5.0,
min_samples: 2,
trim_fraction: 0.05,
};
let r = detect_outliers(&m, |x| Some(*x), &cfg, 10.0);
assert_eq!(r.sample_size, 2);
assert!((r.median_log_ratio.unwrap() - 0.0).abs() < 1e-9);
assert!((r.mad.unwrap() - 0.1).abs() < 1e-9);
assert!(r.flagged.is_empty());
}
#[test]
fn mad_collapses_after_trim_even_when_raw_set_varies() {
let mut pairs: Vec<(String, f64)> = (0..29).map(|i| (format!("c{i}"), -1.0)).collect();
pairs.push(("the_one_with_variation".into(), 5.0));
let m: HashMap<_, _> = pairs.into_iter().collect();
let r = detect_outliers(&m, |x| Some(*x), &OutlierConfig::default(), 10.0);
assert_eq!(r.sample_size, 30);
assert_eq!(r.skip_reason, Some(SkipReason::MadCollapsed));
assert!(r.flagged.is_empty());
}
#[test]
fn capacity_ceiling_not_binding_when_threshold_below() {
let pairs: Vec<(String, f64)> = (0..30)
.map(|i| (format!("c{i}"), -1.0 + (i as f64 - 15.0) * 0.02))
.collect();
let m: HashMap<_, _> = pairs.into_iter().collect();
let r = detect_outliers(&m, |x| Some(*x), &OutlierConfig::default(), 100.0);
assert!(!r.capacity_ceiling_binding);
assert!(r.threshold.unwrap() < 100.0);
}
#[test]
fn three_outliers_exceeding_trim_fraction() {
let mut pairs: Vec<(String, f64)> = (0..27)
.map(|i| (format!("honest{i}"), -1.0 + (i as f64 - 13.0) * 0.02))
.collect();
pairs.push(("abuser1".into(), 5.0));
pairs.push(("abuser2".into(), 6.0));
pairs.push(("abuser3".into(), 7.0));
let m: HashMap<_, _> = pairs.into_iter().collect();
let r = detect_outliers(&m, |x| Some(*x), &OutlierConfig::default(), 100.0);
assert!(r.flagged.contains(&"abuser1".to_string()));
assert!(r.flagged.contains(&"abuser2".to_string()));
assert!(r.flagged.contains(&"abuser3".to_string()));
assert!(!r.flagged.iter().any(|k| k.starts_with("honest")));
}
#[test]
fn sample_exactly_at_threshold_is_not_flagged() {
let pairs: Vec<(String, f64)> = (0..30).map(|i| (format!("c{i}"), -1.0)).collect();
let mut m: HashMap<_, _> = pairs.into_iter().collect();
m.clear();
let widened: Vec<(String, f64)> = (0..30)
.map(|i| (format!("c{i}"), -1.0 + (i as f64 - 15.0) * 0.01))
.collect();
m.extend(widened);
let cfg = OutlierConfig::default();
let r_probe = detect_outliers(&m, |x| Some(*x), &cfg, 100.0);
let exact_threshold = r_probe.threshold.unwrap();
m.insert("at_threshold".into(), exact_threshold);
let r = detect_outliers(&m, |x| Some(*x), &cfg, 100.0);
assert!(
!r.flagged.contains(&"at_threshold".to_string()),
"sample at threshold should NOT be flagged (strict > semantics), flagged: {:?}",
r.flagged
);
}
#[test]
fn invariant_under_translation() {
let base: Vec<(String, f64)> = (0..30)
.map(|i| (format!("c{i}"), (i as f64 - 15.0) * 0.1))
.collect();
let mut shifted = base.clone();
for (_, v) in shifted.iter_mut() {
*v += 7.5;
}
let m1: HashMap<_, _> = base.into_iter().collect();
let m2: HashMap<_, _> = shifted.into_iter().collect();
let r1 = detect_outliers(&m1, |x| Some(*x), &OutlierConfig::default(), 100.0);
let r2 = detect_outliers(&m2, |x| Some(*x), &OutlierConfig::default(), 100.0);
assert!((r2.median_log_ratio.unwrap() - r1.median_log_ratio.unwrap() - 7.5).abs() < 1e-9);
assert!((r1.mad.unwrap() - r2.mad.unwrap()).abs() < 1e-9);
assert_eq!(r1.flagged.len(), r2.flagged.len());
}
}