use rmp_serde::decode::Error as RmpDecodeError;
use rmp_serde::encode::Error as RmpEncodeError;
use serde::{Deserialize, Serialize};
use crate::common::input::data_input_to_f64;
use crate::common::numerical::NumericalValue;
use crate::{DataInput, Vector1D};
use super::kll::Coin;
const CAPACITY_CACHE_LEN: usize = 20;
const MAX_CACHEABLE_K: usize = 26_602;
const CAPACITY_DECAY: f64 = 2.0 / 3.0;
const DEFAULT_K: i32 = 200;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct DynamicCdfEntry {
value: f64,
quantile: f64,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct KLLDynamic<T: NumericalValue = f64> {
items: Vector1D<T>, levels: Vector1D<usize>,
k: usize,
m: usize, num_levels: usize,
co: Coin,
#[serde(skip)]
capacity_cache: [u32; CAPACITY_CACHE_LEN],
#[serde(skip)]
top_height: usize,
#[serde(skip)]
level0_capacity: usize,
}
impl<T: NumericalValue> Default for KLLDynamic<T> {
fn default() -> Self {
Self::init_kll(DEFAULT_K)
}
}
impl<T: NumericalValue> KLLDynamic<T> {
pub fn init(k: usize, m: usize) -> Self {
let mut norm_m = m.min(MAX_CACHEABLE_K);
norm_m = norm_m.max(2);
let mut norm_k = k.max(norm_m);
if norm_k > MAX_CACHEABLE_K {
norm_k = MAX_CACHEABLE_K;
}
let mut s = Self {
items: Vector1D::init(norm_k * 3),
levels: Vector1D::filled(2, 0),
k: norm_k,
m: norm_m,
num_levels: 1,
co: Coin::new(),
capacity_cache: [0; CAPACITY_CACHE_LEN],
top_height: 0,
level0_capacity: 0,
};
s.rebuild_capacity_cache();
s
}
pub fn init_kll(k: i32) -> Self {
Self::init(k as usize, 8)
}
fn push_value(&mut self, value: T) {
self.items.push(value);
if let Some(last) = self.levels.last_mut() {
*last = self.items.len();
}
let levels_slice = self.levels.as_slice();
let l0_start = levels_slice[self.num_levels - 1];
let l0_count = self.items.len() - l0_start;
if l0_count > self.level0_capacity {
self.compress_while_needed();
}
}
pub fn update(&mut self, val: &T) {
self.push_value(*val);
}
fn compress_while_needed(&mut self) {
let mut h = 0;
loop {
let level_idx = self.num_levels - 1 - h;
let cap = self.capacity_for_level(h);
let size = self.level_size(h);
if size <= cap {
break;
}
if level_idx == 0 {
self.add_new_top_level();
continue;
}
self.compact(h);
h += 1;
}
}
fn capacity_for_level(&self, level: usize) -> usize {
if self.num_levels == 0 {
return self.m;
}
let height_from_top = self.top_height.saturating_sub(level);
let idx = height_from_top.min(CAPACITY_CACHE_LEN - 1);
self.capacity_cache[idx] as usize
}
fn rebuild_capacity_cache(&mut self) {
self.top_height = self.num_levels.saturating_sub(1);
let mut scale = 1.0_f64;
for idx in 0..CAPACITY_CACHE_LEN {
let scaled = ((self.k as f64) * scale).ceil() as usize;
let cap = scaled.max(self.m);
self.capacity_cache[idx] = cap as u32;
scale *= CAPACITY_DECAY;
}
self.level0_capacity = self.capacity_for_level(0);
}
#[inline]
fn level_size(&self, h: usize) -> usize {
let idx = self.num_levels - 1 - h;
let slice = self.levels.as_slice();
slice[idx + 1] - slice[idx]
}
fn add_new_top_level(&mut self) {
self.levels.insert(0, 0);
if let Some(last) = self.levels.last_mut() {
*last = self.items.len();
}
self.num_levels += 1;
self.top_height = self.num_levels - 1;
self.level0_capacity = self.capacity_for_level(0);
}
fn compact(&mut self, h: usize) {
let cur_lvl_idx = self.num_levels - 1 - h;
let levels_slice = self.levels.as_mut_slice();
let start = levels_slice[cur_lvl_idx];
let end = levels_slice[cur_lvl_idx + 1];
let count = end - start;
let items = self.items.as_mut_slice();
items[start..end].sort_unstable_by(T::total_cmp);
let offset = usize::from(self.co.toss());
let mut survivors = 0;
let mut i = offset;
while i < count {
items[start + survivors] = items[start + i];
survivors += 1;
i += 2;
}
let garbage_len = count - survivors;
let start_garbage = start + survivors;
let end_garbage = end;
let tail_len = items.len() - end_garbage;
if tail_len > 0 {
unsafe {
let ptr = items.as_mut_ptr();
std::ptr::copy(ptr.add(end_garbage), ptr.add(start_garbage), tail_len);
}
}
let new_len = items.len() - garbage_len;
self.items.truncate(new_len);
let levels_slice = self.levels.as_mut_slice();
levels_slice[cur_lvl_idx] = start + survivors;
for pos in levels_slice
.iter_mut()
.take(self.num_levels + 1)
.skip(cur_lvl_idx + 1)
{
*pos -= garbage_len;
}
levels_slice[self.num_levels] = self.items.len();
}
pub fn clear(&mut self) {
self.items.clear();
self.levels = Vector1D::filled(2, 0);
self.num_levels = 1;
self.co = Coin::new();
self.rebuild_capacity_cache();
}
pub fn print_compactors(&self)
where
T: std::fmt::Debug,
{
println!(
"KLLDynamic Packed (k={}, levels={}, items={})",
self.k,
self.num_levels,
self.items.len()
);
let levels = self.levels.as_slice();
let items = self.items.as_slice();
for h in (0..self.num_levels).rev() {
let idx = self.num_levels - 1 - h;
let start = levels[idx];
let end = levels[idx + 1];
println!(" L{}: {:?}", h, &items[start..end]);
}
}
pub fn cdf(&self) -> DynamicCdf {
let mut cdf = DynamicCdf {
entries: Vector1D::init(self.buffer_size()),
};
let mut total_w = 0usize;
let levels = self.levels.as_slice();
let items = self.items.as_slice();
for h in 0..self.num_levels {
let idx = self.num_levels - 1 - h;
let start = levels[idx];
let end = levels[idx + 1];
let weight = 1 << h;
for &value in &items[start..end] {
cdf.entries.push(DynamicCdfEntry {
value: value.to_f64(),
quantile: weight as f64,
});
}
total_w += (end - start) * weight;
}
if total_w == 0 {
return cdf;
}
cdf.entries
.as_mut_slice()
.sort_by(|a, b| a.value.partial_cmp(&b.value).unwrap());
let mut cur_w = 0.0;
for entry in cdf.entries.as_mut_slice() {
cur_w += entry.quantile;
entry.quantile = cur_w / total_w as f64;
}
cdf
}
pub fn merge(&mut self, other: &KLLDynamic<T>) {
for &value in other.items.as_slice() {
self.push_value(value);
}
}
pub fn quantile(&self, q: f64) -> f64 {
let cdf = self.cdf();
cdf.query(q)
}
pub fn rank(&self, x: f64) -> usize {
let mut r = 0;
let levels = self.levels.as_slice();
let items = self.items.as_slice();
for h in 0..self.num_levels {
let idx = self.num_levels - 1 - h;
let start = levels[idx];
let end = levels[idx + 1];
let weight = 1 << h;
for &val in &items[start..end] {
if val.to_f64() <= x {
r += weight;
}
}
}
r
}
pub fn count(&self) -> usize {
let mut total = 0;
for h in 0..self.num_levels {
total += self.level_size(h) * (1 << h);
}
total
}
fn buffer_size(&self) -> usize {
self.items.len()
}
pub fn serialize_to_bytes(&self) -> Result<Vec<u8>, RmpEncodeError>
where
T: Serialize,
{
rmp_serde::to_vec(self)
}
pub fn deserialize_from_bytes(bytes: &[u8]) -> Result<Self, RmpDecodeError>
where
T: for<'de> Deserialize<'de>,
{
rmp_serde::from_slice(bytes).map(|mut sketch: KLLDynamic<T>| {
sketch.rebuild_capacity_cache();
sketch
})
}
}
impl KLLDynamic<f64> {
pub fn update_data_input(&mut self, val: &DataInput) -> Result<(), &'static str> {
let value = data_input_to_f64(val)?;
self.push_value(value);
Ok(())
}
}
pub struct DynamicCdf {
entries: Vector1D<DynamicCdfEntry>,
}
impl DynamicCdf {
pub fn quantile(&self, x: f64) -> f64 {
if self.entries.is_empty() {
return 0.0;
}
let slice = self.entries.as_slice();
match slice
.binary_search_by(|e| e.value.partial_cmp(&x).unwrap_or(std::cmp::Ordering::Less))
{
Ok(idx) => slice[idx].quantile,
Err(0) => 0.0,
Err(idx) => slice[idx - 1].quantile,
}
}
pub fn print_entries(&self) {
println!("entries: {:?}", self.entries);
}
pub fn query(&self, p: f64) -> f64 {
if self.entries.is_empty() {
return 0.0;
}
let slice = self.entries.as_slice();
match slice.binary_search_by(|e| {
e.quantile
.partial_cmp(&p)
.unwrap_or(std::cmp::Ordering::Less)
}) {
Ok(idx) => slice[idx].value,
Err(idx) if idx == slice.len() => slice[slice.len() - 1].value,
Err(idx) => slice[idx].value,
}
}
pub fn quantile_li(&self, x: f64) -> f64 {
let slice = self.entries.as_slice();
if slice.is_empty() {
return 0.0;
}
let idx = slice.partition_point(|e| e.value < x);
if idx == slice.len() {
return 1.0;
}
if idx == 0 {
return 0.0;
}
let a = slice[idx - 1].value;
let aq = slice[idx - 1].quantile;
let b = slice[idx].value;
let bq = slice[idx].quantile;
((a - x) * bq + (x - b) * aq) / (a - b)
}
pub fn query_li(&self, p: f64) -> f64 {
let slice = self.entries.as_slice();
if slice.is_empty() {
return 0.0;
}
let idx = slice.partition_point(|e| e.quantile < p);
if idx == slice.len() {
return slice[slice.len() - 1].value;
}
if idx == 0 {
return slice[0].value;
}
let a = slice[idx - 1].value;
let aq = slice[idx - 1].quantile;
let b = slice[idx].value;
let bq = slice[idx].quantile;
((aq - p) * b + (p - bq) * a) / (aq - bq)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::{sample_uniform_f64, sample_zipf_f64};
#[derive(Clone, Copy)]
enum TestDistribution {
Uniform {
min: f64,
max: f64,
},
Zipf {
min: f64,
max: f64,
domain: usize,
exponent: f64,
},
}
const SKETCH_K: i32 = 200;
fn build_kll_with_distribution(
k: i32,
sample_size: usize,
distribution: TestDistribution,
seed: u64,
) -> (KLLDynamic, Vec<f64>) {
let mut sketch = KLLDynamic::init_kll(k);
let values = match distribution {
TestDistribution::Uniform { min, max } => {
sample_uniform_f64(min, max, sample_size, seed)
}
TestDistribution::Zipf {
min,
max,
domain,
exponent,
} => sample_zipf_f64(min, max, domain, exponent, sample_size, seed),
};
for &value in &values {
sketch.update(&value);
}
(sketch, values)
}
fn quantile_from_sorted(data: &[f64], quantile: f64) -> f64 {
assert!(!data.is_empty(), "data set must not be empty");
if quantile <= 0.0 {
return data[0];
}
if quantile >= 1.0 {
return data[data.len() - 1];
}
let n = data.len();
let idx = ((quantile * n as f64).ceil() as isize - 1).clamp(0, (n - 1) as isize) as usize;
data[idx]
}
fn assert_quantiles_within_error(
sketch: &KLLDynamic,
sorted_truth: &[f64],
quantiles: &[(f64, &str)],
tolerance: f64,
context: &str,
sample_size: usize,
seed: u64,
) {
let cdf = sketch.cdf();
for &(quantile, label) in quantiles {
let lower_q = (quantile - tolerance).max(0.0);
let upper_q = (quantile + tolerance).min(1.0);
let truth_min = quantile_from_sorted(sorted_truth, lower_q);
let truth_max = quantile_from_sorted(sorted_truth, upper_q);
let estimate = cdf.query(quantile);
assert!(
(truth_min..=truth_max).contains(&estimate),
"{label} exceeded tolerance: context={context}, sample_size={sample_size}, seed=0x{seed:08x}, \
quantile={quantile:.4}, truth_min={truth_min:.4}, truth_max={truth_max:.4}, \
estimate={estimate:.4}, tolerance={tolerance:.4}, total_length={}",
sorted_truth.len()
);
}
}
#[test]
fn distributions_quantiles_stay_within_rank_error() {
const TOLERANCE: f64 = 0.02;
const SAMPLE_SIZES: &[usize] = &[1_000, 5_000, 20_000, 100_000, 1_000_000, 5_000_000];
const QUANTILES: &[(f64, &str)] = &[
(0.0, "min"),
(0.10, "p10"),
(0.25, "p25"),
(0.50, "p50"),
(0.75, "p75"),
(0.90, "p90"),
(1.0, "max"),
];
struct Case {
name: &'static str,
distribution: TestDistribution,
seed_base: u64,
}
let cases = [
Case {
name: "uniform",
distribution: TestDistribution::Uniform {
min: 0.0,
max: 100_000_000.0,
},
seed_base: 0xA5A5_0000,
},
Case {
name: "zipf",
distribution: TestDistribution::Zipf {
min: 1_000_000.0,
max: 10_000_000.0,
domain: 8_192,
exponent: 1.1,
},
seed_base: 0xB4B4_0000,
},
];
for case in cases {
for (idx, &sample_size) in SAMPLE_SIZES.iter().enumerate() {
let seed = case.seed_base + idx as u64;
let (sketch, mut values) =
build_kll_with_distribution(SKETCH_K, sample_size, case.distribution, seed);
values.sort_by(|a, b| a.partial_cmp(b).unwrap());
assert_quantiles_within_error(
&sketch,
&values,
QUANTILES,
TOLERANCE,
case.name,
sample_size,
seed,
);
}
}
}
#[test]
fn test_data_input_api() {
let mut kll = KLLDynamic::init_kll(128);
kll.update_data_input(&DataInput::I32(10)).unwrap();
kll.update_data_input(&DataInput::I64(20)).unwrap();
kll.update_data_input(&DataInput::F64(30.5)).unwrap();
kll.update_data_input(&DataInput::F32(40.2)).unwrap();
kll.update_data_input(&DataInput::U32(50)).unwrap();
let cdf = kll.cdf();
let median = cdf.query(0.5);
assert!(median > 20.0 && median < 40.2, "Median = {}", median);
let result = kll.update_data_input(&DataInput::String("not a number".to_string()));
assert!(result.is_err());
assert_eq!(
result.unwrap_err(),
"KLL sketch only accepts numeric inputs"
);
}
#[test]
fn test_forced_compact() {
let mut kll = KLLDynamic::init(3, 3);
kll.update(&10.0);
kll.update(&20.0);
kll.update(&30.0);
kll.update(&40.0);
kll.update(&50.0);
let cdf = kll.cdf();
let median = cdf.query(0.5);
assert!(median == 30.0 || median == 40.0, "Median = {}", median);
}
#[test]
fn test_no_compact() {
let mut kll = KLLDynamic::init_kll(8);
kll.update(&10.0);
kll.update(&20.0);
kll.update(&30.0);
kll.update(&40.0);
kll.update(&50.0);
let cdf = kll.cdf();
let median = cdf.query(0.5);
assert!(median == 30.0, "Median = {}", median);
}
#[test]
fn merge_preserves_quantiles_within_tolerance() {
const TOLERANCE: f64 = 0.02;
const QUANTILES: &[(f64, &str)] = &[
(0.0, "min"),
(0.10, "p10"),
(0.25, "p25"),
(0.50, "p50"),
(0.75, "p75"),
(0.90, "p90"),
(1.0, "max"),
];
let values = sample_uniform_f64(1_000_000.0, 10_000_000.0, 10_000, 0xC0FFEE);
let mut sketch_a = KLLDynamic::init_kll(SKETCH_K);
let mut sketch_b = KLLDynamic::init_kll(SKETCH_K);
for (idx, value) in values.iter().copied().enumerate() {
if idx % 2 == 0 {
sketch_a.update(&value);
} else {
sketch_b.update(&value);
}
}
sketch_a.merge(&sketch_b);
let mut sorted = values.clone();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
assert_quantiles_within_error(
&sketch_a,
&sorted,
QUANTILES,
TOLERANCE,
"merge",
values.len(),
0x00C0_FFEE,
);
}
#[test]
fn cdf_handles_empty_sketch() {
let sketch = KLLDynamic::<f64>::init_kll(64);
let cdf = sketch.cdf();
assert_eq!(cdf.quantile(123.0), 0.0);
assert_eq!(cdf.query(0.5), 0.0);
assert_eq!(cdf.query_li(0.5), 0.0);
}
#[test]
fn kll_dynamic_round_trip_rmp() {
let mut sketch = KLLDynamic::init_kll(256);
let samples = sample_uniform_f64(0.0, 1_000_000.0, 5_000, 0xDEAD_BEEF);
for value in &samples {
sketch.update(value);
}
let bytes = sketch
.serialize_to_bytes()
.expect("serialize KLLDynamic with rmp");
assert!(!bytes.is_empty(), "serialized bytes should not be empty");
let restored =
KLLDynamic::deserialize_from_bytes(&bytes).expect("deserialize KLLDynamic with rmp");
assert_eq!(sketch.k, restored.k);
assert_eq!(sketch.m, restored.m);
assert_eq!(sketch.num_levels, restored.num_levels);
assert_eq!(sketch.top_height, restored.top_height);
assert_eq!(sketch.level0_capacity, restored.level0_capacity);
assert_eq!(
sketch.levels.as_slice(),
restored.levels.as_slice(),
"level boundaries changed after round-trip"
);
assert_eq!(
sketch.items.as_slice(),
restored.items.as_slice(),
"packed items changed after round-trip"
);
let quantiles = [0.0, 0.1, 0.25, 0.5, 0.75, 0.9, 1.0];
let original_cdf = sketch.cdf();
let restored_cdf = restored.cdf();
for &q in &quantiles {
assert!(
(original_cdf.query(q) - restored_cdf.query(q)).abs() < f64::EPSILON,
"quantile mismatch at p={q}: original={}, restored={}",
original_cdf.query(q),
restored_cdf.query(q)
);
}
}
#[test]
fn generic_kll_dynamic_i64_sanity() {
let mut sketch = KLLDynamic::<i64>::init_kll(200);
let n: i64 = 20_000;
for v in 1..=n {
sketch.update(&v);
}
let count = sketch.count() as f64;
assert!(
(count - n as f64).abs() / (n as f64) < 0.05,
"count={count} diverged from n={n}"
);
let cdf = sketch.cdf();
let p50 = cdf.query(0.5);
let p90 = cdf.query(0.9);
let tol = n as f64 * 0.02;
assert!(
(p50 - (n as f64 * 0.5)).abs() < tol,
"p50={p50} out of range for n={n}"
);
assert!(
(p90 - (n as f64 * 0.9)).abs() < tol,
"p90={p90} out of range for n={n}"
);
let bytes = sketch
.serialize_to_bytes()
.expect("serialize KLLDynamic<i64>");
let restored =
KLLDynamic::<i64>::deserialize_from_bytes(&bytes).expect("deserialize KLLDynamic<i64>");
assert_eq!(sketch.count(), restored.count());
}
}