use crate::cpc::CpcSketch;
use crate::cpc::DEFAULT_LG_K;
use crate::cpc::Flavor;
use crate::cpc::count_bits_set_in_matrix;
use crate::cpc::determine_correct_offset;
use crate::cpc::pair_table::PairTable;
use crate::hash::DEFAULT_UPDATE_SEED;
#[derive(Debug, Clone)]
pub struct CpcUnion {
lg_k: u8,
seed: u64,
state: UnionState,
}
impl Default for CpcUnion {
fn default() -> Self {
Self::new(DEFAULT_LG_K)
}
}
impl CpcUnion {
pub fn new(lg_k: u8) -> Self {
Self::with_seed(lg_k, DEFAULT_UPDATE_SEED)
}
pub fn with_seed(lg_k: u8, seed: u64) -> Self {
let sketch = CpcSketch::with_seed(lg_k, seed);
let state = UnionState::Accumulator(sketch);
Self { lg_k, seed, state }
}
pub fn lg_k(&self) -> u8 {
self.lg_k
}
pub fn to_sketch(&self) -> CpcSketch {
match &self.state {
UnionState::Accumulator(sketch) => {
if sketch.is_empty() {
CpcSketch::with_seed(self.lg_k, self.seed)
} else {
let mut sketch = sketch.clone();
assert_eq!(sketch.flavor(), Flavor::Sparse);
sketch.merge_flag = true;
sketch
}
}
UnionState::BitMatrix(matrix) => {
let lg_k = self.lg_k;
let mut sketch = CpcSketch::with_seed(lg_k, self.seed);
let num_coupons = count_bits_set_in_matrix(matrix);
sketch.num_coupons = num_coupons;
let offset = determine_correct_offset(lg_k, num_coupons);
sketch.window_offset = offset;
let k = 1 << lg_k;
let mut sliding_window = vec![0u8; k];
let new_table_lg_size = (lg_k - 4).max(2);
let mut table = PairTable::new(new_table_lg_size, 6 + lg_k);
let mask_for_clearing_window = (0xFFu64 << offset) ^ u64::MAX;
let mask_for_flipping_early_zone = (1u64 << offset) - 1;
let mut all_surprises_ored = 0;
for i in 0..k {
let mut pattern = matrix[i];
sliding_window[i] = ((pattern >> offset) & 0xFF) as u8;
pattern &= mask_for_clearing_window;
pattern ^= mask_for_flipping_early_zone; all_surprises_ored |= pattern;
while pattern != 0 {
let col = pattern.trailing_zeros();
pattern ^= 1u64 << col; let row_col = ((i as u32) << 6) | col;
let is_novel = table.maybe_insert(row_col);
assert!(is_novel);
}
}
sketch.first_interesting_column = all_surprises_ored.trailing_zeros() as u8;
if sketch.first_interesting_column > offset {
sketch.first_interesting_column = offset; }
sketch.sliding_window = sliding_window;
sketch.surprising_value_table = Some(table);
sketch.merge_flag = true;
sketch
}
}
}
pub fn update(&mut self, sketch: &CpcSketch) {
assert_eq!(self.seed, sketch.seed());
let flavor = sketch.flavor();
if flavor == Flavor::Empty {
return;
}
if sketch.lg_k() < self.lg_k {
self.reduce_k(sketch.lg_k());
}
if flavor > Flavor::Sparse {
if let UnionState::Accumulator(old_sketch) = &self.state {
let bit_matrix = old_sketch.build_bit_matrix();
self.state = UnionState::BitMatrix(bit_matrix);
}
}
match &mut self.state {
UnionState::Accumulator(old_sketch) => {
if flavor == Flavor::Sparse {
let old_flavor = old_sketch.flavor();
if old_flavor != Flavor::Sparse && old_flavor != Flavor::Empty {
unreachable!(
"unexpected old flavor in union accumulator: {:?}",
old_flavor
);
}
if old_flavor == Flavor::Empty && self.lg_k == sketch.lg_k() {
*old_sketch = sketch.clone();
return;
}
walk_table_updating_sketch(old_sketch, sketch.surprising_value_table());
let final_flavor = old_sketch.flavor();
if final_flavor > Flavor::Sparse {
let bit_matrix = old_sketch.build_bit_matrix();
self.state = UnionState::BitMatrix(bit_matrix);
}
return;
}
unreachable!("unexpected flavor in union accumulator: {:?}", flavor);
}
UnionState::BitMatrix(old_matrix) => {
if flavor == Flavor::Sparse {
or_table_into_matrix(old_matrix, self.lg_k, sketch.surprising_value_table());
return;
}
if matches!(flavor, Flavor::Hybrid | Flavor::Pinned) {
or_window_into_matrix(
old_matrix,
self.lg_k,
&sketch.sliding_window,
sketch.window_offset,
sketch.lg_k(),
);
or_table_into_matrix(old_matrix, self.lg_k, sketch.surprising_value_table());
return;
}
assert_eq!(flavor, Flavor::Sliding);
let src_matrix = sketch.build_bit_matrix();
or_matrix_into_matrix(old_matrix, self.lg_k, &src_matrix, sketch.lg_k());
}
}
}
fn reduce_k(&mut self, new_lg_k: u8) {
match &mut self.state {
UnionState::Accumulator(sketch) => {
if sketch.is_empty() {
self.lg_k = new_lg_k;
self.state = UnionState::Accumulator(CpcSketch::with_seed(new_lg_k, self.seed));
return;
}
let mut new_sketch = CpcSketch::with_seed(new_lg_k, self.seed);
walk_table_updating_sketch(&mut new_sketch, sketch.surprising_value_table());
let final_new_flavor = new_sketch.flavor();
assert_ne!(final_new_flavor, Flavor::Empty);
if final_new_flavor == Flavor::Sparse {
self.lg_k = new_lg_k;
self.state = UnionState::Accumulator(new_sketch);
return;
}
self.lg_k = new_lg_k;
self.state = UnionState::BitMatrix(new_sketch.build_bit_matrix());
}
UnionState::BitMatrix(matrix) => {
let new_k = 1 << new_lg_k;
let mut new_matrix = vec![0; new_k];
or_matrix_into_matrix(&mut new_matrix, new_lg_k, matrix, self.lg_k);
self.lg_k = new_lg_k;
self.state = UnionState::BitMatrix(new_matrix);
}
}
}
}
impl CpcUnion {
pub fn num_coupons(&self) -> u32 {
match &self.state {
UnionState::Accumulator(sketch) => sketch.num_coupons,
UnionState::BitMatrix(matrix) => count_bits_set_in_matrix(matrix),
}
}
}
fn or_window_into_matrix(
dst_matrix: &mut [u64],
dst_lg_k: u8,
src_window: &[u8],
src_offset: u8,
src_lg_k: u8,
) {
assert!(dst_lg_k <= src_lg_k);
let dst_mask = (1 << dst_lg_k) - 1; let src_k = 1 << src_lg_k;
for src_row in 0..src_k {
dst_matrix[src_row & dst_mask] |= (src_window[src_row] as u64) << src_offset;
}
}
fn or_table_into_matrix(dst_matrix: &mut [u64], dst_lg_k: u8, src_table: &PairTable) {
let dst_mask = (1 << dst_lg_k) - 1; let slots = src_table.slots();
for &row_col in slots.iter() {
if row_col != u32::MAX {
let src_row = (row_col >> 6) as usize;
let src_col = (row_col & 63) as usize;
let dst_row = src_row & dst_mask;
dst_matrix[dst_row] |= 1u64 << src_col;
}
}
}
fn or_matrix_into_matrix(dst_matrix: &mut [u64], dst_lg_k: u8, src_matrix: &[u64], src_lg_k: u8) {
assert!(dst_lg_k <= src_lg_k);
let dst_mask = (1 << dst_lg_k) - 1; let src_k = 1 << src_lg_k;
for src_row in 0..src_k {
let dst_row = src_row & dst_mask;
dst_matrix[dst_row] |= src_matrix[src_row];
}
}
fn walk_table_updating_sketch(sketch: &mut CpcSketch, table: &PairTable) {
assert!(sketch.lg_k() <= 26);
let slots = table.slots();
let num_slots = slots.len() as u32;
let dst_mask = (((1u64 << sketch.lg_k()) - 1) << 6) | 63;
let mut stride = (0.6180339887498949 * (num_slots as f64)) as u32;
assert!(stride >= 2);
if stride == ((stride >> 1) << 1) {
stride += 1;
}
assert!((stride >= 3) && (stride < num_slots));
let mut k = 0;
for _ in 0..num_slots {
k &= num_slots - 1;
let row_col = slots[k as usize];
if row_col != u32::MAX {
sketch.row_col_update(row_col & (dst_mask as u32));
}
k += stride;
}
}
#[derive(Debug, Clone)]
enum UnionState {
Accumulator(CpcSketch),
BitMatrix(Vec<u64>),
}