pub trait Sketch: Default {
const STREAMS: u32;
const HASH_MASK: u64;
const IDX_SHIFT: u32;
fn val(&self, stream: u32) -> u8;
fn set(&mut self, stream: u32, value: u8);
fn decrement(&mut self) -> u32;
fn count(&self) -> u32;
fn merge(&mut self, other: &Self);
fn merge_high_into_lo(&mut self, other: &Self);
}
#[derive(Debug, Eq, PartialEq, Hash, Clone, Default)]
#[cfg_attr(feature = "mem_dbg", derive(mem_dbg::MemDbg, mem_dbg::MemSize))]
pub struct M64 {
low: u64,
high: u64,
}
impl Sketch for M64 {
const STREAMS: u32 = 64;
const HASH_MASK: u64 =
0b0000_0011_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111;
const IDX_SHIFT: u32 = 58;
#[inline]
#[allow(clippy::cast_possible_truncation)]
fn val(&self, stream: u32) -> u8 {
debug_assert!(stream < Self::STREAMS);
let high_bit = (self.high >> stream) as u8 & 1;
let low_bit = (self.low >> stream) as u8 & 1;
high_bit << 1 | low_bit
}
#[inline]
fn set(&mut self, stream: u32, value: u8) {
debug_assert!(stream < Self::STREAMS);
debug_assert!(value < 4);
let value = u64::from(value);
let value_high_bit = (value >> 1) & 1;
let value_low_bit = value & 1;
self.high &= !(1 << stream);
self.low &= !(1 << stream);
self.high |= value_high_bit << stream;
self.low |= value_low_bit << stream;
}
#[inline]
fn decrement(&mut self) -> u32 {
let count = self.high.count_ones();
self.low = self.high & !self.low;
self.high &= !self.low;
count
}
#[inline]
fn count(&self) -> u32 {
let used_streams = self.high | self.low;
used_streams.count_ones()
}
#[inline]
fn merge(&mut self, other: &Self) {
self.high |= other.high;
self.low |= other.low;
}
#[inline]
fn merge_high_into_lo(&mut self, other: &Self) {
self.low |= other.high;
}
}
#[derive(Debug, Eq, PartialEq, Hash, Clone, Default)]
#[cfg_attr(feature = "mem_dbg", derive(mem_dbg::MemDbg, mem_dbg::MemSize))]
pub struct M128 {
low: u128,
high: u128,
}
impl Sketch for M128 {
const STREAMS: u32 = 128;
const HASH_MASK: u64 =
0b0000_0001_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111;
const IDX_SHIFT: u32 = 57;
#[inline]
#[allow(clippy::cast_possible_truncation)]
fn val(&self, stream: u32) -> u8 {
debug_assert!(stream < Self::STREAMS);
let high_bit = (self.high >> stream) as u8 & 1;
let low_bit = (self.low >> stream) as u8 & 1;
high_bit << 1 | low_bit
}
#[inline]
fn set(&mut self, stream: u32, value: u8) {
debug_assert!(stream < Self::STREAMS);
debug_assert!(value < 4);
let value = u128::from(value);
let value_high_bit = (value >> 1) & 1;
let value_low_bit = value & 1;
self.high &= !(1 << stream);
self.low &= !(1 << stream);
self.high |= value_high_bit << stream;
self.low |= value_low_bit << stream;
}
#[inline]
fn decrement(&mut self) -> u32 {
let count = self.high.count_ones();
self.low = self.high & !self.low;
self.high &= !self.low;
count
}
#[inline]
fn count(&self) -> u32 {
let used_streams = self.high | self.low;
used_streams.count_ones()
}
#[inline]
fn merge(&mut self, other: &Self) {
self.high |= other.high;
self.low |= other.low;
}
#[inline]
fn merge_high_into_lo(&mut self, other: &Self) {
self.low |= other.high;
}
}
#[derive(Default, Clone, Copy, Debug, Eq, PartialEq, Hash)]
#[cfg_attr(feature = "mem_dbg", derive(mem_dbg::MemDbg, mem_dbg::MemSize))]
struct HiLoRegister {
high: u128,
low: u128,
}
#[derive(Debug, Eq, PartialEq, Hash, Copy, Clone)]
#[cfg_attr(feature = "mem_dbg", derive(mem_dbg::MemDbg, mem_dbg::MemSize))]
pub struct M128Reg<const REGISTERS: usize> {
registers: [HiLoRegister; REGISTERS],
}
impl<const REGISTERS: usize> Default for M128Reg<REGISTERS> {
fn default() -> Self {
Self {
registers: [HiLoRegister { high: 0, low: 0 }; REGISTERS],
}
}
}
impl<const REGISTERS: usize> M128Reg<REGISTERS> {
const REG_SIZE: usize = 128;
#[inline]
fn val(&self, stream: u32) -> u8 {
let register_index = stream as usize / Self::REG_SIZE;
let bit_index = stream as usize % Self::REG_SIZE;
let high_bit = ((self.registers[register_index].high >> bit_index) & 1) as u8;
let low_bit = ((self.registers[register_index].low >> bit_index) & 1) as u8;
(high_bit << 1) | low_bit
}
#[inline]
fn set(&mut self, stream: u32, value: u8) {
debug_assert!(value < 4);
let register_index = stream as usize / Self::REG_SIZE;
let bit_index = stream as usize % Self::REG_SIZE;
let value = u128::from(value);
let value_high_bit = (value >> 1) & 1;
let value_low_bit = value & 1;
self.registers[register_index].high &= !(1 << bit_index);
self.registers[register_index].low &= !(1 << bit_index);
self.registers[register_index].high |= value_high_bit << bit_index;
self.registers[register_index].low |= value_low_bit << bit_index;
}
#[inline]
fn decrement(&mut self) -> u32 {
let mut count = 0;
for register in &mut self.registers {
count += register.high.count_ones();
register.low = register.high & !register.low;
register.high &= !register.low;
}
count
}
#[inline]
fn count(&self) -> u32 {
let mut count = 0;
for registers in self.registers {
count += (registers.high | registers.low).count_ones();
}
count
}
#[inline]
fn merge(&mut self, other: &Self) {
for (self_register, other_register) in self.registers.iter_mut().zip(other.registers.iter())
{
self_register.high |= other_register.high;
self_register.low |= other_register.low;
}
}
#[inline]
fn merge_high_into_lo(&mut self, other: &Self) {
for (self_register, other_register) in self.registers.iter_mut().zip(other.registers.iter())
{
self_register.low |= other_register.high;
}
}
}
pub type M256 = M128Reg<2>;
impl Sketch for M256 {
const STREAMS: u32 = 256;
const HASH_MASK: u64 =
0b0000_0000_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111;
const IDX_SHIFT: u32 = 56;
#[inline]
fn val(&self, stream: u32) -> u8 {
debug_assert!(stream < Self::STREAMS);
self.val(stream)
}
#[inline]
fn set(&mut self, stream: u32, value: u8) {
debug_assert!(stream < Self::STREAMS);
debug_assert!(value < 4);
self.set(stream, value);
}
#[inline]
fn decrement(&mut self) -> u32 {
self.decrement()
}
#[inline]
fn count(&self) -> u32 {
self.count()
}
#[inline]
fn merge(&mut self, other: &Self) {
self.merge(other);
}
#[inline]
fn merge_high_into_lo(&mut self, other: &Self) {
self.merge_high_into_lo(other);
}
}
pub type M512 = M128Reg<4>;
impl Sketch for M512 {
const STREAMS: u32 = 512;
const HASH_MASK: u64 =
0b0000_0000_0111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111;
const IDX_SHIFT: u32 = 55;
#[inline]
fn val(&self, stream: u32) -> u8 {
debug_assert!(stream < Self::STREAMS);
self.val(stream)
}
#[inline]
fn set(&mut self, stream: u32, value: u8) {
debug_assert!(stream < Self::STREAMS);
debug_assert!(value < 4);
self.set(stream, value);
}
#[inline]
fn decrement(&mut self) -> u32 {
self.decrement()
}
#[inline]
fn count(&self) -> u32 {
self.count()
}
#[inline]
fn merge(&mut self, other: &Self) {
self.merge(other);
}
#[inline]
fn merge_high_into_lo(&mut self, other: &Self) {
self.merge_high_into_lo(other);
}
}
pub type M1024 = M128Reg<8>;
impl Sketch for M1024 {
const STREAMS: u32 = 1024;
const HASH_MASK: u64 =
0b0000_0000_0011_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111;
const IDX_SHIFT: u32 = 54;
#[inline]
fn val(&self, stream: u32) -> u8 {
debug_assert!(stream < Self::STREAMS);
self.val(stream)
}
#[inline]
fn set(&mut self, stream: u32, value: u8) {
debug_assert!(stream < Self::STREAMS);
debug_assert!(value < 4);
self.set(stream, value);
}
#[inline]
fn decrement(&mut self) -> u32 {
self.decrement()
}
#[inline]
fn count(&self) -> u32 {
self.count()
}
#[inline]
fn merge(&mut self, other: &Self) {
self.merge(other);
}
#[inline]
fn merge_high_into_lo(&mut self, other: &Self) {
self.merge_high_into_lo(other);
}
}
pub type M2048 = M128Reg<16>;
impl Sketch for M2048 {
const STREAMS: u32 = 2048;
const HASH_MASK: u64 =
0b0000_0000_0001_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111;
const IDX_SHIFT: u32 = 53;
#[inline]
fn val(&self, stream: u32) -> u8 {
debug_assert!(stream < Self::STREAMS);
self.val(stream)
}
#[inline]
fn set(&mut self, stream: u32, value: u8) {
debug_assert!(stream < Self::STREAMS);
debug_assert!(value < 4);
self.set(stream, value);
}
#[inline]
fn decrement(&mut self) -> u32 {
self.decrement()
}
#[inline]
fn count(&self) -> u32 {
self.count()
}
#[inline]
fn merge(&mut self, other: &Self) {
self.merge(other);
}
#[inline]
fn merge_high_into_lo(&mut self, other: &Self) {
self.merge_high_into_lo(other);
}
}
pub type M4096 = M128Reg<32>;
impl Sketch for M4096 {
const STREAMS: u32 = 4096;
const HASH_MASK: u64 =
0b0000_0000_0000_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111;
const IDX_SHIFT: u32 = 52;
#[inline]
fn val(&self, stream: u32) -> u8 {
debug_assert!(stream < Self::STREAMS);
self.val(stream)
}
#[inline]
fn set(&mut self, stream: u32, value: u8) {
debug_assert!(stream < Self::STREAMS);
debug_assert!(value < 4);
self.set(stream, value);
}
#[inline]
fn decrement(&mut self) -> u32 {
self.decrement()
}
#[inline]
fn count(&self) -> u32 {
self.count()
}
#[inline]
fn merge(&mut self, other: &Self) {
self.merge(other);
}
#[inline]
fn merge_high_into_lo(&mut self, other: &Self) {
self.merge_high_into_lo(other);
}
}
pub type M8192 = M128Reg<64>;
impl Sketch for M8192 {
const STREAMS: u32 = 8192;
const HASH_MASK: u64 =
0b0000_0000_0000_0111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111;
const IDX_SHIFT: u32 = 51;
#[inline]
fn val(&self, stream: u32) -> u8 {
debug_assert!(stream < Self::STREAMS);
self.val(stream)
}
#[inline]
fn set(&mut self, stream: u32, value: u8) {
debug_assert!(stream < Self::STREAMS);
debug_assert!(value < 4);
self.set(stream, value);
}
#[inline]
fn decrement(&mut self) -> u32 {
self.decrement()
}
#[inline]
fn count(&self) -> u32 {
self.count()
}
#[inline]
fn merge(&mut self, other: &Self) {
self.merge(other);
}
#[inline]
fn merge_high_into_lo(&mut self, other: &Self) {
self.merge_high_into_lo(other);
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test<S: Sketch>() {
let mut s = S::default();
for i in 0..S::STREAMS {
assert_eq!(s.val(i), 0);
for r in 1..=3 {
s.set(i, r);
assert_eq!(s.val(i), r);
}
for j in 0..S::STREAMS {
if j == i {
assert_eq!(s.val(j), 3);
} else {
assert_eq!(s.val(j), 0);
}
}
s.set(i, 0);
assert_eq!(s.val(i), 0);
}
for i in 0..S::STREAMS {
s.set(i, 3);
assert_eq!(s.val(i), 3);
}
for r in (0..=2).rev() {
s.decrement();
for i in 0..S::STREAMS {
assert_eq!(s.val(i), r);
}
}
s.decrement();
for i in 0..S::STREAMS {
assert_eq!(s.val(i), 0);
}
}
#[test]
fn test_m64() {
test::<M64>();
}
#[test]
fn test_m128() {
test::<M128>();
}
#[test]
fn test_m265() {
test::<M256>();
}
#[test]
fn test_m512() {
test::<M512>();
}
#[test]
fn test_m1024() {
test::<M1024>();
}
#[test]
fn test_m2048() {
test::<M2048>();
}
#[test]
fn test_m4096() {
test::<M4096>();
}
#[test]
fn test_m8192() {
test::<M8192>();
}
}