use crate::SketchError;
#[derive(Debug, Clone, Copy)]
struct Centroid {
mean: f64,
weight: f64,
}
#[derive(Debug, Clone)]
pub struct TDigest {
compression: f64,
centroids: Vec<Centroid>,
total_weight: f64,
}
impl TDigest {
pub fn new(compression: f64) -> Result<Self, SketchError> {
if !compression.is_finite() || compression < 10.0 {
return Err(SketchError::InvalidParameter(
"compression must be finite and greater than or equal to 10",
));
}
Ok(Self {
compression,
centroids: Vec::new(),
total_weight: 0.0,
})
}
pub fn with_error_rate(quantile_error: f64) -> Result<Self, SketchError> {
if !quantile_error.is_finite() || quantile_error <= 0.0 || quantile_error >= 1.0 {
return Err(SketchError::InvalidParameter(
"quantile_error must be finite and strictly between 0 and 1",
));
}
let compression = (10.0 / quantile_error).ceil().max(10.0);
Self::new(compression)
}
pub fn compression(&self) -> f64 {
self.compression
}
pub fn centroid_count(&self) -> usize {
self.centroids.len()
}
pub fn count(&self) -> u64 {
self.total_weight.round() as u64
}
pub fn is_empty(&self) -> bool {
self.total_weight == 0.0
}
pub fn add(&mut self, value: f64) {
self.add_weighted(value, 1.0);
}
pub fn quantile(&self, q: f64) -> Result<f64, SketchError> {
if !q.is_finite() || !(0.0..=1.0).contains(&q) {
return Err(SketchError::InvalidParameter(
"q must be finite and in [0, 1]",
));
}
if self.centroids.is_empty() {
return Err(SketchError::InvalidParameter(
"quantile is undefined for an empty digest",
));
}
let mut centroids = self.centroids.clone();
centroids.sort_unstable_by(|left, right| left.mean.total_cmp(&right.mean));
if q <= 0.0 {
return Ok(centroids[0].mean);
}
if q >= 1.0 {
return Ok(centroids[centroids.len() - 1].mean);
}
let target = q * self.total_weight;
let mut cumulative = 0.0;
for index in 0..centroids.len() {
let current = centroids[index];
let next_cumulative = cumulative + current.weight;
if target <= next_cumulative {
if index == 0 {
return Ok(current.mean);
}
let previous = centroids[index - 1];
let left_rank = cumulative - previous.weight * 0.5;
let right_rank = cumulative + current.weight * 0.5;
if right_rank <= left_rank + f64::EPSILON {
return Ok(current.mean);
}
let t = ((target - left_rank) / (right_rank - left_rank)).clamp(0.0, 1.0);
return Ok(previous.mean + t * (current.mean - previous.mean));
}
cumulative = next_cumulative;
}
Ok(centroids[centroids.len() - 1].mean)
}
pub fn merge(&mut self, other: &Self) -> Result<(), SketchError> {
if (self.compression - other.compression).abs() > f64::EPSILON {
return Err(SketchError::IncompatibleSketches(
"compression must match for merge",
));
}
for centroid in &other.centroids {
self.add_weighted(centroid.mean, centroid.weight);
}
self.compress();
Ok(())
}
pub fn clear(&mut self) {
self.centroids.clear();
self.total_weight = 0.0;
}
fn add_weighted(&mut self, value: f64, weight: f64) {
if !value.is_finite() || !weight.is_finite() || weight <= 0.0 {
return;
}
if self.centroids.is_empty() {
self.centroids.push(Centroid {
mean: value,
weight,
});
self.total_weight += weight;
return;
}
self.centroids
.sort_unstable_by(|left, right| left.mean.total_cmp(&right.mean));
let mut nearest_index = 0usize;
let mut nearest_distance = (self.centroids[0].mean - value).abs();
for (index, centroid) in self.centroids.iter().enumerate().skip(1) {
let distance = (centroid.mean - value).abs();
if distance < nearest_distance {
nearest_distance = distance;
nearest_index = index;
}
}
let q = self.centroid_quantile(nearest_index);
let max_weight = self.max_centroid_weight(q);
if self.centroids[nearest_index].weight + weight <= max_weight {
let centroid = &mut self.centroids[nearest_index];
let updated_weight = centroid.weight + weight;
centroid.mean += (value - centroid.mean) * (weight / updated_weight);
centroid.weight = updated_weight;
} else {
let insert_index = match self
.centroids
.binary_search_by(|centroid| centroid.mean.total_cmp(&value))
{
Ok(index) | Err(index) => index,
};
self.centroids.insert(
insert_index,
Centroid {
mean: value,
weight,
},
);
}
self.total_weight += weight;
if self.centroids.len() > (self.compression * 8.0) as usize {
self.compress();
}
}
fn centroid_quantile(&self, index: usize) -> f64 {
let cumulative_before = self.centroids[..index].iter().map(|centroid| centroid.weight).sum::<f64>();
let centered = cumulative_before + self.centroids[index].weight * 0.5;
(centered / self.total_weight.max(1.0)).clamp(0.0, 1.0)
}
fn max_centroid_weight(&self, q: f64) -> f64 {
let scaled = (4.0 * self.total_weight / self.compression) * q * (1.0 - q);
scaled.max(1.0)
}
fn compress(&mut self) {
if self.centroids.len() <= 1 {
return;
}
self.centroids
.sort_unstable_by(|left, right| left.mean.total_cmp(&right.mean));
let old = std::mem::take(&mut self.centroids);
let mut merged: Vec<Centroid> = Vec::with_capacity(old.len());
let mut cumulative = 0.0;
for centroid in old {
if let Some(last) = merged.last_mut() {
let q = ((cumulative + 0.5 * last.weight) / self.total_weight.max(1.0)).clamp(0.0, 1.0);
let max_weight = self.max_centroid_weight(q);
if last.weight + centroid.weight <= max_weight {
let updated_weight = last.weight + centroid.weight;
last.mean += (centroid.mean - last.mean) * (centroid.weight / updated_weight);
last.weight = updated_weight;
continue;
}
cumulative += last.weight;
}
merged.push(centroid);
}
self.centroids = merged;
}
}
#[cfg(test)]
mod tests {
use super::TDigest;
#[test]
fn constructor_validates_compression() {
assert!(TDigest::new(5.0).is_err());
assert!(TDigest::new(50.0).is_ok());
}
#[test]
fn quantile_rejects_invalid_input() {
let digest = TDigest::new(100.0).unwrap();
assert!(digest.quantile(0.5).is_err());
assert!(digest.quantile(-0.1).is_err());
assert!(digest.quantile(1.1).is_err());
}
#[test]
fn median_estimate_is_reasonable() {
let mut digest = TDigest::new(120.0).unwrap();
for value in 0_u64..10_000 {
digest.add(value as f64);
}
let p50 = digest.quantile(0.5).unwrap();
assert!(p50 > 4_400.0 && p50 < 5_600.0, "p50={p50}");
}
#[test]
fn high_quantile_tracks_tail() {
let mut digest = TDigest::new(120.0).unwrap();
for value in 0_u64..10_000 {
digest.add(value as f64);
}
let p95 = digest.quantile(0.95).unwrap();
let p99 = digest.quantile(0.99).unwrap();
assert!(p95 <= p99, "p95={p95} p99={p99}");
assert!(p99 > 9_000.0);
}
#[test]
fn merge_combines_streams() {
let mut left = TDigest::new(80.0).unwrap();
let mut right = TDigest::new(80.0).unwrap();
for value in 0_u64..5_000 {
left.add(value as f64);
}
for value in 5_000_u64..10_000 {
right.add(value as f64);
}
left.merge(&right).unwrap();
let p90 = left.quantile(0.9).unwrap();
assert!(p90 > 8_000.0);
}
#[test]
fn merge_rejects_mismatched_compression() {
let mut left = TDigest::new(80.0).unwrap();
let right = TDigest::new(81.0).unwrap();
assert!(left.merge(&right).is_err());
}
#[test]
fn clear_resets_state() {
let mut digest = TDigest::new(50.0).unwrap();
digest.add(1.0);
digest.add(2.0);
digest.clear();
assert!(digest.is_empty());
assert!(digest.quantile(0.5).is_err());
}
}