use std::f64::consts::PI;
#[derive(Debug, Clone, PartialEq)]
pub struct Centroid {
pub mean: f64,
pub weight: f64,
}
impl Centroid {
pub fn new(mean: f64, weight: f64) -> Self {
Self { mean, weight }
}
}
#[derive(Debug, Clone)]
pub struct TDigest {
centroids: Vec<Centroid>,
n: f64,
delta: f64,
}
impl TDigest {
pub fn new(delta: f64) -> Self {
Self {
centroids: Vec::new(),
n: 0.0,
delta,
}
}
pub fn total_weight(&self) -> f64 {
self.n
}
pub fn num_centroids(&self) -> usize {
self.centroids.len()
}
pub fn add(&mut self, x: f64) {
self.add_weighted(x, 1.0);
}
pub fn add_weighted(&mut self, x: f64, w: f64) {
self.centroids.push(Centroid::new(x, w));
self.n += w;
if self.centroids.len() > (10.0 * self.delta) as usize {
self.compress();
}
}
pub fn compress(&mut self) {
if self.centroids.is_empty() {
return;
}
self.centroids
.sort_by(|a, b| a.mean.partial_cmp(&b.mean).unwrap_or(std::cmp::Ordering::Equal));
let total = self.n;
if total == 0.0 {
return;
}
let mut merged: Vec<Centroid> = Vec::with_capacity(self.centroids.len());
let mut current = self.centroids[0].clone();
let mut cumulative_weight = 0.0;
for i in 1..self.centroids.len() {
let c = &self.centroids[i];
let q0 = (cumulative_weight + current.weight / 2.0) / total;
let q1 = (cumulative_weight + current.weight + c.weight / 2.0) / total;
let k_limit = k1_scale(q0, self.delta) + 1.0;
if k1_scale(q1, self.delta) <= k_limit {
let new_weight = current.weight + c.weight;
current.mean =
(current.mean * current.weight + c.mean * c.weight) / new_weight;
current.weight = new_weight;
} else {
cumulative_weight += current.weight;
merged.push(current.clone());
current = c.clone();
}
}
merged.push(current);
self.centroids = merged;
}
pub fn quantile(&self, q: f64) -> f64 {
if self.centroids.is_empty() || self.n == 0.0 {
return f64::NAN;
}
let mut sorted = self.centroids.clone();
sorted.sort_by(|a, b| a.mean.partial_cmp(&b.mean).unwrap_or(std::cmp::Ordering::Equal));
if q <= 0.0 {
return sorted.first().map(|c| c.mean).unwrap_or(f64::NAN);
}
if q >= 1.0 {
return sorted.last().map(|c| c.mean).unwrap_or(f64::NAN);
}
let n = sorted.len();
let mut mid_ranks = Vec::with_capacity(n);
let mut prefix = 0.0_f64;
for c in &sorted {
mid_ranks.push(prefix + c.weight / 2.0);
prefix += c.weight;
}
let target = q * self.n;
if target <= mid_ranks[0] {
return sorted[0].mean;
}
if target >= mid_ranks[n - 1] {
return sorted[n - 1].mean;
}
let i = mid_ranks
.partition_point(|&r| r <= target)
.min(n - 1)
.max(1);
let r0 = mid_ranks[i - 1];
let r1 = mid_ranks[i];
let frac = if r1 - r0 > 0.0 {
(target - r0) / (r1 - r0)
} else {
0.5
};
sorted[i - 1].mean + frac * (sorted[i].mean - sorted[i - 1].mean)
}
pub fn cdf(&self, x: f64) -> f64 {
if self.centroids.is_empty() || self.n == 0.0 {
return f64::NAN;
}
let mut sorted = self.centroids.clone();
sorted.sort_by(|a, b| a.mean.partial_cmp(&b.mean).unwrap_or(std::cmp::Ordering::Equal));
if x < sorted[0].mean {
return 0.0;
}
if x >= sorted[sorted.len() - 1].mean {
return 1.0;
}
let mut cumulative = 0.0;
for i in 0..sorted.len() {
let c = &sorted[i];
if x < c.mean {
if i == 0 {
return 0.0;
}
let prev = &sorted[i - 1];
let frac = if c.mean - prev.mean != 0.0 {
(x - prev.mean) / (c.mean - prev.mean)
} else {
0.5
};
let prev_cum_mid = cumulative - prev.weight / 2.0;
let this_cum_mid = cumulative + c.weight / 2.0;
return (prev_cum_mid + frac * (this_cum_mid - prev_cum_mid)) / self.n;
}
cumulative += c.weight;
}
1.0
}
pub fn merge(&mut self, other: &TDigest) {
for c in &other.centroids {
self.centroids.push(c.clone());
self.n += c.weight;
}
self.compress();
}
}
fn k1_scale(q: f64, delta: f64) -> f64 {
let q_clamped = q.clamp(1e-10, 1.0 - 1e-10);
delta / (2.0 * PI) * (2.0 * q_clamped - 1.0).asin()
}
#[cfg(test)]
mod tests {
use super::*;
fn build_digest(n: usize) -> TDigest {
let mut td = TDigest::new(100.0);
for v in 1..=n {
td.add(v as f64);
}
td
}
#[test]
fn test_median_on_1_to_100() {
let mut td = build_digest(100);
td.compress();
let median = td.quantile(0.5);
assert!(
(median - 50.0).abs() < 2.0,
"Expected median ≈ 50, got {median}"
);
}
#[test]
fn test_quantile_0_is_min() {
let mut td = build_digest(50);
td.compress();
let min = td.quantile(0.0);
assert!(
(min - 1.0).abs() < 1.0,
"Expected min ≈ 1, got {min}"
);
}
#[test]
fn test_quantile_1_is_max() {
let mut td = build_digest(50);
td.compress();
let max = td.quantile(1.0);
assert!(
(max - 50.0).abs() < 1.0,
"Expected max ≈ 50, got {max}"
);
}
#[test]
fn test_merge_two_digests() {
let mut td_a = TDigest::new(100.0);
let mut td_b = TDigest::new(100.0);
for v in 1..=50 {
td_a.add(v as f64);
}
for v in 51..=100 {
td_b.add(v as f64);
}
td_a.merge(&td_b);
let median = td_a.quantile(0.5);
assert!(
(median - 50.0).abs() < 3.0,
"Merged median expected ≈ 50, got {median}"
);
}
#[test]
fn test_cdf_at_median_is_approx_half() {
let mut td = build_digest(100);
td.compress();
let cdf_at_50 = td.cdf(50.0);
assert!(
(cdf_at_50 - 0.5).abs() < 0.05,
"CDF(50) expected ≈ 0.5, got {cdf_at_50}"
);
}
#[test]
fn test_empty_digest_returns_nan() {
let td = TDigest::new(100.0);
assert!(td.quantile(0.5).is_nan());
assert!(td.cdf(0.0).is_nan());
}
}