use crate::core::DocId;
use super::{AggregationResult, Aggregator, AggregatorFactory, MetricResult};
use crate::segment::reader::SegmentReader;
#[derive(Clone, Copy, Debug)]
struct Centroid {
mean: f64,
weight: f64,
}
#[derive(Clone)]
pub struct TDigest {
centroids: Vec<Centroid>,
compression: f64,
total_weight: f64,
}
impl TDigest {
pub fn new(compression: f64) -> Self {
Self {
centroids: Vec::new(),
compression,
total_weight: 0.0,
}
}
pub fn add(&mut self, value: f64) {
self.centroids.push(Centroid {
mean: value,
weight: 1.0,
});
self.total_weight += 1.0;
if self.centroids.len() > (self.compression * 5.0) as usize {
self.compress();
}
}
pub fn merge(&mut self, other: &TDigest) {
self.centroids.extend_from_slice(&other.centroids);
self.total_weight += other.total_weight;
self.compress();
}
fn compress(&mut self) {
if self.centroids.is_empty() {
return;
}
self.centroids
.sort_by(|a, b| a.mean.partial_cmp(&b.mean).unwrap());
let mut compressed: Vec<Centroid> = Vec::new();
let mut weight_so_far = 0.0f64;
let mut current = self.centroids[0];
for &c in &self.centroids[1..] {
let q = (weight_so_far + current.weight) / self.total_weight;
let q_clamped = q.clamp(0.001, 0.999);
let max_weight = self.compression * (q_clamped * (1.0 - q_clamped)).sqrt();
if current.weight + c.weight <= max_weight.max(1.0) {
let new_weight = current.weight + c.weight;
current.mean = (current.mean * current.weight + c.mean * c.weight) / new_weight;
current.weight = new_weight;
} else {
weight_so_far += current.weight;
compressed.push(current);
current = c;
}
}
compressed.push(current);
self.centroids = compressed;
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut buf = Vec::with_capacity(16 + self.centroids.len() * 16);
buf.extend_from_slice(&self.compression.to_le_bytes());
buf.extend_from_slice(&self.total_weight.to_le_bytes());
buf.extend_from_slice(&(self.centroids.len() as u32).to_le_bytes());
for c in &self.centroids {
buf.extend_from_slice(&c.mean.to_le_bytes());
buf.extend_from_slice(&c.weight.to_le_bytes());
}
buf
}
pub fn from_bytes(data: &[u8]) -> Option<Self> {
if data.len() < 20 {
return None;
}
let compression = f64::from_le_bytes(data[0..8].try_into().ok()?);
let total_weight = f64::from_le_bytes(data[8..16].try_into().ok()?);
let count = u32::from_le_bytes(data[16..20].try_into().ok()?) as usize;
if data.len() < 20 + count * 16 {
return None;
}
let mut centroids = Vec::with_capacity(count);
let mut pos = 20;
for _ in 0..count {
let mean = f64::from_le_bytes(data[pos..pos + 8].try_into().ok()?);
let weight = f64::from_le_bytes(data[pos + 8..pos + 16].try_into().ok()?);
centroids.push(Centroid { mean, weight });
pos += 16;
}
Some(Self {
centroids,
compression,
total_weight,
})
}
pub fn percentile(&self, p: f64) -> f64 {
if self.centroids.is_empty() {
return 0.0;
}
if self.centroids.len() == 1 {
return self.centroids[0].mean;
}
let target = p / 100.0 * self.total_weight;
let mut cumulative = 0.0;
for (i, c) in self.centroids.iter().enumerate() {
if cumulative + c.weight >= target {
if i == 0 {
return c.mean;
}
let prev = &self.centroids[i - 1];
let prev_cumulative = cumulative;
let frac = if c.weight > 0.0 {
(target - prev_cumulative) / c.weight
} else {
0.0
};
return prev.mean + frac * (c.mean - prev.mean);
}
cumulative += c.weight;
}
self.centroids.last().unwrap().mean
}
}
pub struct PercentilesAggFactory {
pub field_name: String,
pub percents: Vec<f64>,
pub compression: f64,
}
impl AggregatorFactory for PercentilesAggFactory {
fn create_collector(&self, reader: &SegmentReader) -> Box<dyn Aggregator> {
let field_id = reader
.header()
.fields
.iter()
.find(|f| f.field_name == self.field_name)
.map(|f| f.field_id);
let col = super::bucket::OwnedColumn::new(field_id, reader);
Box::new(PercentilesCollector {
digest: TDigest::new(self.compression),
col,
percents: self.percents.clone(),
})
}
fn merge_results(&self, results: Vec<AggregationResult>) -> AggregationResult {
let mut merged = TDigest::new(self.compression);
let mut has_data = false;
for r in &results {
if let AggregationResult::Metric(m) = r {
if let Some(ref bytes) = m.merge_state {
if let Some(segment_digest) = TDigest::from_bytes(bytes) {
merged.merge(&segment_digest);
has_data = true;
}
}
}
}
if !has_data {
return AggregationResult::Metric(MetricResult::single(None));
}
merged.compress();
let mut result = MetricResult::single(None);
result.extra.insert("count".into(), merged.total_weight);
for &p in &self.percents {
let val = merged.percentile(p);
result.extra.insert(format!("{p}"), val);
}
AggregationResult::Metric(result)
}
}
struct PercentilesCollector {
digest: TDigest,
col: Option<super::bucket::OwnedColumn>,
percents: Vec<f64>,
}
unsafe impl Send for PercentilesCollector {}
impl Aggregator for PercentilesCollector {
fn collect(&mut self, doc_id: DocId) {
if let Some(v) = self
.col
.as_ref()
.and_then(|c| c.numeric_value(doc_id.as_u32()))
{
self.digest.add(v);
}
}
fn finish(self: Box<Self>) -> AggregationResult {
if self.digest.total_weight == 0.0 {
return AggregationResult::Metric(MetricResult::single(None));
}
let mut digest = self.digest;
digest.compress();
let mut result = MetricResult::single(None);
result.extra.insert("count".into(), digest.total_weight);
for &p in &self.percents {
let val = digest.percentile(p);
result.extra.insert(format!("{p}"), val);
}
result.merge_state = Some(digest.to_bytes());
AggregationResult::Metric(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tdigest_empty() {
let d = TDigest::new(100.0);
assert_eq!(d.percentile(50.0), 0.0);
}
#[test]
fn tdigest_single_value() {
let mut d = TDigest::new(100.0);
d.add(42.0);
assert_eq!(d.percentile(50.0), 42.0);
assert_eq!(d.percentile(1.0), 42.0);
assert_eq!(d.percentile(99.0), 42.0);
}
#[test]
fn tdigest_uniform_distribution() {
let mut d = TDigest::new(100.0);
for i in 0..10000 {
d.add(i as f64);
}
d.compress();
let p50 = d.percentile(50.0);
assert!((p50 - 5000.0).abs() < 500.0, "p50: {p50}");
let p99 = d.percentile(99.0);
assert!((p99 - 9900.0).abs() < 500.0, "p99: {p99}");
let p1 = d.percentile(1.0);
assert!((p1 - 100.0).abs() < 500.0, "p1: {p1}");
}
#[test]
fn tdigest_merge() {
let mut d1 = TDigest::new(100.0);
let mut d2 = TDigest::new(100.0);
for i in 0..5000 {
d1.add(i as f64);
}
for i in 5000..10000 {
d2.add(i as f64);
}
d1.merge(&d2);
let p50 = d1.percentile(50.0);
assert!((p50 - 5000.0).abs() < 500.0, "merged p50: {p50}");
}
}