use crate::error::AnalyticsError;
#[derive(Debug, Clone, PartialEq)]
pub struct UniformityResult {
pub observed: Vec<u64>,
pub expected: Vec<f64>,
pub chi_squared: f64,
pub degrees_of_freedom: u32,
pub critical_value: f64,
pub alpha: f64,
pub is_uniform: bool,
pub p_value: f64,
}
pub fn chi_squared_uniformity(
observed: &[u64],
expected_weights: Option<&[f64]>,
alpha: f64,
) -> Result<UniformityResult, AnalyticsError> {
let k = observed.len();
if k < 2 {
return Err(AnalyticsError::InsufficientData(
"chi-squared test requires at least 2 cells".to_string(),
));
}
if alpha <= 0.0 || alpha >= 1.0 {
return Err(AnalyticsError::ConfigError(format!(
"alpha={alpha} must be in (0, 1)"
)));
}
let total_n: u64 = observed.iter().sum();
if total_n == 0 {
return Err(AnalyticsError::InsufficientData(
"observed counts sum to zero".to_string(),
));
}
let expected: Vec<f64> = if let Some(weights) = expected_weights {
if weights.len() != k {
return Err(AnalyticsError::ConfigError(format!(
"expected_weights length ({}) must match observed length ({})",
weights.len(),
k
)));
}
for &w in weights {
if w <= 0.0 {
return Err(AnalyticsError::ConfigError(
"all expected weights must be positive".to_string(),
));
}
}
let weight_sum: f64 = weights.iter().sum();
weights
.iter()
.map(|&w| total_n as f64 * w / weight_sum)
.collect()
} else {
let equal = total_n as f64 / k as f64;
vec![equal; k]
};
for (i, &e) in expected.iter().enumerate() {
if e < 1.0 {
return Err(AnalyticsError::InsufficientData(format!(
"expected count for cell {i} is {e:.2} < 1; increase sample size"
)));
}
}
let chi_squared: f64 = observed
.iter()
.zip(expected.iter())
.map(|(&o, &e)| {
let diff = o as f64 - e;
diff * diff / e
})
.sum();
let df = (k - 1) as u32;
let critical_value = chi_squared_critical_value(df, alpha);
let p_value = chi_squared_p_value(chi_squared, df);
let is_uniform = chi_squared <= critical_value;
Ok(UniformityResult {
observed: observed.to_vec(),
expected,
chi_squared,
degrees_of_freedom: df,
critical_value,
alpha,
is_uniform,
p_value,
})
}
#[derive(Debug, Clone)]
pub struct SrmConfig {
pub planned_fractions: Vec<f64>,
pub alpha: f64,
}
impl SrmConfig {
pub fn equal_split(n_variants: usize, alpha: f64) -> Self {
let frac = 1.0 / n_variants as f64;
Self {
planned_fractions: vec![frac; n_variants],
alpha,
}
}
}
pub fn detect_srm(
observed_counts: &[u64],
config: &SrmConfig,
) -> Result<UniformityResult, AnalyticsError> {
if observed_counts.len() != config.planned_fractions.len() {
return Err(AnalyticsError::ConfigError(format!(
"observed_counts length ({}) != planned_fractions length ({})",
observed_counts.len(),
config.planned_fractions.len()
)));
}
chi_squared_uniformity(
observed_counts,
Some(&config.planned_fractions),
config.alpha,
)
}
pub fn bucket_uniformity_test(
values: &[f64],
n_buckets: usize,
alpha: f64,
) -> Result<UniformityResult, AnalyticsError> {
if values.is_empty() {
return Err(AnalyticsError::InsufficientData(
"values slice is empty".to_string(),
));
}
if n_buckets < 2 {
return Err(AnalyticsError::ConfigError(
"n_buckets must be >= 2".to_string(),
));
}
let min = values
.iter()
.cloned()
.fold(f64::INFINITY, f64::min);
let max = values
.iter()
.cloned()
.fold(f64::NEG_INFINITY, f64::max);
if (max - min).abs() < f64::EPSILON {
return Err(AnalyticsError::InsufficientData(
"all values are identical; cannot test uniformity".to_string(),
));
}
let width = (max - min) / n_buckets as f64;
let mut counts = vec![0u64; n_buckets];
for &v in values {
let bucket = ((v - min) / width) as usize;
let bucket = bucket.min(n_buckets - 1);
counts[bucket] += 1;
}
chi_squared_uniformity(&counts, None, alpha)
}
pub fn chi_squared_critical_value(df: u32, alpha: f64) -> f64 {
const TABLE: &[(u32, f64, f64, f64)] = &[
(1, 2.706, 3.841, 6.635),
(2, 4.605, 5.991, 9.210),
(3, 6.251, 7.815, 11.345),
(4, 7.779, 9.488, 13.277),
(5, 9.236, 11.070, 15.086),
(6, 10.645, 12.592, 16.812),
(7, 12.017, 14.067, 18.475),
(8, 13.362, 15.507, 20.090),
(9, 14.684, 16.919, 21.666),
(10, 15.987, 18.307, 23.209),
(15, 22.307, 24.996, 30.578),
(20, 28.412, 31.410, 37.566),
(25, 34.382, 37.652, 44.314),
(30, 40.256, 43.773, 50.892),
];
let (a10, a05, a01) = TABLE
.iter()
.find(|(d, ..)| *d == df)
.map(|&(_, a, b, c)| (a, b, c))
.unwrap_or_else(|| {
let k = df as f64;
let normal_z = if alpha <= 0.01 {
2.326
} else if alpha <= 0.05 {
1.645
} else {
1.282
};
let cv = k
* (1.0 - 2.0 / (9.0 * k) + normal_z * (2.0 / (9.0 * k)).sqrt()).powi(3);
let cv = cv.max(0.0);
let z10 = 1.282_f64;
let cv10 = k * (1.0 - 2.0/(9.0*k) + z10 * (2.0/(9.0*k)).sqrt()).powi(3);
let z05 = 1.645_f64;
let cv05 = k * (1.0 - 2.0/(9.0*k) + z05 * (2.0/(9.0*k)).sqrt()).powi(3);
let _ = cv;
(cv10.max(0.0), cv05.max(0.0), 0.0)
});
if alpha >= 0.10 {
a10
} else if alpha >= 0.05 {
a05
} else {
a01
}
}
pub fn chi_squared_p_value(chi_sq: f64, df: u32) -> f64 {
if df == 0 {
return 1.0;
}
let k = df as f64;
let cube = (chi_sq / k).cbrt();
let mean = 1.0 - 2.0 / (9.0 * k);
let std = (2.0 / (9.0 * k)).sqrt();
let z = if std > 0.0 { (cube - mean) / std } else { 0.0 };
normal_sf(z)
}
fn normal_sf(z: f64) -> f64 {
let x = z / std::f64::consts::SQRT_2;
let p = 0.3275911_f64;
let t = 1.0 / (1.0 + p * x.abs());
let poly = t * (0.254829592
+ t * (-0.284496736
+ t * (1.421413741 + t * (-1.453152027 + t * 1.061405429))));
let erf_abs = 1.0 - poly * (-x * x).exp();
let erf = if x >= 0.0 { erf_abs } else { -erf_abs };
(1.0 - erf) / 2.0
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn uniform_distribution_passes_test() {
let observed = vec![2_500u64, 2_500, 2_500, 2_500];
let result = chi_squared_uniformity(&observed, None, 0.05).expect("should succeed");
assert!(result.is_uniform, "perfectly uniform split should pass");
assert!((result.chi_squared).abs() < 1e-9, "χ² should be 0");
}
#[test]
fn severely_skewed_distribution_fails_test() {
let observed = vec![9_000u64, 333, 334, 333];
let result = chi_squared_uniformity(&observed, None, 0.05).expect("should succeed");
assert!(
!result.is_uniform,
"severely skewed distribution should fail uniformity test"
);
}
#[test]
fn chi_squared_statistic_correct_for_known_case() {
let observed = vec![60u64, 40];
let result = chi_squared_uniformity(&observed, None, 0.05).expect("should succeed");
assert!(
(result.chi_squared - 4.0).abs() < 1e-9,
"χ²={} expected 4.0",
result.chi_squared
);
}
#[test]
fn two_cells_error_on_one_cell() {
let result = chi_squared_uniformity(&[100u64], None, 0.05);
assert!(result.is_err());
}
#[test]
fn zero_total_count_returns_error() {
let result = chi_squared_uniformity(&[0u64, 0, 0], None, 0.05);
assert!(result.is_err());
}
#[test]
fn invalid_alpha_returns_error() {
let result = chi_squared_uniformity(&[100u64, 100], None, 0.0);
assert!(result.is_err());
let result2 = chi_squared_uniformity(&[100u64, 100], None, 1.0);
assert!(result2.is_err());
}
#[test]
fn srm_not_detected_for_perfect_split() {
let config = SrmConfig::equal_split(2, 0.01);
let result = detect_srm(&[5_000u64, 5_000], &config).expect("should succeed");
assert!(result.is_uniform, "no SRM expected for perfect split");
}
#[test]
fn srm_detected_for_extreme_imbalance() {
let config = SrmConfig::equal_split(3, 0.01);
let result = detect_srm(&[8_000u64, 1_000, 1_000], &config).expect("should succeed");
assert!(
!result.is_uniform,
"SRM expected for highly imbalanced assignment"
);
}
#[test]
fn srm_planned_fractions_mismatch_length_error() {
let config = SrmConfig {
planned_fractions: vec![0.5, 0.5],
alpha: 0.05,
};
let result = detect_srm(&[100u64, 100, 100], &config);
assert!(result.is_err());
}
#[test]
fn bucket_uniformity_passes_for_uniform_values() {
let values: Vec<f64> = (0..1_000).map(|i| (i % 10) as f64).collect();
let result = bucket_uniformity_test(&values, 10, 0.05).expect("should succeed");
assert!(
result.is_uniform,
"uniform bucket distribution should pass, χ²={}",
result.chi_squared
);
}
#[test]
fn bucket_uniformity_fails_for_all_in_one_bucket() {
let mut values: Vec<f64> = vec![0.0; 999];
values.push(10.0);
let result = bucket_uniformity_test(&values, 5, 0.05).expect("should succeed");
assert!(
!result.is_uniform,
"heavily skewed bucket distribution should fail"
);
}
#[test]
fn bucket_uniformity_identical_values_error() {
let values = vec![5.0f64; 100];
let result = bucket_uniformity_test(&values, 5, 0.05);
assert!(result.is_err());
}
#[test]
fn assign_variant_fnv_uniformity_10k() {
fn fnv1a_32(data: &[u8]) -> u32 {
let mut hash: u32 = 2_166_136_261;
for &b in data {
hash ^= u32::from(b);
hash = hash.wrapping_mul(16_777_619);
}
hash
}
let n_variants = 4u32;
let n_users = 10_000usize;
let mut counts = vec![0u64; n_variants as usize];
for i in 0..n_users {
let user_id = format!("user_{i:06}");
let hash = fnv1a_32(user_id.as_bytes());
let variant = (hash % n_variants) as usize;
counts[variant] += 1;
}
let result = chi_squared_uniformity(&counts, None, 0.05)
.expect("chi-squared test should succeed");
assert!(
result.is_uniform,
"FNV-1a assignment over 10K users is not uniform: counts={counts:?}, χ²={:.4}",
result.chi_squared
);
}
#[test]
fn p_value_near_one_for_zero_chi_squared() {
let p = chi_squared_p_value(0.0, 3);
assert!(p > 0.4, "p-value for χ²=0 should be near 0.5, got {p}");
}
#[test]
fn critical_value_for_df1_alpha05_known() {
let cv = chi_squared_critical_value(1, 0.05);
assert!(
(cv - 3.841).abs() < 0.01,
"critical value df=1,α=0.05 should be 3.841, got {cv}"
);
}
}