use crate::UnivMon;
use crate::common::heap::HHHeap;
use crate::common::{BOTTOM_LAYER_FINDER, DataInput, hash_item64_seeded, hash64_seeded};
use crate::common::{L2HH, Vector1D};
use crate::sketches::countsketch_topk::CountL2HH;
pub struct UnivSketchPool {
free_list: Vec<UnivMon>,
total_allocated: usize,
heap_size: usize,
sketch_row: usize,
sketch_col: usize,
layer_size: usize,
}
impl UnivSketchPool {
pub fn new(
cap: usize,
heap_size: usize,
sketch_row: usize,
sketch_col: usize,
layer_size: usize,
) -> Self {
let free_list: Vec<UnivMon> = (0..cap)
.map(|_| UnivMon::init_univmon(heap_size, sketch_row, sketch_col, layer_size))
.collect();
UnivSketchPool {
free_list,
total_allocated: cap,
heap_size,
sketch_row,
sketch_col,
layer_size,
}
}
pub fn take(&mut self) -> UnivMon {
if let Some(sketch) = self.free_list.pop() {
sketch
} else {
self.total_allocated += 1;
UnivMon::init_univmon(
self.heap_size,
self.sketch_row,
self.sketch_col,
self.layer_size,
)
}
}
pub fn put(&mut self, mut sketch: UnivMon) {
sketch.free();
self.free_list.push(sketch);
}
pub fn available(&self) -> usize {
self.free_list.len()
}
pub fn total_allocated(&self) -> usize {
self.total_allocated
}
}
const DEFAULT_ELEPHANT_LAYERS: usize = 8;
const DEFAULT_ELEPHANT_ROW: usize = 3;
const DEFAULT_ELEPHANT_COL: usize = 2048;
const DEFAULT_MOUSE_ROW: usize = 3;
const DEFAULT_MOUSE_COL: usize = 512;
const DEFAULT_PYRAMID_HEAP: usize = 32;
const DEFAULT_PYRAMID_LAYERS: usize = 16;
#[derive(Clone, Debug)]
pub struct UnivMonPyramid {
pub l2_sketch_layers: Vector1D<L2HH>,
pub hh_layers: Vector1D<HHHeap>,
pub layer_size: usize,
pub elephant_layers: usize,
pub elephant_row: usize,
pub elephant_col: usize,
pub mouse_row: usize,
pub mouse_col: usize,
pub heap_size: usize,
pub bucket_size: usize,
}
impl UnivMonPyramid {
pub fn new(
heap_size: usize,
elephant_layers: usize,
elephant_row: usize,
elephant_col: usize,
mouse_row: usize,
mouse_col: usize,
total_layers: usize,
) -> Self {
let sk_vec: Vec<L2HH> = if total_layers <= elephant_layers {
(0..total_layers)
.map(|i| {
L2HH::COUNT(CountL2HH::with_dimensions_and_seed(
elephant_row,
elephant_col,
i,
))
})
.collect()
} else {
(0..elephant_layers)
.map(|i| {
L2HH::COUNT(CountL2HH::with_dimensions_and_seed(
elephant_row,
elephant_col,
i,
))
})
.chain((elephant_layers..total_layers).map(|i| {
L2HH::COUNT(CountL2HH::with_dimensions_and_seed(mouse_row, mouse_col, i))
}))
.collect()
};
let hh_vec: Vec<HHHeap> = (0..total_layers).map(|_| HHHeap::new(heap_size)).collect();
UnivMonPyramid {
l2_sketch_layers: Vector1D::from_vec(sk_vec),
hh_layers: Vector1D::from_vec(hh_vec),
layer_size: total_layers,
elephant_layers,
elephant_row,
elephant_col,
mouse_row,
mouse_col,
heap_size,
bucket_size: 0,
}
}
pub fn with_defaults() -> Self {
Self::new(
DEFAULT_PYRAMID_HEAP,
DEFAULT_ELEPHANT_LAYERS,
DEFAULT_ELEPHANT_ROW,
DEFAULT_ELEPHANT_COL,
DEFAULT_MOUSE_ROW,
DEFAULT_MOUSE_COL,
DEFAULT_PYRAMID_LAYERS,
)
}
#[inline(always)]
fn find_bottom_layer_num(&self, hash: u64) -> usize {
for l in 1..self.layer_size {
if ((hash >> l) & 1) == 0 {
return l - 1;
}
}
self.layer_size - 1
}
pub fn insert(&mut self, key: &DataInput, value: i64) {
self.bucket_size += value as usize;
let h = hash64_seeded(BOTTOM_LAYER_FINDER, key);
let bottom = self.find_bottom_layer_num(h);
for i in 0..=bottom {
let count = if i == 0 {
self.l2_sketch_layers[i].update_and_est(key, value)
} else {
self.l2_sketch_layers[i].update_and_est_without_l2(key, value)
};
self.hh_layers[i].update(key, count as i64);
}
}
pub fn fast_insert(&mut self, key: &DataInput, value: i64) {
self.bucket_size += value as usize;
let h = hash64_seeded(BOTTOM_LAYER_FINDER, key);
let bottom = self.find_bottom_layer_num(h);
if bottom < self.elephant_layers {
if bottom > 0 {
let count = self.l2_sketch_layers[bottom].update_and_est_without_l2(key, value);
for l in (1..=bottom).rev() {
self.hh_layers[l].update(key, count as i64);
}
let count0 = self.l2_sketch_layers[0].update_and_est(key, value);
self.hh_layers[0].update(key, count0 as i64);
} else {
let count0 = self.l2_sketch_layers[0].update_and_est(key, value);
self.hh_layers[0].update(key, count0 as i64);
}
} else {
let count = self.l2_sketch_layers[bottom].update_and_est_without_l2(key, value);
for l in (1..=bottom).rev() {
self.hh_layers[l].update(key, count as i64);
}
let count0 = self.l2_sketch_layers[0].update_and_est(key, value);
self.hh_layers[0].update(key, count0 as i64);
}
}
pub fn calc_g_sum<F>(&self, g: F, is_card: bool) -> f64
where
F: Fn(f64) -> f64,
{
let mut y = vec![0.0; self.layer_size];
let l2_top = self.l2_sketch_layers[self.layer_size - 1].get_l2();
let threshold_top = if is_card { (l2_top * 0.01) as i64 } else { 0 };
let mut tmp = 0.0;
for item in self.hh_layers[self.layer_size - 1].heap() {
if item.count > threshold_top {
tmp += g(item.count as f64);
}
}
y[self.layer_size - 1] = tmp;
for i in (0..(self.layer_size - 1)).rev() {
tmp = 0.0;
let l2_val = self.l2_sketch_layers[i].get_l2();
let threshold = if is_card { (l2_val * 0.01) as i64 } else { 0 };
for item in self.hh_layers[i].heap() {
if item.count > threshold {
let hash = (hash_item64_seeded(BOTTOM_LAYER_FINDER, &item.key) >> (i + 1)) & 1;
let coe = 1.0 - 2.0 * (hash as f64);
tmp += coe * g(item.count as f64);
}
}
y[i] = 2.0 * y[i + 1] + tmp;
}
y[0]
}
pub fn calc_l1(&self) -> f64 {
self.calc_g_sum(|x| x, false)
}
pub fn calc_l2(&self) -> f64 {
self.calc_g_sum(|x| x * x, false).sqrt()
}
pub fn calc_entropy(&self) -> f64 {
let tmp = self.calc_g_sum(|x| if x > 0.0 { x * x.log2() } else { 0.0 }, false);
(self.bucket_size as f64).log2() - tmp / (self.bucket_size as f64)
}
pub fn calc_card(&self) -> f64 {
self.calc_g_sum(|_| 1.0, true)
}
pub fn free(&mut self) {
self.bucket_size = 0;
for i in 0..self.layer_size {
self.l2_sketch_layers[i].clear();
self.hh_layers[i].clear();
}
}
pub fn merge(&mut self, other: &UnivMonPyramid) {
assert_eq!(
self.layer_size, other.layer_size,
"layer size must match for merge"
);
for i in 0..self.layer_size {
self.l2_sketch_layers[i].merge(&other.l2_sketch_layers[i]);
for item in other.hh_layers[i].heap() {
let count = if let Some(index) = self.hh_layers[i].find_heap_item(&item.key) {
self.hh_layers[i].heap()[index].count + item.count
} else {
item.count
};
self.hh_layers[i].update_heap_item(&item.key, count);
}
}
}
pub fn heap_at_layer(&mut self, layer: usize) -> &mut HHHeap {
&mut self.hh_layers[layer]
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::DataInput;
#[test]
fn pool_basic_take_put() {
let mut pool = UnivSketchPool::new(2, 16, 2, 5, 2);
assert_eq!(pool.available(), 2);
assert_eq!(pool.total_allocated(), 2);
let s0 = pool.take();
assert_eq!(pool.available(), 1);
let s1 = pool.take();
assert_eq!(pool.available(), 0);
let s2 = pool.take();
assert_eq!(pool.available(), 0);
assert_eq!(pool.total_allocated(), 3);
pool.put(s1);
assert_eq!(pool.available(), 1);
let s3 = pool.take();
assert_eq!(pool.available(), 0);
assert_eq!(pool.total_allocated(), 3);
pool.put(s0);
pool.put(s2);
pool.put(s3);
assert_eq!(pool.available(), 3);
}
#[test]
fn pool_free_resets_sketch() {
let mut pool = UnivSketchPool::new(1, 16, 2, 5, 2);
let mut sketch = pool.take();
sketch.insert(&DataInput::I64(42), 100);
assert!(sketch.bucket_size > 0);
pool.put(sketch);
let sketch2 = pool.take();
assert_eq!(sketch2.bucket_size, 0);
assert!((sketch2.l2_sketch_layers[0].get_l2()).abs() < 1e-9);
}
#[test]
fn pyramid_basic_insert_and_query() {
let mut um = UnivMonPyramid::with_defaults();
let cases: Vec<(&str, i64)> = vec![("hello", 10), ("world", 20), ("hello", 5), ("foo", 30)];
for (key, val) in &cases {
um.insert(&DataInput::Str(key), *val);
}
assert_eq!(um.bucket_size, 65);
assert!((um.calc_l1() - 65.0).abs() < 1e-6, "L1 = {}", um.calc_l1());
assert_eq!(um.calc_card(), 3.0);
}
#[test]
fn pyramid_fast_insert_matches_standard() {
let mut standard = UnivMonPyramid::new(32, 8, 3, 2048, 3, 512, 16);
let mut fast = UnivMonPyramid::new(32, 8, 3, 2048, 3, 512, 16);
for i in 0..500i64 {
let key = DataInput::I64(i % 100);
standard.insert(&key, 1);
fast.fast_insert(&key, 1);
}
assert_eq!(standard.bucket_size, fast.bucket_size);
let l1_diff = (standard.calc_l1() - fast.calc_l1()).abs();
let card_diff = (standard.calc_card() - fast.calc_card()).abs();
assert!(
l1_diff / standard.calc_l1() < 0.10,
"L1 diverged: std={}, fast={}",
standard.calc_l1(),
fast.calc_l1()
);
assert!(
card_diff / standard.calc_card().max(1.0) < 0.15,
"Card diverged: std={}, fast={}",
standard.calc_card(),
fast.calc_card()
);
}
#[test]
fn pyramid_two_tier_dimensions() {
let um = UnivMonPyramid::new(32, 4, 5, 2048, 3, 256, 8);
assert_eq!(um.layer_size, 8);
assert_eq!(um.elephant_layers, 4);
}
#[test]
fn pyramid_free_resets_state() {
let mut um = UnivMonPyramid::with_defaults();
for i in 0..100i64 {
um.insert(&DataInput::I64(i), 10);
}
assert!(um.bucket_size > 0);
um.free();
assert_eq!(um.bucket_size, 0);
assert!((um.l2_sketch_layers[0].get_l2()).abs() < 1e-9);
}
#[test]
fn pyramid_merge_combines_data() {
let mut left = UnivMonPyramid::with_defaults();
let mut right = UnivMonPyramid::with_defaults();
for i in 0..50i64 {
left.insert(&DataInput::I64(i), 10);
}
for i in 50..100i64 {
right.insert(&DataInput::I64(i), 10);
}
let left_l1 = left.calc_l1();
let right_l1 = right.calc_l1();
left.merge(&right);
let merged_l1 = left.calc_l1();
let expected = left_l1 + right_l1;
let err = (merged_l1 - expected).abs() / expected;
assert!(
err < 0.10,
"Merged L1 error {:.2}%: got {}, expected {}",
err * 100.0,
merged_l1,
expected
);
}
fn ground_truth(freq: &std::collections::HashMap<i64, i64>) -> (f64, f64, f64, f64) {
let l1: f64 = freq.values().map(|&v| v as f64).sum();
let l2: f64 = freq
.values()
.map(|&v| (v as f64).powi(2))
.sum::<f64>()
.sqrt();
let card = freq.len() as f64;
let entropy = if l1 > 0.0 {
let term: f64 = freq
.values()
.map(|&v| {
let f = v as f64;
if f > 0.0 { f * f.log2() } else { 0.0 }
})
.sum();
l1.log2() - term / l1
} else {
0.0
};
(l1, l2, card, entropy)
}
#[test]
fn pyramid_accuracy_zipf() {
use std::collections::HashMap;
let mut um = UnivMonPyramid::new(64, 8, 5, 2048, 3, 512, 16);
let mut freq: HashMap<i64, i64> = HashMap::new();
for _ in 0..5000 {
um.insert(&DataInput::I64(0), 1);
*freq.entry(0).or_insert(0) += 1;
}
for key in 1..=20i64 {
for _ in 0..200 {
um.insert(&DataInput::I64(key), 1);
*freq.entry(key).or_insert(0) += 1;
}
}
for key in 21..=500i64 {
um.insert(&DataInput::I64(key), 1);
*freq.entry(key).or_insert(0) += 1;
}
let (true_l1, true_l2, true_card, true_entropy) = ground_truth(&freq);
let err = |name: &str, est: f64, truth: f64| {
let rel = (est - truth).abs() / truth.max(1e-12);
assert!(
rel < 0.15,
"Pyramid {name}: error {:.2}%, est={est:.2}, truth={truth:.2}",
rel * 100.0
);
};
err("L1", um.calc_l1(), true_l1);
err("L2", um.calc_l2(), true_l2);
err("Card", um.calc_card(), true_card);
err("Entropy", um.calc_entropy(), true_entropy);
}
#[test]
fn pyramid_fast_insert_accuracy() {
use std::collections::HashMap;
let mut um = UnivMonPyramid::new(64, 8, 5, 2048, 3, 512, 16);
let mut freq: HashMap<i64, i64> = HashMap::new();
for _ in 0..3000 {
um.fast_insert(&DataInput::I64(0), 1);
*freq.entry(0).or_insert(0) += 1;
}
for key in 1..=50i64 {
for _ in 0..100 {
um.fast_insert(&DataInput::I64(key), 1);
*freq.entry(key).or_insert(0) += 1;
}
}
let (true_l1, true_l2, true_card, true_entropy) = ground_truth(&freq);
let err = |name: &str, est: f64, truth: f64| {
let rel = (est - truth).abs() / truth.max(1e-12);
assert!(
rel < 0.15,
"Pyramid fast {name}: error {:.2}%, est={est:.2}, truth={truth:.2}",
rel * 100.0
);
};
err("L1", um.calc_l1(), true_l1);
err("L2", um.calc_l2(), true_l2);
err("Card", um.calc_card(), true_card);
err("Entropy", um.calc_entropy(), true_entropy);
}
#[test]
fn pyramid_memory_savings_vs_uniform() {
let elephant_col = 2048;
let mouse_col = 512;
let elephant_layers = 8;
let total_layers = 16;
let uniform_cols = elephant_col * total_layers;
let pyramid_cols =
elephant_col * elephant_layers + mouse_col * (total_layers - elephant_layers);
assert!(
pyramid_cols < uniform_cols,
"Pyramid ({pyramid_cols}) should use fewer columns than uniform ({uniform_cols})"
);
let savings = 1.0 - (pyramid_cols as f64 / uniform_cols as f64);
assert!(
savings > 0.30,
"Expected >30% column savings, got {:.1}%",
savings * 100.0
);
}
}