lucisearch 0.8.0

Embeddable, in-process search engine — the SQLite/DuckDB of Elasticsearch
Documentation
//! T-digest approximate percentile estimation for the `percentiles` aggregation.
//!
//! Maintains a sorted list of centroids (mean, weight) that compress the
//! distribution, keeping more detail at the tails. Merging t-digests is
//! efficient — just merge the centroid lists.
//!
//! See [[feature-aggregations-v010]] and
//! [Computing Extremely Accurate Quantiles Using t-Digests](https://arxiv.org/abs/1902.04023).

use crate::core::DocId;

use super::{AggregationResult, Aggregator, AggregatorFactory, MetricResult};
use crate::segment::reader::SegmentReader;

/// A centroid in the t-digest: (mean, weight).
#[derive(Clone, Copy, Debug)]
struct Centroid {
    mean: f64,
    weight: f64,
}

/// T-digest data structure for approximate percentile estimation.
#[derive(Clone)]
pub struct TDigest {
    centroids: Vec<Centroid>,
    compression: f64,
    total_weight: f64,
}

impl TDigest {
    pub fn new(compression: f64) -> Self {
        Self {
            centroids: Vec::new(),
            compression,
            total_weight: 0.0,
        }
    }

    /// Add a single value.
    pub fn add(&mut self, value: f64) {
        self.centroids.push(Centroid {
            mean: value,
            weight: 1.0,
        });
        self.total_weight += 1.0;

        // Compress when we have too many centroids
        if self.centroids.len() > (self.compression * 5.0) as usize {
            self.compress();
        }
    }

    /// Merge another t-digest into this one.
    pub fn merge(&mut self, other: &TDigest) {
        self.centroids.extend_from_slice(&other.centroids);
        self.total_weight += other.total_weight;
        self.compress();
    }

    /// Compress centroids to maintain the compression invariant.
    /// Uses the k1 scale function from the t-digest paper:
    /// k(q) = (δ/2) * (asin(2q - 1) / π + 1/2)
    /// which maps [0,1] → [0, δ], concentrating resolution at tails.
    fn compress(&mut self) {
        if self.centroids.is_empty() {
            return;
        }

        self.centroids
            .sort_by(|a, b| a.mean.partial_cmp(&b.mean).unwrap());

        let mut compressed: Vec<Centroid> = Vec::new();
        let mut weight_so_far = 0.0f64;
        let mut current = self.centroids[0];

        for &c in &self.centroids[1..] {
            let q = (weight_so_far + current.weight) / self.total_weight;
            // k1 scale: max centroid size proportional to 1/(derivative of k at q)
            // derivative of asin(2q-1) = 2/sqrt(1-(2q-1)^2)
            // so max_weight ∝ sqrt(q*(1-q))
            let q_clamped = q.clamp(0.001, 0.999);
            let max_weight = self.compression * (q_clamped * (1.0 - q_clamped)).sqrt();

            if current.weight + c.weight <= max_weight.max(1.0) {
                let new_weight = current.weight + c.weight;
                current.mean = (current.mean * current.weight + c.mean * c.weight) / new_weight;
                current.weight = new_weight;
            } else {
                weight_so_far += current.weight;
                compressed.push(current);
                current = c;
            }
        }
        compressed.push(current);
        self.centroids = compressed;
    }

    /// Serialize centroids to bytes for merge transport.
    pub fn to_bytes(&self) -> Vec<u8> {
        let mut buf = Vec::with_capacity(16 + self.centroids.len() * 16);
        buf.extend_from_slice(&self.compression.to_le_bytes());
        buf.extend_from_slice(&self.total_weight.to_le_bytes());
        buf.extend_from_slice(&(self.centroids.len() as u32).to_le_bytes());
        for c in &self.centroids {
            buf.extend_from_slice(&c.mean.to_le_bytes());
            buf.extend_from_slice(&c.weight.to_le_bytes());
        }
        buf
    }

    /// Deserialize centroids from bytes.
    pub fn from_bytes(data: &[u8]) -> Option<Self> {
        if data.len() < 20 {
            return None;
        }
        let compression = f64::from_le_bytes(data[0..8].try_into().ok()?);
        let total_weight = f64::from_le_bytes(data[8..16].try_into().ok()?);
        let count = u32::from_le_bytes(data[16..20].try_into().ok()?) as usize;
        if data.len() < 20 + count * 16 {
            return None;
        }
        let mut centroids = Vec::with_capacity(count);
        let mut pos = 20;
        for _ in 0..count {
            let mean = f64::from_le_bytes(data[pos..pos + 8].try_into().ok()?);
            let weight = f64::from_le_bytes(data[pos + 8..pos + 16].try_into().ok()?);
            centroids.push(Centroid { mean, weight });
            pos += 16;
        }
        Some(Self {
            centroids,
            compression,
            total_weight,
        })
    }

