use std::borrow::Borrow;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::mem::size_of_val;
use bitpacking::{BitPacker, BitPacker1x};
use num::{PrimInt, Unsigned};
use wyhash::WyHash;
use crate::mphf::{Mphf, DEFAULT_GAMMA};
#[derive(Default)]
#[cfg_attr(feature = "rkyv_derive", derive(rkyv::Archive, rkyv::Deserialize, rkyv::Serialize))]
#[cfg_attr(feature = "rkyv_derive", archive_attr(derive(rkyv::CheckBytes)))]
pub struct MapWithDictBitpacked<K, const B: usize = 32, const S: usize = 8, ST = u8, H = WyHash>
where
ST: PrimInt + Unsigned,
H: Hasher + Default,
{
mphf: Mphf<B, S, ST, H>,
keys: Box<[K]>,
values_index: Box<[usize]>,
values_dict: Box<[u8]>,
}
#[derive(Debug)]
pub enum Error {
MphfError(crate::mphf::MphfError),
NotEqualValuesLengths,
}
impl<K, const B: usize, const S: usize, ST, H> MapWithDictBitpacked<K, B, S, ST, H>
where
K: Hash + PartialEq + Clone,
ST: PrimInt + Unsigned,
H: Hasher + Default,
{
pub fn from_iter_with_params<I>(iter: I, gamma: f32) -> Result<Self, Error>
where
I: IntoIterator<Item = (K, Vec<u32>)>,
{
let mut keys = vec![];
let mut offsets_cache = HashMap::new();
let mut values_index = vec![];
let mut values_dict = vec![];
let mut iter = iter.into_iter().peekable();
let v_len = iter.peek().map_or(0, |(_, v)| v.len());
for (k, v) in iter {
keys.push(k.clone());
if v.len() != v_len {
return Err(Error::NotEqualValuesLengths);
}
if let Some(&offset) = offsets_cache.get(&v) {
values_index.push(offset);
} else {
let offset = values_dict.len();
offsets_cache.insert(v.clone(), offset);
values_index.push(offset);
pack_values(&v, &mut values_dict);
}
}
values_dict.resize(values_dict.len() + 4 * VALUES_BLOCK_LEN, 0);
let mphf = Mphf::from_slice(&keys, gamma).map_err(Error::MphfError)?;
for i in 0..keys.len() {
loop {
let idx = mphf.get(&keys[i]).unwrap();
if idx == i {
break;
}
keys.swap(i, idx);
values_index.swap(i, idx);
}
}
Ok(MapWithDictBitpacked {
mphf,
keys: keys.into_boxed_slice(),
values_index: values_index.into_boxed_slice(),
values_dict: values_dict.into_boxed_slice(),
})
}
#[inline]
pub fn get_values<Q>(&self, key: &Q, values: &mut [u32]) -> bool
where
K: Borrow<Q> + PartialEq<Q>,
Q: Hash + Eq + ?Sized,
{
let idx = match self.mphf.get(key) {
Some(idx) => idx,
None => return false,
};
unsafe {
if self.keys.get_unchecked(idx) != key {
return false;
}
let value_idx = *self.values_index.get_unchecked(idx);
let dict = self.values_dict.get_unchecked(value_idx..);
unpack_values(dict, values);
}
true
}
#[inline]
pub fn len(&self) -> usize {
self.keys.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.keys.is_empty()
}
#[inline]
pub fn contains_key<Q>(&self, key: &Q) -> bool
where
K: Borrow<Q> + PartialEq<Q>,
Q: Hash + Eq + ?Sized,
{
if let Some(idx) = self.mphf.get(key) {
unsafe { self.keys.get_unchecked(idx) == key }
} else {
false
}
}
#[inline]
pub fn iter(&self, n: usize) -> impl Iterator<Item = (&K, Vec<u32>)> {
self.keys().zip(self.values_index.iter()).map(move |(key, &value_idx)| {
let mut values = vec![0; n];
let dict = unsafe { self.values_dict.get_unchecked(value_idx..) };
unpack_values(dict, &mut values);
(key, values)
})
}
#[inline]
pub fn keys(&self) -> impl Iterator<Item = &K> {
self.keys.iter()
}
#[inline]
pub fn values(&self, n: usize) -> impl Iterator<Item = Vec<u32>> + '_ {
self.values_index.iter().map(move |&value_idx| {
let mut values = vec![0; n];
let dict = unsafe { self.values_dict.get_unchecked(value_idx..) };
unpack_values(dict, &mut values);
values
})
}
pub fn size(&self) -> usize {
size_of_val(self)
+ self.mphf.size()
+ size_of_val(self.keys.as_ref())
+ size_of_val(self.values_index.as_ref())
+ size_of_val(self.values_dict.as_ref())
}
}
impl<K> TryFrom<HashMap<K, Vec<u32>>> for MapWithDictBitpacked<K>
where
K: PartialEq + Hash + Clone,
{
type Error = Error;
#[inline]
fn try_from(value: HashMap<K, Vec<u32>>) -> Result<Self, Self::Error> {
MapWithDictBitpacked::from_iter_with_params(value, DEFAULT_GAMMA)
}
}
const VALUES_BLOCK_LEN: usize = BitPacker1x::BLOCK_LEN;
fn pack_values(values: &[u32], dict: &mut Vec<u8>) {
let bitpacker = BitPacker1x::new();
for block in values.chunks(VALUES_BLOCK_LEN) {
let mut values_block = [0u32; VALUES_BLOCK_LEN];
let mut values_packed_block = [0u8; 4 * VALUES_BLOCK_LEN];
values_block[..block.len()].copy_from_slice(block);
let num_bits = bitpacker.num_bits(&values_block);
bitpacker.compress(&values_block, &mut values_packed_block, num_bits);
let size = (block.len() * (num_bits as usize)).div_ceil(8);
dict.push(num_bits);
dict.extend_from_slice(&values_packed_block[..size]);
}
}
fn unpack_values(dict: &[u8], res: &mut [u32]) {
let bitpacker = BitPacker1x::new();
let mut dict = dict;
for block in res.chunks_mut(VALUES_BLOCK_LEN) {
let mut values_block = [0u32; VALUES_BLOCK_LEN];
let num_bits = dict[0];
dict = &dict[1..];
let size = (block.len() * (num_bits as usize)).div_ceil(8);
bitpacker.decompress(dict, &mut values_block, num_bits);
dict = &dict[size..];
block.copy_from_slice(&values_block[..block.len()]);
}
}
#[cfg(feature = "rkyv_derive")]
impl<K, const B: usize, const S: usize, ST, H> ArchivedMapWithDictBitpacked<K, B, S, ST, H>
where
K: PartialEq + Hash + rkyv::Archive,
K::Archived: PartialEq<K>,
ST: PrimInt + Unsigned + rkyv::Archive<Archived = ST>,
H: Hasher + Default,
{
#[inline]
pub fn get_values(&self, key: &K, values: &mut [u32]) -> bool {
let idx = match self.mphf.get(key) {
Some(idx) => idx,
None => return false,
};
unsafe {
if self.keys.get_unchecked(idx) != key {
return false;
}
let value_idx = *self.values_index.get_unchecked(idx) as usize;
let dict = self.values_dict.get_unchecked(value_idx..);
unpack_values(dict, values);
}
true
}
}
#[cfg(test)]
mod tests {
use super::*;
use paste::paste;
use proptest::prelude::*;
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha8Rng;
use std::collections::{hash_map::RandomState, HashSet};
use test_case::test_case;
#[test_case(
&[] => Vec::<u8>::new();
"empty values"
)]
#[test_case(
&[0] => vec![0];
"single 0-bit value"
)]
#[test_case(
&[0; 10] => vec![0];
"10 0-bit value"
)]
#[test_case(
&[0; 77] => vec![0, 0, 0];
"77 0-bit values (3 blocks)"
)]
#[test_case(
&[1] => vec![1, 1];
"single 1-bit value"
)]
#[test_case(
&[1; 10] => vec![1, 0b11111111, 0b00000011];
"10 1-bit value"
)]
#[test_case(
&[1; 32] => vec![1, 0b11111111, 0b11111111, 0b11111111, 0b11111111];
"32 1-bit value"
)]
#[test_case(
&[1; 33] => vec![1, 0b11111111, 0b11111111, 0b11111111, 0b11111111, 1, 0b00000001];
"33 1-bit value"
)]
#[test_case(
&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10] => vec![4, 0b0010_0001, 0b0100_0011, 0b0110_0101, 0b1000_0111, 0b1010_1001];
"10 4-bit value"
)]
fn test_pack_unpack(values: &[u32]) -> Vec<u8> {
let mut dict = vec![];
pack_values(values, &mut dict);
let mut padded_dict = dict.clone();
padded_dict.resize(dict.len() + 4 * VALUES_BLOCK_LEN, 0);
let mut unpacked_values = vec![0; values.len()];
unpack_values(&padded_dict, &mut unpacked_values);
assert_eq!(values, unpacked_values);
dict
}
#[test]
fn test_pack_unpack_random() {
let max_n = 200;
let mut rng = ChaCha8Rng::seed_from_u64(123);
let mut dict = vec![];
let mut values = vec![];
let mut unpacked_values = vec![];
for n in 1..=max_n {
for num_bits in 0..=32 {
values.truncate(0);
values.extend((0..n).map(|_| rng.gen::<u32>() & ((1u32 << (num_bits % 32)) - 1)));
dict.truncate(0);
pack_values(&values, &mut dict);
assert!(!dict.is_empty());
dict.resize(dict.len() + 4 * VALUES_BLOCK_LEN, 0);
unpacked_values.resize(n, 0);
unpack_values(&dict, &mut unpacked_values);
assert_eq!(values, unpacked_values);
}
}
}
fn gen_map(items_num: usize, values_num: usize) -> HashMap<u64, Vec<u32>> {
let mut rng = ChaCha8Rng::seed_from_u64(123);
(0..items_num)
.map(|_| {
let key = rng.gen::<u64>();
let value = (0..values_num).map(|_| rng.gen_range(1..=10)).collect();
(key, value)
})
.collect()
}
#[test]
fn test_map_with_dict_bitpacked() {
let items_num = 1000;
let values_num = 10;
let original_map = gen_map(items_num, values_num);
let map = MapWithDictBitpacked::try_from(original_map.clone()).unwrap();
assert_eq!(map.len(), original_map.len());
assert_eq!(map.is_empty(), original_map.is_empty());
let mut values_buf = vec![0; values_num];
for (key, value) in &original_map {
assert!(map.get_values(key, &mut values_buf));
assert_eq!(value, &values_buf);
assert!(map.contains_key(key));
}
for (&k, v) in map.iter(values_num) {
assert_eq!(original_map.get(&k), Some(&v));
}
for k in map.keys() {
assert!(original_map.contains_key(k));
}
for v in map.values(values_num) {
assert!(original_map.values().any(|val| val == &v));
}
assert_eq!(map.size(), 22664);
}
#[cfg(feature = "rkyv_derive")]
#[test]
fn test_rkyv() {
let items_num = 1000;
let values_num = 10;
let original_map = gen_map(items_num, values_num);
let map = MapWithDictBitpacked::try_from(original_map.clone()).unwrap();
let rkyv_bytes = rkyv::to_bytes::<_, 1024>(&map).unwrap();
assert_eq!(rkyv_bytes.len(), 18516);
let rkyv_map = rkyv::check_archived_root::<MapWithDictBitpacked<u64>>(&rkyv_bytes).unwrap();
let mut values_buf = vec![0; values_num];
for (k, v) in original_map {
rkyv_map.get_values(&k, &mut values_buf);
assert_eq!(v, values_buf);
}
}
macro_rules! proptest_map_with_dict_bitpacked_model {
($(($b:expr, $s:expr, $gamma:expr, $n:expr)),* $(,)?) => {
$(
paste! {
proptest! {
#[test]
fn [<proptest_map_with_dict_bitpacked_model_ $b _ $s _ $n _ $gamma>](model: HashMap<u64, [u32; $n]>, arbitrary: HashSet<u64>) {
let entropy_map: MapWithDictBitpacked<u64, $b, $s> = MapWithDictBitpacked::from_iter_with_params(
model.iter().map(|(&k, v)| (k, Vec::from(v))),
$gamma as f32 / 100.0
).unwrap();
assert_eq!(entropy_map.len(), model.len());
assert_eq!(entropy_map.is_empty(), model.is_empty());
assert_eq!(
HashSet::<_, RandomState>::from_iter(entropy_map.keys()),
HashSet::from_iter(model.keys())
);
assert_eq!(
HashSet::<_, RandomState>::from_iter(entropy_map.values($n)),
HashSet::from_iter(model.values().map(Vec::from))
);
for (k, v) in &model {
assert!(entropy_map.contains_key(&k));
let mut buf = [0u32; $n];
assert!(entropy_map.get_values(&k, &mut buf));
assert_eq!(&buf, v);
}
for k in arbitrary {
assert_eq!(
model.contains_key(&k),
entropy_map.contains_key(&k),
);
let mut buf = [0u32; $n];
let contains = entropy_map.get_values(&k, &mut buf);
assert_eq!(contains, model.contains_key(&k));
if contains {
assert_eq!(Some(&buf), model.get(&k));
}
}
}
}
}
)*
};
}
proptest_map_with_dict_bitpacked_model!(
(2, 8, 100, 10),
(4, 8, 100, 10),
(7, 8, 100, 10),
(8, 8, 100, 10),
(15, 8, 100, 10),
(16, 8, 100, 10),
(23, 8, 100, 10),
(24, 8, 100, 10),
(31, 8, 100, 10),
(32, 8, 100, 10),
(33, 8, 100, 10),
(48, 8, 100, 10),
(53, 8, 100, 10),
(61, 8, 100, 10),
(63, 8, 100, 10),
(64, 8, 100, 10),
(32, 7, 100, 10),
(32, 5, 100, 10),
(32, 4, 100, 10),
(32, 3, 100, 10),
(32, 1, 100, 10),
(32, 0, 100, 10),
(32, 8, 200, 10),
(32, 6, 200, 10),
);
}