use serde::{Deserialize, Serialize};
use statrs::distribution::{
Beta as BetaDist, ContinuousCDF, Gamma as GammaDist, LogNormal as LogNormalDist,
Normal as NormalDist,
};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[non_exhaustive]
#[serde(tag = "kind")]
pub enum Distribution {
Uniform { lo: f64, hi: f64 },
Normal { mu: f64, sigma: f64 },
LogNormal { mu_log: f64, sigma_log: f64 },
Triangular { lo: f64, mode: f64, hi: f64 },
Beta {
alpha: f64,
beta: f64,
lo: f64,
hi: f64,
},
Gamma { shape: f64, scale: f64 },
Weibull { shape: f64, scale: f64 },
Exponential { lambda: f64 },
Bernoulli { p: f64 },
DiscreteUniform { lo: i64, hi: i64 },
}
impl Distribution {
#[must_use]
pub fn quantile(&self, u: f64) -> f64 {
let u = u.clamp(0.0, 1.0);
match *self {
Self::Uniform { lo, hi } => uniform_quantile(lo, hi, u),
Self::Normal { mu, sigma } => normal_quantile(mu, sigma, u),
Self::LogNormal { mu_log, sigma_log } => lognormal_quantile(mu_log, sigma_log, u),
Self::Triangular { lo, mode, hi } => triangular_quantile(lo, mode, hi, u),
Self::Beta {
alpha,
beta,
lo,
hi,
} => beta_quantile(alpha, beta, lo, hi, u),
Self::Gamma { shape, scale } => gamma_quantile(shape, scale, u),
Self::Weibull { shape, scale } => weibull_quantile(shape, scale, u),
Self::Exponential { lambda } => exponential_quantile(lambda, u),
Self::Bernoulli { p } => bernoulli_quantile(p, u),
Self::DiscreteUniform { lo, hi } => discrete_uniform_quantile(lo, hi, u),
}
}
#[must_use]
pub fn support(&self) -> (f64, f64) {
match *self {
Self::Uniform { lo, hi }
| Self::Triangular { lo, hi, .. }
| Self::Beta { lo, hi, .. } => (lo, hi),
Self::Normal { .. } => (f64::NEG_INFINITY, f64::INFINITY),
Self::LogNormal { .. }
| Self::Gamma { .. }
| Self::Exponential { .. }
| Self::Weibull { .. } => (0.0, f64::INFINITY),
Self::Bernoulli { .. } => (0.0, 1.0),
#[allow(clippy::cast_precision_loss)]
Self::DiscreteUniform { lo, hi } => (lo as f64, hi as f64),
}
}
}
fn uniform_quantile(lo: f64, hi: f64, u: f64) -> f64 {
lo + u * (hi - lo)
}
fn triangular_quantile(lo: f64, mode: f64, hi: f64, u: f64) -> f64 {
assert!(lo < hi, "Triangular: lo must be < hi");
assert!(
lo <= mode && mode <= hi,
"Triangular: mode must be in [lo, hi]"
);
let f_mode = (mode - lo) / (hi - lo);
if u <= f_mode {
lo + (u * (hi - lo) * (mode - lo)).sqrt()
} else {
hi - ((1.0 - u) * (hi - lo) * (hi - mode)).sqrt()
}
}
fn weibull_quantile(shape: f64, scale: f64, u: f64) -> f64 {
assert!(shape > 0.0, "Weibull: shape must be > 0");
assert!(scale > 0.0, "Weibull: scale must be > 0");
if u >= 1.0 {
return f64::INFINITY;
}
scale * (-(1.0 - u).ln()).powf(1.0 / shape)
}
fn exponential_quantile(lambda: f64, u: f64) -> f64 {
assert!(lambda > 0.0, "Exponential: lambda must be > 0");
if u >= 1.0 {
return f64::INFINITY;
}
-((1.0 - u).ln()) / lambda
}
fn bernoulli_quantile(p: f64, u: f64) -> f64 {
assert!((0.0..=1.0).contains(&p), "Bernoulli: p must be in [0, 1]");
if u <= 1.0 - p {
0.0
} else {
1.0
}
}
fn discrete_uniform_quantile(lo: i64, hi: i64, u: f64) -> f64 {
assert!(lo <= hi, "DiscreteUniform: lo must be <= hi");
let n = hi - lo + 1;
#[allow(clippy::cast_precision_loss)]
let scaled = u * (n as f64);
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let idx = (scaled.floor() as i64).min(n - 1);
#[allow(clippy::cast_precision_loss)]
let result = (lo + idx) as f64;
result
}
#[allow(clippy::expect_used)]
fn normal_quantile(mu: f64, sigma: f64, u: f64) -> f64 {
assert!(sigma > 0.0, "Normal: sigma must be > 0");
let dist = NormalDist::new(mu, sigma).expect("Normal::new param check");
dist.inverse_cdf(u)
}
#[allow(clippy::expect_used)]
fn lognormal_quantile(mu_log: f64, sigma_log: f64, u: f64) -> f64 {
assert!(sigma_log > 0.0, "LogNormal: sigma_log must be > 0");
let dist = LogNormalDist::new(mu_log, sigma_log).expect("LogNormal::new param check");
dist.inverse_cdf(u)
}
#[allow(clippy::expect_used)]
fn beta_quantile(alpha: f64, beta: f64, lo: f64, hi: f64, u: f64) -> f64 {
assert!(alpha > 0.0, "Beta: alpha must be > 0");
assert!(beta > 0.0, "Beta: beta must be > 0");
assert!(lo < hi, "Beta: lo must be < hi");
let dist = BetaDist::new(alpha, beta).expect("Beta::new param check");
let v = dist.inverse_cdf(u);
lo + (hi - lo) * v
}
#[allow(clippy::expect_used)]
fn gamma_quantile(shape: f64, scale: f64, u: f64) -> f64 {
assert!(shape > 0.0, "Gamma: shape must be > 0");
assert!(scale > 0.0, "Gamma: scale must be > 0");
let dist = GammaDist::new(shape, 1.0 / scale).expect("Gamma::new param check");
dist.inverse_cdf(u)
}
#[cfg(test)]
#[allow(
clippy::float_cmp,
clippy::approx_constant,
clippy::cast_precision_loss
)]
mod tests {
use super::*;
fn assert_close(got: f64, want: f64, tol: f64, ctx: &str) {
assert!(
(got - want).abs() <= tol,
"{ctx}: got {got}, want {want}, |Δ|={}, tol={tol}",
(got - want).abs()
);
}
fn assert_monotone_non_decreasing(d: &Distribution) {
let us = [
0.0, 0.001, 0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.99, 0.999,
1.0,
];
let mut prev = f64::NEG_INFINITY;
for &u in &us {
let q = d.quantile(u);
assert!(
q >= prev || (q.is_nan() && prev.is_nan()),
"monotonicity violated for {d:?}: q({u}) = {q} < prev {prev}"
);
prev = q;
}
}
#[test]
fn uniform_zero_one_quantile_is_u() {
let d = Distribution::Uniform { lo: 0.0, hi: 1.0 };
for u in [0.0, 0.25, 0.5, 0.75, 1.0] {
assert_eq!(d.quantile(u), u);
}
}
#[test]
fn uniform_general_quantile_linearly_maps() {
let d = Distribution::Uniform { lo: 10.0, hi: 30.0 };
assert_eq!(d.quantile(0.0), 10.0);
assert_eq!(d.quantile(0.5), 20.0);
assert_eq!(d.quantile(1.0), 30.0);
}
#[test]
fn uniform_negative_range() {
let d = Distribution::Uniform { lo: -5.0, hi: 5.0 };
assert_eq!(d.quantile(0.5), 0.0);
assert_eq!(d.quantile(0.0), -5.0);
assert_eq!(d.quantile(1.0), 5.0);
}
#[test]
fn uniform_support_matches_params() {
let d = Distribution::Uniform { lo: 2.5, hi: 7.5 };
assert_eq!(d.support(), (2.5, 7.5));
}
#[test]
fn uniform_monotone() {
assert_monotone_non_decreasing(&Distribution::Uniform { lo: 0.0, hi: 1.0 });
assert_monotone_non_decreasing(&Distribution::Uniform {
lo: -10.0,
hi: 100.0,
});
}
#[test]
fn uniform_saturates_out_of_range_u() {
let d = Distribution::Uniform { lo: 0.0, hi: 1.0 };
assert_eq!(d.quantile(-0.5), 0.0);
assert_eq!(d.quantile(1.5), 1.0);
}
#[test]
fn normal_quantile_at_half_is_mean() {
let d = Distribution::Normal {
mu: 5.0,
sigma: 2.0,
};
assert_close(d.quantile(0.5), 5.0, 1e-12, "Normal median");
}
#[test]
fn normal_quantile_one_sigma_above_mean() {
let d = Distribution::Normal {
mu: 0.0,
sigma: 1.0,
};
assert_close(d.quantile(0.841_344_746_068_543), 1.0, 1e-9, "+1σ");
}
#[test]
fn normal_quantile_symmetric_about_mean() {
let d = Distribution::Normal {
mu: 7.0,
sigma: 3.0,
};
for u in [0.1, 0.2, 0.3, 0.4] {
let q_lo = d.quantile(u);
let q_hi = d.quantile(1.0 - u);
assert_close(q_lo + q_hi, 2.0 * 7.0, 1e-9, "Normal symmetry");
}
}
#[test]
fn normal_support_is_unbounded() {
let d = Distribution::Normal {
mu: 0.0,
sigma: 1.0,
};
let (lo, hi) = d.support();
assert_eq!(lo, f64::NEG_INFINITY);
assert_eq!(hi, f64::INFINITY);
}
#[test]
fn normal_monotone() {
assert_monotone_non_decreasing(&Distribution::Normal {
mu: 0.0,
sigma: 1.0,
});
}
#[test]
#[should_panic(expected = "sigma must be > 0")]
fn normal_zero_sigma_panics() {
let d = Distribution::Normal {
mu: 0.0,
sigma: 0.0,
};
let _ = d.quantile(0.5);
}
#[test]
fn lognormal_quantile_at_half_is_exp_mu_log() {
let d = Distribution::LogNormal {
mu_log: 1.0,
sigma_log: 0.5,
};
assert_close(d.quantile(0.5), 1.0_f64.exp(), 1e-9, "LogNormal median");
}
#[test]
fn lognormal_support_is_zero_to_infinity() {
let d = Distribution::LogNormal {
mu_log: 0.0,
sigma_log: 1.0,
};
let (lo, hi) = d.support();
assert_eq!(lo, 0.0);
assert_eq!(hi, f64::INFINITY);
}
#[test]
fn lognormal_monotone() {
assert_monotone_non_decreasing(&Distribution::LogNormal {
mu_log: 0.0,
sigma_log: 1.0,
});
}
#[test]
fn triangular_quantile_at_zero_is_lo() {
let d = Distribution::Triangular {
lo: 0.0,
mode: 0.5,
hi: 1.0,
};
assert_eq!(d.quantile(0.0), 0.0);
}
#[test]
fn triangular_quantile_at_one_is_hi() {
let d = Distribution::Triangular {
lo: 0.0,
mode: 0.5,
hi: 1.0,
};
assert_eq!(d.quantile(1.0), 1.0);
}
#[test]
fn triangular_at_f_mode_is_mode() {
let d = Distribution::Triangular {
lo: 0.0,
mode: 0.5,
hi: 1.0,
};
assert_close(d.quantile(0.5), 0.5, 1e-12, "Triangular at F(mode)");
}
#[test]
fn triangular_asymmetric_mode() {
let d = Distribution::Triangular {
lo: 0.0,
mode: 0.25,
hi: 1.0,
};
assert_close(d.quantile(0.25), 0.25, 1e-12, "asymmetric mode");
}
#[test]
fn triangular_support_matches_params() {
let d = Distribution::Triangular {
lo: -2.0,
mode: 0.0,
hi: 5.0,
};
assert_eq!(d.support(), (-2.0, 5.0));
}
#[test]
fn triangular_monotone_symmetric() {
assert_monotone_non_decreasing(&Distribution::Triangular {
lo: 0.0,
mode: 0.5,
hi: 1.0,
});
}
#[test]
fn triangular_monotone_asymmetric() {
assert_monotone_non_decreasing(&Distribution::Triangular {
lo: -10.0,
mode: -3.0,
hi: 7.0,
});
}
#[test]
fn beta_quantile_at_half_for_alpha_eq_beta_is_midpoint() {
let d = Distribution::Beta {
alpha: 2.0,
beta: 2.0,
lo: 0.0,
hi: 1.0,
};
assert_close(d.quantile(0.5), 0.5, 1e-9, "symmetric Beta median");
}
#[test]
fn beta_affine_to_general_range() {
let d = Distribution::Beta {
alpha: 2.0,
beta: 2.0,
lo: 10.0,
hi: 30.0,
};
assert_close(d.quantile(0.5), 20.0, 1e-8, "Beta affine median");
}
#[test]
fn beta_quantile_at_zero_is_lo() {
let d = Distribution::Beta {
alpha: 2.0,
beta: 5.0,
lo: 1.0,
hi: 7.0,
};
assert_close(d.quantile(0.0), 1.0, 1e-12, "Beta lo edge");
}
#[test]
fn beta_quantile_at_one_is_hi() {
let d = Distribution::Beta {
alpha: 2.0,
beta: 5.0,
lo: 1.0,
hi: 7.0,
};
assert_close(d.quantile(1.0), 7.0, 1e-12, "Beta hi edge");
}
#[test]
fn beta_uniform_special_case() {
let d = Distribution::Beta {
alpha: 1.0,
beta: 1.0,
lo: 0.0,
hi: 1.0,
};
for u in [0.1, 0.3, 0.5, 0.7, 0.9] {
assert_close(d.quantile(u), u, 1e-9, "Beta(1,1) ≡ Uniform");
}
}
#[test]
fn beta_monotone() {
assert_monotone_non_decreasing(&Distribution::Beta {
alpha: 2.0,
beta: 5.0,
lo: 0.0,
hi: 1.0,
});
}
#[test]
fn gamma_shape_one_collapses_to_exponential() {
let d_g = Distribution::Gamma {
shape: 1.0,
scale: 2.0,
};
let d_e = Distribution::Exponential { lambda: 0.5 };
for u in [0.1, 0.3, 0.5, 0.7, 0.9] {
assert_close(
d_g.quantile(u),
d_e.quantile(u),
1e-7,
"Gamma(1) ≡ Exponential",
);
}
}
#[test]
fn gamma_quantile_at_zero_is_zero() {
let d = Distribution::Gamma {
shape: 2.0,
scale: 3.0,
};
assert_close(d.quantile(0.0), 0.0, 1e-12, "Gamma lo edge");
}
#[test]
fn gamma_support() {
let d = Distribution::Gamma {
shape: 2.0,
scale: 3.0,
};
assert_eq!(d.support(), (0.0, f64::INFINITY));
}
#[test]
fn gamma_monotone() {
assert_monotone_non_decreasing(&Distribution::Gamma {
shape: 2.0,
scale: 3.0,
});
}
#[test]
fn weibull_shape_one_collapses_to_exponential() {
let d_w = Distribution::Weibull {
shape: 1.0,
scale: 4.0,
};
let d_e = Distribution::Exponential { lambda: 0.25 };
for u in [0.1, 0.3, 0.5, 0.7, 0.9] {
assert_close(
d_w.quantile(u),
d_e.quantile(u),
1e-12,
"Weibull(1) ≡ Exponential",
);
}
}
#[test]
fn weibull_quantile_at_zero_is_zero() {
let d = Distribution::Weibull {
shape: 2.0,
scale: 1.0,
};
assert_close(d.quantile(0.0), 0.0, 1e-12, "Weibull lo edge");
}
#[test]
fn weibull_quantile_at_one_is_infinity() {
let d = Distribution::Weibull {
shape: 2.0,
scale: 1.0,
};
assert_eq!(d.quantile(1.0), f64::INFINITY);
}
#[test]
fn weibull_monotone() {
assert_monotone_non_decreasing(&Distribution::Weibull {
shape: 2.0,
scale: 1.0,
});
}
#[test]
fn exponential_quantile_at_zero_is_zero() {
let d = Distribution::Exponential { lambda: 1.0 };
assert_eq!(d.quantile(0.0), 0.0);
}
#[test]
fn exponential_quantile_at_one_is_infinity() {
let d = Distribution::Exponential { lambda: 1.0 };
assert_eq!(d.quantile(1.0), f64::INFINITY);
}
#[test]
fn exponential_quantile_known_point() {
let lambda = 2.0_f64;
let d = Distribution::Exponential { lambda };
let u = 1.0 - (-1.0_f64).exp();
assert_close(d.quantile(u), 1.0 / lambda, 1e-12, "Exponential @1/λ");
}
#[test]
fn exponential_monotone() {
assert_monotone_non_decreasing(&Distribution::Exponential { lambda: 1.0 });
}
#[test]
fn bernoulli_zero_p_is_always_zero() {
let d = Distribution::Bernoulli { p: 0.0 };
for u in [0.0, 0.25, 0.5, 0.75, 1.0] {
assert_eq!(d.quantile(u), 0.0);
}
}
#[test]
fn bernoulli_one_p_returns_one_above_zero() {
let d = Distribution::Bernoulli { p: 1.0 };
assert_eq!(d.quantile(0.0), 0.0); for u in [0.000_001, 0.25, 0.5, 0.75, 1.0] {
assert_eq!(d.quantile(u), 1.0);
}
}
#[test]
fn bernoulli_threshold_is_inclusive_at_one_minus_p() {
let d = Distribution::Bernoulli { p: 0.3 };
assert_eq!(d.quantile(0.0), 0.0);
assert_eq!(d.quantile(0.5), 0.0);
assert_eq!(d.quantile(0.69), 0.0);
assert_eq!(d.quantile(0.7), 0.0); assert_eq!(d.quantile(0.700_000_000_001), 1.0);
assert_eq!(d.quantile(0.99), 1.0);
assert_eq!(d.quantile(1.0), 1.0);
}
#[test]
fn bernoulli_monotone() {
assert_monotone_non_decreasing(&Distribution::Bernoulli { p: 0.4 });
}
#[test]
fn discrete_uniform_singleton() {
let d = Distribution::DiscreteUniform { lo: 5, hi: 5 };
for u in [0.0, 0.5, 1.0] {
assert_eq!(d.quantile(u), 5.0);
}
}
#[test]
fn discrete_uniform_two_values() {
let d = Distribution::DiscreteUniform { lo: 0, hi: 1 };
assert_eq!(d.quantile(0.0), 0.0);
assert_eq!(d.quantile(0.49), 0.0);
assert_eq!(d.quantile(0.5), 1.0);
assert_eq!(d.quantile(0.99), 1.0);
assert_eq!(d.quantile(1.0), 1.0);
}
#[test]
fn discrete_uniform_six_values() {
let d = Distribution::DiscreteUniform { lo: 1, hi: 6 };
assert_eq!(d.quantile(0.0), 1.0);
assert_eq!(d.quantile(1.0 / 6.0 + 1e-9), 2.0); assert_eq!(d.quantile(0.5), 4.0); assert_eq!(d.quantile(1.0), 6.0); }
#[test]
fn discrete_uniform_negative_range() {
let d = Distribution::DiscreteUniform { lo: -3, hi: 3 };
assert_eq!(d.quantile(0.0), -3.0);
assert_eq!(d.quantile(1.0), 3.0);
let mid = d.quantile(0.5);
assert_eq!(mid, 0.0);
}
#[test]
fn discrete_uniform_monotone() {
assert_monotone_non_decreasing(&Distribution::DiscreteUniform { lo: 1, hi: 10 });
}
#[test]
fn distribution_serde_round_trip_for_all_variants() {
let cases = vec![
Distribution::Uniform { lo: 1.0, hi: 5.0 },
Distribution::Normal {
mu: 0.0,
sigma: 2.0,
},
Distribution::LogNormal {
mu_log: 1.0,
sigma_log: 0.5,
},
Distribution::Triangular {
lo: 0.0,
mode: 0.3,
hi: 1.0,
},
Distribution::Beta {
alpha: 2.0,
beta: 5.0,
lo: 0.0,
hi: 1.0,
},
Distribution::Gamma {
shape: 2.0,
scale: 1.0,
},
Distribution::Weibull {
shape: 1.5,
scale: 2.0,
},
Distribution::Exponential { lambda: 0.7 },
Distribution::Bernoulli { p: 0.3 },
Distribution::DiscreteUniform { lo: 1, hi: 6 },
];
for d in cases {
let json = serde_json::to_string(&d).expect("serialize");
let back: Distribution = serde_json::from_str(&json).expect("deserialize");
assert_eq!(back, d, "round-trip {d:?} → {json} → {back:?}");
}
}
#[test]
fn quantile_at_zero_returns_lower_support_for_finite_distributions() {
let cases = vec![
(Distribution::Uniform { lo: 2.0, hi: 5.0 }, 2.0),
(
Distribution::Triangular {
lo: -1.0,
mode: 0.0,
hi: 1.0,
},
-1.0,
),
(
Distribution::Beta {
alpha: 2.0,
beta: 3.0,
lo: 0.5,
hi: 1.5,
},
0.5,
),
(Distribution::Bernoulli { p: 0.4 }, 0.0),
(Distribution::DiscreteUniform { lo: -2, hi: 2 }, -2.0),
];
for (d, lo) in cases {
assert_close(d.quantile(0.0), lo, 1e-9, "lo edge");
}
}
#[test]
fn quantile_at_one_returns_upper_support_for_finite_distributions() {
let cases = vec![
(Distribution::Uniform { lo: 2.0, hi: 5.0 }, 5.0),
(
Distribution::Triangular {
lo: -1.0,
mode: 0.0,
hi: 1.0,
},
1.0,
),
(
Distribution::Beta {
alpha: 2.0,
beta: 3.0,
lo: 0.5,
hi: 1.5,
},
1.5,
),
(Distribution::Bernoulli { p: 0.4 }, 1.0),
(Distribution::DiscreteUniform { lo: -2, hi: 2 }, 2.0),
];
for (d, hi) in cases {
assert_close(d.quantile(1.0), hi, 1e-9, "hi edge");
}
}
}