    /// Estimate the value at a given percentile (0-100).
    pub fn percentile(&self, p: f64) -> f64 {
        if self.centroids.is_empty() {
            return 0.0;
        }
        if self.centroids.len() == 1 {
            return self.centroids[0].mean;
        }

        // Ensure sorted
        let target = p / 100.0 * self.total_weight;
        let mut cumulative = 0.0;

        for (i, c) in self.centroids.iter().enumerate() {
            if cumulative + c.weight >= target {
                if i == 0 {
                    return c.mean;
                }
                // Linear interpolation between centroids
                let prev = &self.centroids[i - 1];
                let prev_cumulative = cumulative;
                let frac = if c.weight > 0.0 {
                    (target - prev_cumulative) / c.weight
                } else {
                    0.0
                };
                return prev.mean + frac * (c.mean - prev.mean);
            }
            cumulative += c.weight;
        }

        self.centroids.last().unwrap().mean
    }
}

// --- Percentiles aggregation ---

pub struct PercentilesAggFactory {
    pub field_name: String,
    pub percents: Vec<f64>,
    pub compression: f64,
}

impl AggregatorFactory for PercentilesAggFactory {
    fn create_collector(&self, reader: &SegmentReader) -> Box<dyn Aggregator> {
        let field_id = reader
            .header()
            .fields
            .iter()
            .find(|f| f.field_name == self.field_name)
            .map(|f| f.field_id);

        let col = super::bucket::OwnedColumn::new(field_id, reader);

        Box::new(PercentilesCollector {
            digest: TDigest::new(self.compression),
            col,
            percents: self.percents.clone(),
        })
    }

    fn merge_results(&self, results: Vec<AggregationResult>) -> AggregationResult {
        let mut merged = TDigest::new(self.compression);
        let mut has_data = false;

        for r in &results {
            if let AggregationResult::Metric(m) = r {
                if let Some(ref bytes) = m.merge_state {
                    if let Some(segment_digest) = TDigest::from_bytes(bytes) {
                        merged.merge(&segment_digest);
                        has_data = true;
                    }
                }
            }
        }

        if !has_data {
            return AggregationResult::Metric(MetricResult::single(None));
        }

        merged.compress();
        let mut result = MetricResult::single(None);
        result.extra.insert("count".into(), merged.total_weight);
        for &p in &self.percents {
            let val = merged.percentile(p);
            result.extra.insert(format!("{p}"), val);
        }
        AggregationResult::Metric(result)
    }
}

struct PercentilesCollector {
    digest: TDigest,
    col: Option<super::bucket::OwnedColumn>,
    percents: Vec<f64>,
}

unsafe impl Send for PercentilesCollector {}

impl Aggregator for PercentilesCollector {
    fn collect(&mut self, doc_id: DocId) {
        if let Some(v) = self
            .col
            .as_ref()
            .and_then(|c| c.numeric_value(doc_id.as_u32()))
        {
            self.digest.add(v);
        }
    }

    fn finish(self: Box<Self>) -> AggregationResult {
        if self.digest.total_weight == 0.0 {
            return AggregationResult::Metric(MetricResult::single(None));
        }

        // Force final compression
        let mut digest = self.digest;
        digest.compress();

        let mut result = MetricResult::single(None);
        result.extra.insert("count".into(), digest.total_weight);
        for &p in &self.percents {
            let val = digest.percentile(p);
            result.extra.insert(format!("{p}"), val);
        }
        result.merge_state = Some(digest.to_bytes());
        AggregationResult::Metric(result)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn tdigest_empty() {
        let d = TDigest::new(100.0);
        assert_eq!(d.percentile(50.0), 0.0);
    }

    #[test]
    fn tdigest_single_value() {
        let mut d = TDigest::new(100.0);
        d.add(42.0);
        assert_eq!(d.percentile(50.0), 42.0);
        assert_eq!(d.percentile(1.0), 42.0);
        assert_eq!(d.percentile(99.0), 42.0);
    }

    #[test]
    fn tdigest_uniform_distribution() {
        let mut d = TDigest::new(100.0);
        for i in 0..10000 {
            d.add(i as f64);
        }
        d.compress();

        let p50 = d.percentile(50.0);
        assert!((p50 - 5000.0).abs() < 500.0, "p50: {p50}");

        let p99 = d.percentile(99.0);
        assert!((p99 - 9900.0).abs() < 500.0, "p99: {p99}");

        let p1 = d.percentile(1.0);
        assert!((p1 - 100.0).abs() < 500.0, "p1: {p1}");
    }

    #[test]
    fn tdigest_merge() {
        let mut d1 = TDigest::new(100.0);
        let mut d2 = TDigest::new(100.0);
        for i in 0..5000 {
            d1.add(i as f64);
        }
        for i in 5000..10000 {
            d2.add(i as f64);
        }
        d1.merge(&d2);

        let p50 = d1.percentile(50.0);
        assert!((p50 - 5000.0).abs() < 500.0, "merged p50: {p50}");
    }
}