asap_sketchlib 0.2.1

A high-performance sketching library for approximate stream processing
Documentation
use serde::{Deserialize, Serialize};

use crate::DataInput;

const DEFAULT_SEED: u64 = 0x9E37_79B9_7F4A_7C15;
const GAMMA: u64 = 0xBF58_476D_1CE4_E5B9;
const DELTA: u64 = 0x94D0_49BB_1331_11EB;

#[derive(Clone, Debug, Serialize, Deserialize)]
struct SampleEntry {
    priority: u64,
    value: f64,
}

impl SampleEntry {
    fn new(priority: u64, value: f64) -> Self {
        Self { priority, value }
    }
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct UniformSampling {
    sample_rate: f64,
    total_seen: u64,
    rng_state: u64,
    entries: Vec<SampleEntry>,
}

impl UniformSampling {
    pub fn new(sample_rate: f64) -> Self {
        Self::with_seed(sample_rate, DEFAULT_SEED)
    }

    pub fn with_seed(sample_rate: f64, seed: u64) -> Self {
        assert!(
            (0.0..=1.0).contains(&sample_rate) && sample_rate > 0.0,
            "uniform sampling rate must be within (0, 1]"
        );
        let init_state = if seed == 0 { DEFAULT_SEED } else { seed };
        Self {
            sample_rate,
            total_seen: 0,
            rng_state: init_state,
            entries: Vec::new(),
        }
    }

    pub fn sample_rate(&self) -> f64 {
        self.sample_rate
    }

    pub fn len(&self) -> usize {
        self.entries.len()
    }

    pub fn is_empty(&self) -> bool {
        self.entries.is_empty()
    }

    pub fn total_seen(&self) -> u64 {
        self.total_seen
    }

    pub fn samples(&self) -> Vec<f64> {
        self.entries.iter().map(|entry| entry.value).collect()
    }

    pub fn sample_at(&self, idx: usize) -> Option<f64> {
        self.entries.get(idx).map(|entry| entry.value)
    }

    pub fn update_input(&mut self, value: &DataInput<'_>) -> Result<(), &'static str> {
        match value {
            DataInput::I32(v) => {
                self.update(*v as f64);
                Ok(())
            }
            DataInput::I64(v) => {
                self.update(*v as f64);
                Ok(())
            }
            DataInput::U32(v) => {
                self.update(*v as f64);
                Ok(())
            }
            DataInput::U64(v) => {
                self.update(*v as f64);
                Ok(())
            }
            DataInput::F32(v) => {
                self.update(*v as f64);
                Ok(())
            }
            DataInput::F64(v) => {
                self.update(*v);
                Ok(())
            }
            _ => Err("UniformSampling only supports numeric inputs"),
        }
    }

    pub fn update(&mut self, value: f64) {
        self.total_seen = self.total_seen.saturating_add(1);
        let target_size = Self::target_size(self.total_seen, self.sample_rate);
        let priority = self.next_random();
        self.insert_entry(SampleEntry::new(priority, value));
        self.truncate_to(target_size);
    }

    pub fn merge(&mut self, other: &UniformSampling) -> Result<(), &'static str> {
        if (self.sample_rate - other.sample_rate).abs() > f64::EPSILON {
            return Err("Cannot merge uniform samplers with different sampling rates");
        }
        let combined_seen = self.total_seen.saturating_add(other.total_seen);
        let mut merged = Vec::with_capacity(self.entries.len() + other.entries.len());
        merged.extend(self.entries.iter().cloned());
        merged.extend(other.entries.iter().cloned());
        merged.sort_by_key(|a| a.priority);
        let target_size = Self::target_size(combined_seen, self.sample_rate);
        merged.truncate(target_size);
        self.entries = merged;
        self.total_seen = combined_seen;
        self.mix_state(other.rng_state);
        Ok(())
    }

    fn target_size(total_seen: u64, rate: f64) -> usize {
        if total_seen == 0 {
            0
        } else {
            ((total_seen as f64) * rate).ceil() as usize
        }
    }

    fn insert_entry(&mut self, entry: SampleEntry) {
        let idx = match self
            .entries
            .binary_search_by(|probe| probe.priority.cmp(&entry.priority))
        {
            Ok(position) | Err(position) => position,
        };
        self.entries.insert(idx, entry);
    }

    fn truncate_to(&mut self, target_size: usize) {
        while self.entries.len() > target_size {
            self.entries.pop();
        }
    }

    fn next_random(&mut self) -> u64 {
        self.rng_state = self.rng_state.wrapping_add(DEFAULT_SEED);
        let mut z = self.rng_state;
        z = (z ^ (z >> 30)).wrapping_mul(GAMMA);
        z = (z ^ (z >> 27)).wrapping_mul(DELTA);
        z ^ (z >> 31)
    }

    fn mix_state(&mut self, other: u64) {
        let mixed = self.rng_state ^ other.rotate_left(19);
        self.rng_state = if mixed == 0 { DEFAULT_SEED } else { mixed };
    }
}

#[cfg(test)]
mod tests {
    use super::UniformSampling;

    fn expected_size(rate: f64, total_seen: u64) -> usize {
        if total_seen == 0 {
            0
        } else {
            ((total_seen as f64) * rate).ceil() as usize
        }
    }

    #[test]
    fn sample_count_tracks_rate() {
        let mut sampler = UniformSampling::with_seed(0.4, 0xABC1);
        for (idx, value) in (0..10).enumerate() {
            sampler.update(value as f64);
            let seen = (idx + 1) as u64;
            assert_eq!(sampler.total_seen(), seen);
            assert_eq!(sampler.len(), expected_size(0.4, seen));
        }
    }

    #[test]
    fn samples_are_drawn_from_input_stream() {
        let mut sampler = UniformSampling::with_seed(0.25, 0xBEEFFACE);
        for value in 0..128 {
            sampler.update(value as f64);
        }
        assert_eq!(sampler.total_seen(), 128);
        assert_eq!(sampler.len(), expected_size(0.25, sampler.total_seen()));
        for value in sampler.samples() {
            assert!(value.floor() == value);
            assert!((0.0..128.0).contains(&value));
        }
    }

    #[test]
    fn merge_combines_samples_using_rate_based_target() {
        let mut left = UniformSampling::with_seed(0.2, 0xDEAD);
        for value in 0..64 {
            left.update(value as f64);
        }
        let mut right = UniformSampling::with_seed(0.2, 0xBEEF);
        for value in 100..200 {
            right.update(value as f64);
        }
        let mut combined = left.clone();
        combined.merge(&right).unwrap();
        assert_eq!(
            combined.total_seen(),
            left.total_seen() + right.total_seen()
        );
        assert_eq!(combined.len(), expected_size(0.2, combined.total_seen()));
        for value in combined.samples() {
            assert!(
                (0.0..64.0).contains(&value) || (100.0..200.0).contains(&value),
                "unexpected sample {value}"
            );
        }
    }

    #[test]
    fn merge_rejects_different_rates() {
        let mut left = UniformSampling::with_seed(0.1, 0x1);
        left.update(1.0);
        let mut right = UniformSampling::with_seed(0.2, 0x2);
        right.update(2.0);
        assert!(left.merge(&right).is_err());
    }

    #[test]
    fn sample_access_is_stable() {
        let mut sampler = UniformSampling::with_seed(0.5, 0xFACEFACE);
        for value in 0..20 {
            sampler.update(value as f64);
        }
        let snapshot = sampler.samples();
        for (idx, expected) in snapshot.iter().enumerate() {
            let value = sampler.sample_at(idx).expect("sample exists");
            assert_eq!(value, *expected);
        }
    }
}