#![allow(clippy::comparison_chain)]
#![allow(clippy::type_complexity)]
use anyhow::Result;
use mem_dbg::{MemDbg, MemSize};
use rdst::RadixKey;
use std::{
borrow::{Borrow, BorrowMut},
collections::VecDeque,
fs::File,
io::*,
iter::FusedIterator,
marker::PhantomData,
ops::{BitXor, BitXorAssign},
sync::Arc,
};
use xxhash_rust::xxh3;
use zerocopy::{FromBytes, IntoBytes};
pub trait BinSafe: FromBytes + IntoBytes + Copy + Send + Sync + 'static {}
impl<T: FromBytes + IntoBytes + Copy + Send + Sync + 'static> BinSafe for T {}
pub trait Sig: BinSafe + Default + PartialEq + Eq + std::fmt::Debug {
fn high_bits(&self, high_bits: u32, mask: u64) -> u64;
fn from_hasher(hasher: &xxh3::Xxh3) -> Self;
}
impl Sig for [u64; 2] {
#[inline(always)]
fn high_bits(&self, high_bits: u32, mask: u64) -> u64 {
debug_assert!(mask == (1 << high_bits) - 1);
self[0].rotate_left(high_bits) & mask
}
#[inline(always)]
fn from_hasher(hasher: &xxh3::Xxh3) -> Self {
let h = hasher.digest128();
[(h >> 64) as u64, h as u64]
}
}
impl Sig for [u64; 1] {
#[inline(always)]
fn high_bits(&self, high_bits: u32, mask: u64) -> u64 {
debug_assert!(mask == (1 << high_bits) - 1);
self[0].rotate_left(high_bits) & mask
}
#[inline(always)]
fn from_hasher(hasher: &xxh3::Xxh3) -> Self {
[hasher.digest()]
}
}
#[derive(Debug, Clone, Copy, Default, MemSize, MemDbg)]
pub struct SigVal<S: BinSafe + Sig, V: BinSafe> {
pub sig: S,
pub val: V,
}
impl<V: BinSafe> RadixKey for SigVal<[u64; 2], V> {
const LEVELS: usize = 16;
fn get_level(&self, level: usize) -> u8 {
(self.sig[1 - level / 8] >> ((level % 8) * 8)) as u8
}
}
impl<V: BinSafe> RadixKey for SigVal<[u64; 1], V> {
const LEVELS: usize = 8;
fn get_level(&self, level: usize) -> u8 {
(self.sig[0] >> ((level % 8) * 8)) as u8
}
}
impl<S: Sig + PartialEq, V: BinSafe> PartialEq for SigVal<S, V> {
fn eq(&self, other: &Self) -> bool {
self.sig == other.sig
}
}
#[derive(
Debug,
Clone,
Copy,
Default,
MemDbg,
MemSize,
PartialEq,
Eq,
PartialOrd,
Ord,
FromBytes,
IntoBytes,
)]
#[mem_size(flat)]
#[cfg_attr(
feature = "epserde",
derive(epserde::Epserde),
repr(C),
epserde(zero_copy)
)]
pub struct EmptyVal(());
impl BitXor for EmptyVal {
type Output = EmptyVal;
fn bitxor(self, _: EmptyVal) -> Self::Output {
EmptyVal(())
}
}
impl BitXorAssign for EmptyVal {
fn bitxor_assign(&mut self, _: EmptyVal) {}
}
impl From<EmptyVal> for u128 {
fn from(_: EmptyVal) -> u128 {
0
}
}
impl<V: BinSafe + BitXor<Output: BinSafe>> BitXor<SigVal<[u64; 1], V>> for SigVal<[u64; 1], V> {
type Output = SigVal<[u64; 1], V::Output>;
fn bitxor(self, rhs: SigVal<[u64; 1], V>) -> Self::Output {
SigVal {
sig: [self.sig[0].bitxor(rhs.sig[0])],
val: self.val.bitxor(rhs.val),
}
}
}
impl<V: BinSafe + BitXor<Output: BinSafe>> BitXor<SigVal<[u64; 2], V>> for SigVal<[u64; 2], V> {
type Output = SigVal<[u64; 2], V::Output>;
fn bitxor(self, rhs: SigVal<[u64; 2], V>) -> Self::Output {
SigVal {
sig: [
self.sig[0].bitxor(rhs.sig[0]),
self.sig[1].bitxor(rhs.sig[1]),
],
val: self.val.bitxor(rhs.val),
}
}
}
impl<V: BinSafe + BitXorAssign> BitXorAssign<SigVal<[u64; 1], V>> for SigVal<[u64; 1], V> {
fn bitxor_assign(&mut self, rhs: SigVal<[u64; 1], V>) {
self.sig[0] ^= rhs.sig[0];
self.val ^= rhs.val;
}
}
impl<V: BinSafe + BitXorAssign> BitXorAssign<SigVal<[u64; 2], V>> for SigVal<[u64; 2], V> {
fn bitxor_assign(&mut self, rhs: SigVal<[u64; 2], V>) {
self.sig[0] ^= rhs.sig[0];
self.sig[1] ^= rhs.sig[1];
self.val ^= rhs.val;
}
}
pub trait ToSig<S> {
fn to_sig(key: impl Borrow<Self>, seed: u64) -> S;
}
impl ToSig<[u64; 2]> for String {
#[inline]
fn to_sig(key: impl Borrow<Self>, seed: u64) -> [u64; 2] {
<&str>::to_sig(&**key.borrow(), seed)
}
}
impl ToSig<[u64; 1]> for String {
#[inline]
fn to_sig(key: impl Borrow<Self>, seed: u64) -> [u64; 1] {
<&str>::to_sig(&**key.borrow(), seed)
}
}
impl ToSig<[u64; 2]> for &String {
#[inline]
fn to_sig(key: impl Borrow<Self>, seed: u64) -> [u64; 2] {
<&str>::to_sig(&***key.borrow(), seed)
}
}
impl ToSig<[u64; 1]> for &String {
#[inline]
fn to_sig(key: impl Borrow<Self>, seed: u64) -> [u64; 1] {
<&str>::to_sig(&***key.borrow(), seed)
}
}
impl ToSig<[u64; 2]> for str {
#[inline]
fn to_sig(key: impl Borrow<Self>, seed: u64) -> [u64; 2] {
<&str>::to_sig(key.borrow(), seed)
}
}
impl ToSig<[u64; 1]> for str {
#[inline]
fn to_sig(key: impl Borrow<Self>, seed: u64) -> [u64; 1] {
<&str>::to_sig(key.borrow(), seed)
}
}
impl ToSig<[u64; 2]> for &str {
#[inline]
fn to_sig(key: impl Borrow<Self>, seed: u64) -> [u64; 2] {
<&[u8]>::to_sig(key.borrow().as_bytes(), seed)
}
}
impl ToSig<[u64; 1]> for &str {
#[inline]
fn to_sig(key: impl Borrow<Self>, seed: u64) -> [u64; 1] {
<&[u8]>::to_sig(key.borrow().as_bytes(), seed)
}
}
macro_rules! to_sig_prim {
($($ty:ty),*) => {$(
impl ToSig<[u64; 2]> for $ty {
fn to_sig(key: impl Borrow<Self>, seed: u64) -> [u64; 2] {
let bytes = key.borrow().to_ne_bytes();
let mut hasher = xxh3::Xxh3::with_seed(seed);
hasher.update(bytes.as_slice());
<[u64; 2]>::from_hasher(&hasher)
}
}
impl ToSig<[u64;1]> for $ty {
fn to_sig(key: impl Borrow<Self>, seed: u64) -> [u64; 1] {
let bytes = key.borrow().to_ne_bytes();
let mut hasher = xxh3::Xxh3::with_seed(seed);
hasher.update(bytes.as_slice());
<[u64; 1]>::from_hasher(&hasher)
}
}
)*};
}
to_sig_prim!(
isize, usize, i8, i16, i32, i64, i128, u8, u16, u32, u64, u128
);
macro_rules! to_sig_slice {
($($ty:ty),*) => {$(
impl ToSig<[u64; 2]> for &[$ty] {
fn to_sig(key: impl Borrow<Self>, seed: u64) -> [u64; 2] {
let bytes = unsafe {key.borrow().align_to::<u8>().1 };
let mut hasher = xxh3::Xxh3::with_seed(seed);
hasher.update(bytes);
<[u64; 2]>::from_hasher(&hasher)
}
}
impl ToSig<[u64;1]> for &[$ty] {
fn to_sig(key: impl Borrow<Self>, seed: u64) -> [u64; 1] {
let bytes = unsafe {key.borrow().align_to::<u8>().1 };
let mut hasher = xxh3::Xxh3::with_seed(seed);
hasher.update(bytes);
<[u64; 1]>::from_hasher(&hasher)
}
}
impl ToSig<[u64; 2]> for [$ty] {
fn to_sig(key: impl Borrow<Self>, seed: u64) -> [u64; 2] {
<&[$ty]>::to_sig(key.borrow(), seed)
}
}
impl ToSig<[u64;1]> for [$ty] {
fn to_sig(key: impl Borrow<Self>, seed: u64) -> [u64; 1] {
<&[$ty]>::to_sig(key.borrow(), seed)
}
}
)*};
}
to_sig_slice!(
isize, usize, i8, i16, i32, i64, i128, u8, u16, u32, u64, u128
);
pub trait SigStore<S: Sig + BinSafe, V: BinSafe> {
type Error: std::error::Error + Send + Sync + 'static;
fn try_push(&mut self, sig_val: SigVal<S, V>) -> Result<(), Self::Error>;
type ShardStore: ShardStore<S, V> + Send + Sync;
fn into_shard_store(self, shard_high_bits: u32) -> Result<Self::ShardStore>;
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
fn max_shard_high_bits(&self) -> u32;
fn temp_dir(&self) -> Option<&tempfile::TempDir>;
}
#[derive(Debug)]
pub struct SigStoreImpl<S, V, B> {
len: usize,
buckets_high_bits: u32,
max_shard_high_bits: u32,
buckets_mask: u64,
max_shard_mask: u64,
buckets: VecDeque<B>,
bucket_sizes: Vec<usize>,
shard_sizes: Vec<usize>,
temp_dir: Option<tempfile::TempDir>,
_marker: PhantomData<(S, V)>,
}
pub fn new_offline<S: BinSafe + Sig, V: BinSafe>(
buckets_high_bits: u32,
max_shard_high_bits: u32,
_expected_num_keys: Option<usize>,
) -> Result<SigStoreImpl<S, V, BufWriter<File>>> {
let temp_dir = tempfile::TempDir::new()?;
let mut writers = VecDeque::new();
for i in 0..1 << buckets_high_bits {
let file = File::options()
.read(true)
.write(true)
.create(true)
.truncate(true)
.open(temp_dir.path().join(format!("{}.tmp", i)))?;
writers.push_back(BufWriter::new(file));
}
Ok(SigStoreImpl {
len: 0,
buckets_high_bits,
max_shard_high_bits,
buckets_mask: (1u64 << buckets_high_bits) - 1,
max_shard_mask: (1u64 << max_shard_high_bits) - 1,
buckets: writers,
bucket_sizes: vec![0; 1 << buckets_high_bits],
shard_sizes: vec![0; 1 << max_shard_high_bits],
temp_dir: Some(temp_dir),
_marker: PhantomData,
})
}
pub fn new_online<S: BinSafe + Sig, V: BinSafe>(
buckets_high_bits: u32,
max_shard_high_bits: u32,
expected_num_keys: Option<usize>,
) -> Result<SigStoreImpl<S, V, Vec<SigVal<S, V>>>> {
let mut writers = VecDeque::new();
let initial_capacity = expected_num_keys
.map(|n| (n.div_ceil(1 << buckets_high_bits) as f64 * 1.05) as usize)
.unwrap_or(0);
writers.resize_with(1 << buckets_high_bits, || {
Vec::with_capacity(initial_capacity)
});
Ok(SigStoreImpl {
len: 0,
buckets_high_bits,
max_shard_high_bits,
buckets_mask: (1u64 << buckets_high_bits) - 1,
max_shard_mask: (1u64 << max_shard_high_bits) - 1,
buckets: writers,
bucket_sizes: vec![0; 1 << buckets_high_bits],
shard_sizes: vec![0; 1 << max_shard_high_bits],
temp_dir: None,
_marker: PhantomData,
})
}
impl<S: BinSafe + Sig + Send + Sync, V: BinSafe> SigStore<S, V>
for SigStoreImpl<S, V, BufWriter<File>>
{
type Error = std::io::Error;
fn try_push(&mut self, sig_val: SigVal<S, V>) -> Result<(), Self::Error> {
self.len += 1;
let buffer = sig_val
.sig
.high_bits(self.buckets_high_bits, self.buckets_mask) as usize;
let shard = sig_val
.sig
.high_bits(self.max_shard_high_bits, self.max_shard_mask) as usize;
self.bucket_sizes[buffer] += 1;
self.shard_sizes[shard] += 1;
write_binary(&mut self.buckets[buffer], std::slice::from_ref(&sig_val))
}
type ShardStore = ShardStoreImpl<S, V, BufReader<File>>;
fn into_shard_store(mut self, shard_high_bits: u32) -> Result<Self::ShardStore> {
assert!(shard_high_bits <= self.max_shard_high_bits);
let mut files = Vec::with_capacity(self.buckets.len());
for _ in 0..1 << self.buckets_high_bits {
let mut writer = self.buckets.pop_front().unwrap();
writer.flush()?;
let mut file = writer.into_inner()?;
file.seek(SeekFrom::Start(0))?;
files.push(BufReader::new(file));
}
Ok(ShardStoreImpl {
bucket_high_bits: self.buckets_high_bits,
shard_high_bits,
max_shard_high_bits: self.max_shard_high_bits,
buckets: files,
buf_sizes: self.bucket_sizes,
fine_shard_sizes: self.shard_sizes,
_marker: PhantomData,
})
}
fn len(&self) -> usize {
self.len
}
fn max_shard_high_bits(&self) -> u32 {
self.max_shard_high_bits
}
fn temp_dir(&self) -> Option<&tempfile::TempDir> {
self.temp_dir.as_ref()
}
}
impl<S: BinSafe + Sig + Send + Sync, V: BinSafe> SigStore<S, V>
for SigStoreImpl<S, V, Vec<SigVal<S, V>>>
{
type Error = core::convert::Infallible;
fn try_push(&mut self, sig_val: SigVal<S, V>) -> Result<(), Self::Error> {
self.len += 1;
let buffer = sig_val
.sig
.high_bits(self.buckets_high_bits, self.buckets_mask) as usize;
let shard = sig_val
.sig
.high_bits(self.max_shard_high_bits, self.max_shard_mask) as usize;
self.bucket_sizes[buffer] += 1;
self.shard_sizes[shard] += 1;
self.buckets[buffer].push(sig_val);
Ok(())
}
type ShardStore = ShardStoreImpl<S, V, Arc<Vec<SigVal<S, V>>>>;
fn into_shard_store(self, shard_high_bits: u32) -> Result<Self::ShardStore> {
assert!(shard_high_bits <= self.max_shard_high_bits);
let files = self
.buckets
.into_iter()
.map(|mut x| {
x.shrink_to_fit();
Arc::new(x)
})
.collect();
Ok(ShardStoreImpl {
bucket_high_bits: self.buckets_high_bits,
shard_high_bits,
max_shard_high_bits: self.max_shard_high_bits,
buckets: files,
buf_sizes: self.bucket_sizes,
fine_shard_sizes: self.shard_sizes,
_marker: PhantomData,
})
}
fn len(&self) -> usize {
self.len
}
fn max_shard_high_bits(&self) -> u32 {
self.max_shard_high_bits
}
fn temp_dir(&self) -> Option<&tempfile::TempDir> {
None
}
}
#[cfg(feature = "rayon")]
impl<S: BinSafe + Sig + Send + Sync, V: BinSafe + Send + Sync>
SigStoreImpl<S, V, Vec<SigVal<S, V>>>
{
pub fn par_populate(
&mut self,
n: usize,
max_num_threads: usize,
f: impl Fn(usize) -> SigVal<S, V> + Send + Sync,
) -> V
where
V: Default + Ord + Send,
{
use rayon::prelude::*;
use std::sync::Mutex;
let num_buckets = 1usize << self.buckets_high_bits;
let num_shards = 1usize << self.max_shard_high_bits;
let bhb = self.buckets_high_bits;
let bmask = self.buckets_mask;
let shb = self.max_shard_high_bits;
let smask = self.max_shard_mask;
let mutexed_buckets: Vec<Mutex<Vec<SigVal<S, V>>>> =
self.buckets.drain(..).map(Mutex::new).collect();
const CAP: usize = 48;
use arrayvec::ArrayVec;
let max_val = (0..n)
.into_par_iter()
.with_min_len((n / max_num_threads).max(1_000_000))
.fold(
|| {
let bufs: Box<[ArrayVec<SigVal<S, V>, CAP>]> =
(0..num_buckets).map(|_| ArrayVec::new()).collect();
let bc: Box<[usize]> = vec![0usize; num_buckets].into();
let sc: Box<[usize]> = vec![0usize; num_shards].into();
(V::default(), bufs, bc, sc)
},
|(mut local_max, mut local_bufs, mut bc, mut sc): (
V,
Box<[ArrayVec<SigVal<S, V>, CAP>]>,
Box<[usize]>,
Box<[usize]>,
),
i| {
let sv = f(i);
local_max = Ord::max(local_max, sv.val);
let bucket = sv.sig.high_bits(bhb, bmask) as usize;
let shard = sv.sig.high_bits(shb, smask) as usize;
bc[bucket] += 1;
sc[shard] += 1;
local_bufs[bucket].push(sv);
if local_bufs[bucket].is_full() {
mutexed_buckets[bucket]
.lock()
.unwrap()
.extend(local_bufs[bucket].drain(..));
}
(local_max, local_bufs, bc, sc)
},
)
.map(|(local_max, local_bufs, bc, sc)| {
for (bucket, buf) in local_bufs.into_vec().into_iter().enumerate() {
if !buf.is_empty() {
mutexed_buckets[bucket].lock().unwrap().extend(buf);
}
}
(local_max, bc, sc)
})
.reduce(
|| {
let bc: Box<[usize]> = vec![0usize; num_buckets].into();
let sc: Box<[usize]> = vec![0usize; num_shards].into();
(V::default(), bc, sc)
},
|(max_a, mut bc_a, mut sc_a), (max_b, bc_b, sc_b)| {
let m = Ord::max(max_a, max_b);
for (a, b) in bc_a.iter_mut().zip(bc_b.iter()) {
*a += b;
}
for (a, b) in sc_a.iter_mut().zip(sc_b.iter()) {
*a += b;
}
(m, bc_a, sc_a)
},
);
let (max_val, local_bc, local_sc) = max_val;
self.buckets
.extend(mutexed_buckets.into_iter().map(|m| m.into_inner().unwrap()));
for (i, c) in local_bc.iter().enumerate() {
self.bucket_sizes[i] += c;
}
for (i, c) in local_sc.iter().enumerate() {
self.shard_sizes[i] += c;
}
self.len += n;
max_val
}
}
pub trait ShardStore<S: Sig, V: BinSafe> {
fn shard_sizes(&self) -> Box<dyn Iterator<Item = usize> + '_>;
fn iter(&mut self) -> Box<dyn Iterator<Item = Arc<Vec<SigVal<S, V>>>> + Send + Sync + '_>;
fn drain(&mut self) -> Box<dyn Iterator<Item = Arc<Vec<SigVal<S, V>>>> + Send + Sync + '_>;
fn set_shard_high_bits(&mut self, new_bits: u32);
fn max_shard_high_bits(&self) -> u32;
fn len(&self) -> usize {
self.shard_sizes().sum()
}
}
#[derive(Debug)]
pub struct ShardStoreImpl<S, V, B> {
bucket_high_bits: u32,
shard_high_bits: u32,
max_shard_high_bits: u32,
buckets: Vec<B>,
buf_sizes: Vec<usize>,
fine_shard_sizes: Vec<usize>,
_marker: PhantomData<(S, V)>,
}
impl<S: BinSafe + Sig + Send + Sync, V: BinSafe + Send + Sync, B: Send + Sync> ShardStore<S, V>
for ShardStoreImpl<S, V, B>
where
for<'a> ShardIter<S, V, B, &'a mut Self>: Iterator<Item = Arc<Vec<SigVal<S, V>>>> + Send + Sync,
{
fn shard_sizes(&self) -> Box<dyn Iterator<Item = usize> + '_> {
let coarsen = 1usize << (self.max_shard_high_bits - self.shard_high_bits);
Box::new(
self.fine_shard_sizes
.chunks(coarsen)
.map(|c| c.iter().sum()),
)
}
fn set_shard_high_bits(&mut self, new_bits: u32) {
assert!(new_bits <= self.max_shard_high_bits);
self.shard_high_bits = new_bits;
}
fn max_shard_high_bits(&self) -> u32 {
self.max_shard_high_bits
}
fn iter(&mut self) -> Box<dyn Iterator<Item = Arc<Vec<SigVal<S, V>>>> + Send + Sync + '_> {
Box::new(ShardIter {
store: self,
borrowed: true,
next_bucket: 0,
next_shard: 0,
shards: VecDeque::from(vec![]),
_marker: PhantomData,
})
}
fn drain(&mut self) -> Box<dyn Iterator<Item = Arc<Vec<SigVal<S, V>>>> + Send + Sync + '_> {
Box::new(ShardIter {
store: self,
borrowed: false,
next_bucket: 0,
next_shard: 0,
shards: VecDeque::from(vec![]),
_marker: PhantomData,
})
}
}
#[derive(Debug)]
pub struct ShardIter<S: BinSafe + Sig, V: BinSafe, B, T: BorrowMut<ShardStoreImpl<S, V, B>>> {
store: T,
borrowed: bool,
next_bucket: usize,
next_shard: usize,
shards: VecDeque<Vec<SigVal<S, V>>>,
_marker: PhantomData<(B, V)>,
}
impl<
S: BinSafe + Sig + Send + Sync,
V: BinSafe,
T: BorrowMut<ShardStoreImpl<S, V, BufReader<File>>>,
> Iterator for ShardIter<S, V, BufReader<File>, T>
{
type Item = Arc<Vec<SigVal<S, V>>>;
fn next(&mut self) -> Option<Self::Item> {
let store = self.store.borrow_mut();
if store.bucket_high_bits >= store.shard_high_bits {
if self.next_bucket >= store.buckets.len() {
return None;
}
let to_aggr = 1 << (store.bucket_high_bits - store.shard_high_bits);
let coarsen = 1usize << (store.max_shard_high_bits - store.shard_high_bits);
let base = self.next_shard * coarsen;
let len: usize = store.fine_shard_sizes[base..base + coarsen].iter().sum();
let mut shard = Vec::<SigVal<S, V>>::with_capacity(len);
#[allow(clippy::uninit_vec)]
unsafe {
shard.set_len(len);
}
{
let (pre, mut buf, post) = unsafe { shard.align_to_mut::<u8>() };
assert!(pre.is_empty());
assert!(post.is_empty());
for i in self.next_bucket..self.next_bucket + to_aggr {
let bytes = store.buf_sizes[i] * core::mem::size_of::<SigVal<S, V>>();
store.buckets[i].seek(SeekFrom::Start(0)).unwrap();
store.buckets[i].read_exact(&mut buf[..bytes]).unwrap();
if !self.borrowed {
let _ = store.buckets[i].get_mut().set_len(0);
}
buf = &mut buf[bytes..];
}
}
let res = shard;
self.next_bucket += to_aggr;
self.next_shard += 1;
Some(Arc::new(res))
} else {
if self.shards.is_empty() {
if self.next_bucket == store.buckets.len() {
return None;
}
let split_into = 1 << (store.shard_high_bits - store.bucket_high_bits);
let shard_offset = self.next_bucket * split_into;
let coarsen = 1usize << (store.max_shard_high_bits - store.shard_high_bits);
for shard in shard_offset..shard_offset + split_into {
let base = shard * coarsen;
let cap: usize = store.fine_shard_sizes[base..base + coarsen].iter().sum();
self.shards.push_back(Vec::with_capacity(cap));
}
let mut len = store.buf_sizes[self.next_bucket];
let buf_size = 1024;
let mut buffer = Vec::<SigVal<S, V>>::with_capacity(buf_size);
#[allow(clippy::uninit_vec)]
unsafe {
buffer.set_len(buf_size);
}
let shard_mask = (1 << store.shard_high_bits) - 1;
store.buckets[self.next_bucket]
.seek(SeekFrom::Start(0))
.unwrap();
while len > 0 {
let to_read = buf_size.min(len);
unsafe {
buffer.set_len(to_read);
}
let (pre, buf, after) = unsafe { buffer.align_to_mut::<u8>() };
debug_assert!(pre.is_empty());
debug_assert!(after.is_empty());
store.buckets[self.next_bucket].read_exact(buf).unwrap();
for &v in &buffer {
let shard = v.sig.high_bits(store.shard_high_bits, shard_mask) as usize
- shard_offset;
self.shards[shard].push(v);
}
len -= to_read;
}
self.next_bucket += 1;
}
self.next_shard += 1;
Some(Arc::new(self.shards.pop_front().unwrap()))
}
}
#[inline(always)]
fn size_hint(&self) -> (usize, Option<usize>) {
(self.len(), Some(self.len()))
}
}
impl<
S: BinSafe + Sig + Send + Sync,
V: BinSafe,
T: BorrowMut<ShardStoreImpl<S, V, Arc<Vec<SigVal<S, V>>>>>,
> Iterator for ShardIter<S, V, Arc<Vec<SigVal<S, V>>>, T>
{
type Item = Arc<Vec<SigVal<S, V>>>;
fn next(&mut self) -> Option<Self::Item> {
let store = self.store.borrow_mut();
if store.bucket_high_bits == store.shard_high_bits {
if self.next_bucket >= store.buckets.len() {
return None;
}
let res = if self.borrowed {
store.buckets[self.next_bucket].clone()
} else {
std::mem::take(&mut store.buckets[self.next_bucket])
};
self.next_bucket += 1;
self.next_shard += 1;
Some(res)
} else if store.bucket_high_bits > store.shard_high_bits {
if self.next_bucket >= store.buckets.len() {
return None;
}
let to_aggr = 1 << (store.bucket_high_bits - store.shard_high_bits);
let coarsen = 1usize << (store.max_shard_high_bits - store.shard_high_bits);
let base = self.next_shard * coarsen;
let len: usize = store.fine_shard_sizes[base..base + coarsen].iter().sum();
let mut shard = Vec::with_capacity(len);
for i in self.next_bucket..self.next_bucket + to_aggr {
if self.borrowed {
shard.extend(store.buckets[i].iter());
} else {
shard.extend(std::mem::take(&mut store.buckets[i]).iter());
}
}
let res = shard;
self.next_bucket += to_aggr;
self.next_shard += 1;
Some(Arc::new(res))
} else {
if self.shards.is_empty() {
if self.next_bucket == store.buckets.len() {
return None;
}
let split_into = 1 << (store.shard_high_bits - store.bucket_high_bits);
let shard_offset = self.next_bucket * split_into;
let coarsen = 1usize << (store.max_shard_high_bits - store.shard_high_bits);
for shard in shard_offset..shard_offset + split_into {
let base = shard * coarsen;
let cap: usize = store.fine_shard_sizes[base..base + coarsen].iter().sum();
self.shards.push_back(Vec::with_capacity(cap));
}
let shard_mask = (1 << store.shard_high_bits) - 1;
for &v in store.buckets[self.next_bucket].iter() {
let shard =
v.sig.high_bits(store.shard_high_bits, shard_mask) as usize - shard_offset;
self.shards[shard].push(v);
}
if !self.borrowed {
drop(std::mem::take(&mut store.buckets[self.next_bucket]));
}
self.next_bucket += 1;
}
self.next_shard += 1;
Some(Arc::new(self.shards.pop_front().unwrap()))
}
}
#[inline(always)]
fn size_hint(&self) -> (usize, Option<usize>) {
(self.len(), Some(self.len()))
}
}
impl<
S: BinSafe + Sig + Send + Sync,
V: BinSafe,
B: Send + Sync,
T: BorrowMut<ShardStoreImpl<S, V, B>>,
> ExactSizeIterator for ShardIter<S, V, B, T>
where
for<'a> ShardIter<S, V, B, T>: Iterator,
{
#[inline(always)]
fn len(&self) -> usize {
(1usize << self.store.borrow().shard_high_bits) - self.next_shard
}
}
impl<
S: BinSafe + Sig + Send + Sync,
V: BinSafe,
B: Send + Sync,
T: BorrowMut<ShardStoreImpl<S, V, B>>,
> FusedIterator for ShardIter<S, V, B, T>
where
for<'a> ShardIter<S, V, B, T>: Iterator,
{
}
fn write_binary<S: BinSafe + Sig, V: BinSafe>(
writer: &mut impl Write,
tuples: &[SigVal<S, V>],
) -> std::io::Result<()> {
let (pre, buf, post) = unsafe { tuples.align_to::<u8>() };
debug_assert!(pre.is_empty());
debug_assert!(post.is_empty());
writer.write_all(buf)
}
pub struct FilteredShardStore<'a, SS: ?Sized, S, V, F> {
inner: &'a mut SS,
filter: F,
shard_sizes: Vec<usize>,
_marker: std::marker::PhantomData<(S, V)>,
}
impl<'a, SS: ?Sized, S, V, F> FilteredShardStore<'a, SS, S, V, F>
where
SS: ShardStore<S, V>,
S: Sig + BinSafe + Send + Sync,
V: BinSafe + Copy,
F: Fn(&SigVal<S, V>) -> bool,
{
pub fn new(
inner: &'a mut SS,
shard_high_bits: u32,
filter: F,
shard_sizes: Vec<usize>,
) -> Self {
inner.set_shard_high_bits(shard_high_bits);
Self {
inner,
filter,
shard_sizes,
_marker: std::marker::PhantomData,
}
}
}
impl<'a, SS: ?Sized, S, V, F> ShardStore<S, V> for FilteredShardStore<'a, SS, S, V, F>
where
SS: ShardStore<S, V>,
S: Sig + BinSafe + Send + Sync,
V: BinSafe + Copy + Send + Sync,
F: Fn(&SigVal<S, V>) -> bool + Send + Sync,
{
fn shard_sizes(&self) -> Box<dyn Iterator<Item = usize> + '_> {
Box::new(self.shard_sizes.iter().copied())
}
fn set_shard_high_bits(&mut self, new_bits: u32) {
self.inner.set_shard_high_bits(new_bits);
self.shard_sizes = self
.inner
.iter()
.map(|shard| shard.iter().filter(|sv| (self.filter)(sv)).count())
.collect();
}
fn max_shard_high_bits(&self) -> u32 {
self.inner.max_shard_high_bits()
}
fn iter(&mut self) -> Box<dyn Iterator<Item = Arc<Vec<SigVal<S, V>>>> + Send + Sync + '_> {
let filter = &self.filter;
let inner_shards: Vec<_> = self.inner.iter().collect();
Box::new(
inner_shards
.into_iter()
.map(move |shard| {
let filtered: Vec<_> = shard.iter().filter(|sv| filter(sv)).copied().collect();
Arc::new(filtered)
})
.collect::<Vec<_>>()
.into_iter(),
)
}
fn drain(&mut self) -> Box<dyn Iterator<Item = Arc<Vec<SigVal<S, V>>>> + Send + Sync + '_> {
self.iter()
}
}
impl<S: Sig, V: BinSafe> ShardStore<S, V> for Box<dyn ShardStore<S, V> + Send + Sync> {
fn shard_sizes(&self) -> Box<dyn Iterator<Item = usize> + '_> {
(**self).shard_sizes()
}
fn set_shard_high_bits(&mut self, new_bits: u32) {
(**self).set_shard_high_bits(new_bits)
}
fn max_shard_high_bits(&self) -> u32 {
(**self).max_shard_high_bits()
}
fn iter(&mut self) -> Box<dyn Iterator<Item = Arc<Vec<SigVal<S, V>>>> + Send + Sync + '_> {
(**self).iter()
}
fn drain(&mut self) -> Box<dyn Iterator<Item = Arc<Vec<SigVal<S, V>>>> + Send + Sync + '_> {
(**self).drain()
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::{RngExt, SeedableRng, rngs::SmallRng};
fn _test_sig_store<S: BinSafe + Sig + Send + Sync>(
mut sig_store: impl SigStore<S, u64>,
get_rand_sig: fn(&mut SmallRng) -> S,
) -> anyhow::Result<()> {
let mut rand = SmallRng::seed_from_u64(0);
let shard_high_bits = sig_store.max_shard_high_bits();
for _ in (0..10000).rev() {
sig_store.try_push(SigVal {
sig: get_rand_sig(&mut rand),
val: rand.random(),
})?;
}
let mut shard_store = sig_store.into_shard_store(shard_high_bits).unwrap();
for _ in 0..2 {
let mut count = 0;
for shard in shard_store.iter() {
for &w in shard.iter() {
assert_eq!(
count,
w.sig.high_bits(shard_high_bits, (1 << shard_high_bits) - 1)
);
}
count += 1;
}
assert_eq!(count, 1 << shard_high_bits);
}
let mut count = 0;
for shard in shard_store.drain() {
for &w in shard.iter() {
assert_eq!(
count,
w.sig.high_bits(shard_high_bits, (1 << shard_high_bits) - 1)
);
}
count += 1;
}
assert_eq!(count, 1 << shard_high_bits);
Ok(())
}
#[test]
fn test_sig_store() -> anyhow::Result<()> {
for max_shard_bits in [0, 2, 8, 9] {
for buckets_high_bits in [0, 2, 8, 9] {
for shard_high_bits in [0, 2, 8, 9] {
if shard_high_bits > max_shard_bits {
continue;
}
_test_sig_store(
new_online(buckets_high_bits, max_shard_bits, None)?,
|rand| [rand.random(), rand.random()],
)?;
_test_sig_store(
new_offline(buckets_high_bits, max_shard_bits, None)?,
|rand| [rand.random(), rand.random()],
)?;
}
}
}
Ok(())
}
fn _test_u8<S: BinSafe + Sig>(
mut sig_store: impl SigStore<S, u8>,
get_rand_sig: fn(&mut SmallRng) -> S,
) -> anyhow::Result<()> {
let mut rand = SmallRng::seed_from_u64(0);
for _ in (0..1000).rev() {
sig_store.try_push(SigVal {
sig: get_rand_sig(&mut rand),
val: rand.random(),
})?;
}
let mut shard_store = sig_store.into_shard_store(2)?;
let mut count = 0;
for shard in shard_store.iter() {
for &w in shard.iter() {
assert_eq!(count, w.sig.high_bits(2, (1 << 2) - 1));
}
count += 1;
}
assert_eq!(count, 4);
Ok(())
}
#[test]
fn test_u8() -> anyhow::Result<()> {
_test_u8(new_online(2, 2, None)?, |rand| {
[rand.random(), rand.random()]
})?;
_test_u8(new_offline(2, 2, None)?, |rand| {
[rand.random(), rand.random()]
})?;
Ok(())
}
}