#![cfg_attr(feature = "unstable", feature(iter_array_chunks))]
pub mod hash;
pub mod pack;
pub mod util;
pub mod bucket_fn;
mod bucket_idx;
mod build;
mod reduce;
mod shard;
mod sort_buckets;
#[doc(hidden)]
pub mod stats;
#[cfg(test)]
mod test;
use bitvec::{bitvec, vec::BitVec};
use bucket_fn::BucketFn;
use bucket_fn::CubicEps;
use bucket_fn::Linear;
use bucket_fn::SquareEps;
use cacheline_ef::CachelineEfVec;
use itertools::izip;
use itertools::Itertools;
use log::trace;
use log::warn;
use mem_dbg::MemSize;
use pack::MutPacked;
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha8Rng;
use rayon::prelude::*;
pub use shard::Sharding;
use stats::BucketStats;
use std::array::from_fn;
use std::{borrow::Borrow, default::Default, marker::PhantomData, time::Instant};
use crate::{hash::*, pack::Packed, reduce::*, util::log_duration};
#[derive(Clone, Copy, Debug, MemSize)]
#[cfg_attr(feature = "epserde", derive(epserde::prelude::Epserde))]
#[cfg_attr(feature = "epserde", deep_copy)]
pub struct PtrHashParams<BF> {
pub remap: bool,
pub alpha: f64,
pub lambda: f64,
pub bucket_fn: BF,
pub keys_per_shard: usize,
pub sharding: Sharding,
pub single_part: bool,
}
impl PtrHashParams<Linear> {
pub fn default_fast() -> Self {
Self {
remap: true,
alpha: 0.99,
lambda: 3.0,
bucket_fn: Linear,
keys_per_shard: 1 << 31,
sharding: Sharding::None,
single_part: false,
}
}
}
#[doc(hidden)]
impl PtrHashParams<SquareEps> {
pub fn default_square() -> Self {
Self {
remap: true,
alpha: 0.99,
lambda: 3.5,
bucket_fn: SquareEps,
keys_per_shard: 1 << 31,
sharding: Sharding::None,
single_part: false,
}
}
}
impl PtrHashParams<CubicEps> {
pub fn default() -> Self {
Self {
remap: true,
alpha: 0.99,
lambda: 3.5,
bucket_fn: CubicEps,
keys_per_shard: 1 << 31,
sharding: Sharding::None,
single_part: false,
}
}
pub fn default_compact() -> Self {
Self {
remap: true,
alpha: 0.99,
lambda: 3.9,
bucket_fn: CubicEps,
keys_per_shard: 1 << 31,
sharding: Sharding::None,
single_part: false,
}
}
}
impl Default for PtrHashParams<CubicEps> {
fn default() -> Self {
Self::default()
}
}
pub type DefaultPtrHash<Hx = hash::FxHash, Key = u64, BF = bucket_fn::CubicEps> =
PtrHash<Key, BF, CachelineEfVec, Hx, Vec<u8>>;
pub trait KeyT: Send + Sync + std::hash::Hash {}
impl<T: Send + Sync + std::hash::Hash> KeyT for T {}
type Rp = FastReduce;
type Rb = FastReduce;
type Pilot = u64;
type PilotHash = u64;
#[cfg_attr(feature = "epserde", derive(epserde::prelude::Epserde))]
#[derive(Clone, MemSize)]
pub struct PtrHash<
Key: KeyT + ?Sized = u64,
BF: BucketFn = bucket_fn::CubicEps,
F: Packed = CachelineEfVec,
Hx: Hasher<Key> = hash::FxHash,
V: AsRef<[u8]> = Vec<u8>,
> {
params: PtrHashParams<BF>,
n: usize,
parts: usize,
shards: usize,
parts_per_shard: usize,
slots_total: usize,
buckets_total: usize,
slots: usize,
buckets: usize,
rem_shards: Rp,
rem_parts: Rp,
rem_buckets: Rb,
rem_buckets_total: Rb,
rem_slots: Rp,
seed: u64,
pilots: V,
remap: F,
_key: PhantomData<Key>,
_hx: PhantomData<Hx>,
}
impl<Key: KeyT, BF: BucketFn, F: MutPacked, Hx: Hasher<Key>> Default
for PtrHash<Key, BF, F, Hx, Vec<u8>>
where
PtrHashParams<BF>: Default,
{
fn default() -> Self {
PtrHash {
params: <PtrHashParams<BF> as Default>::default(),
n: 0,
parts: 0,
shards: 0,
parts_per_shard: 0,
slots_total: 0,
buckets_total: 0,
slots: 0,
buckets: 0,
rem_shards: FastReduce::new(0),
rem_parts: FastReduce::new(0),
rem_buckets: FastReduce::new(0),
rem_buckets_total: FastReduce::new(0),
rem_slots: FastReduce::new(0),
seed: 0,
pilots: vec![],
remap: F::default(),
_key: PhantomData,
_hx: PhantomData,
}
}
}
impl<Key: KeyT, BF: BucketFn, F: MutPacked, Hx: Hasher<Key>> PtrHash<Key, BF, F, Hx, Vec<u8>> {
pub fn new(keys: &[Key], params: PtrHashParams<BF>) -> Self {
let mut ptr_hash = Self::init(keys.len(), params);
ptr_hash.compute_pilots(keys.par_iter()).unwrap();
ptr_hash
}
#[doc(hidden)]
pub fn new_with_stats(keys: &[Key], params: PtrHashParams<BF>) -> (Self, BucketStats) {
let mut ptr_hash = Self::init(keys.len(), params);
let stats = ptr_hash.compute_pilots(keys.par_iter()).unwrap();
(ptr_hash, stats)
}
pub fn try_new(keys: &[Key], params: PtrHashParams<BF>) -> Option<Self> {
let mut ptr_hash = Self::init(keys.len(), params);
ptr_hash.compute_pilots(keys.par_iter())?;
Some(ptr_hash)
}
pub fn new_from_par_iter<'a>(
n: usize,
keys: impl ParallelIterator<Item = impl Borrow<Key>> + Clone + 'a,
params: PtrHashParams<BF>,
) -> Self {
let mut ptr_hash = Self::init(n, params);
ptr_hash.compute_pilots(keys);
ptr_hash
}
fn init(n: usize, mut params: PtrHashParams<BF>) -> Self {
assert!(n < (1 << 40), "Number of keys must be less than 2^40.");
let shards = match params.sharding {
Sharding::None => 1,
_ => n.div_ceil(params.keys_per_shard),
};
let mut keys_per_part;
let mut parts_per_shard;
let mut buckets_per_part;
let mut parts;
let mut buckets_total;
let mut slots_total;
let mut slots_per_part;
parts = (n / 1024).next_power_of_two().next_multiple_of(shards);
if params.single_part {
parts = 1;
}
loop {
keys_per_part = n / parts;
parts_per_shard = parts / shards;
slots_per_part = (keys_per_part as f64 / params.alpha) as usize;
slots_total = parts * slots_per_part;
buckets_per_part = (keys_per_part as f64 / params.lambda).ceil() as usize + 3;
buckets_total = parts * buckets_per_part;
if parts == 1 {
break;
}
let exp_keys_per_part = n as f64 / parts as f64;
let stddev = exp_keys_per_part.sqrt();
let stddevs_away = ((parts as f64).ln() * 2.).sqrt();
let exp_max = exp_keys_per_part + stddev * stddevs_away;
let buf_max = exp_max + 2.0 * stddev;
if buf_max < slots_per_part as f64 {
break;
}
parts = (parts / 2).next_multiple_of(shards);
}
trace!(" keys: {n:>10}");
trace!(" shards: {shards:>10}");
trace!(" parts: {parts:>10}");
trace!(" slots/prt: {slots_per_part:>10}");
trace!(" slots tot: {slots_total:>10}");
trace!(" real alpha: {:>10.4}", n as f64 / slots_total as f64);
trace!(" buckets/prt: {buckets_per_part:>10}");
trace!(" buckets tot: {buckets_total:>10}");
trace!("keys/ bucket: {:>13.2}", n as f64 / buckets_total as f64);
params
.bucket_fn
.set_buckets_per_part(buckets_per_part as u64);
Self {
params,
n,
parts,
shards,
parts_per_shard,
slots_total,
slots: slots_per_part,
buckets_total,
buckets: buckets_per_part,
rem_shards: Rp::new(shards),
rem_parts: Rp::new(parts),
rem_buckets: Rb::new(buckets_per_part),
rem_buckets_total: Rb::new(buckets_total),
rem_slots: Rp::new(slots_per_part),
seed: 0,
pilots: Default::default(),
remap: F::default(),
_key: PhantomData,
_hx: PhantomData,
}
}
fn compute_pilots<'a>(
&mut self,
keys: impl ParallelIterator<Item = impl Borrow<Key>> + Clone + 'a,
) -> Option<BucketStats> {
let overall_start = std::time::Instant::now();
let mut taken: Vec<BitVec> = vec![];
let mut pilots: Vec<u8> = vec![];
let mut tries = 0;
const MAX_TRIES: usize = 10;
let mut rng = ChaCha8Rng::seed_from_u64(31415);
let stats = 's: loop {
tries += 1;
if tries > MAX_TRIES {
warn!("PtrHash failed to find a global seed after {MAX_TRIES} tries.");
return None;
}
if tries > 1 {
trace!("NEW TRY Try {tries} for global seed.");
}
self.seed = rng.random();
pilots.clear();
pilots.resize(self.buckets_total, 0);
for taken in taken.iter_mut() {
taken.clear();
taken.resize(self.slots, false);
}
taken.resize_with(self.parts, || bitvec![0; self.slots]);
let shard_hashes = self.shards(keys.clone());
let shard_pilots = pilots.chunks_mut((self.buckets * self.parts_per_shard).max(1));
let shard_taken = taken.chunks_mut(self.parts_per_shard);
let mut stats = BucketStats::default();
for (shard, (hashes, pilots, taken)) in
izip!(shard_hashes, shard_pilots, shard_taken).enumerate()
{
let start = std::time::Instant::now();
let Some((hashes, part_starts)) = self.sort_parts(shard, hashes) else {
trace!("Found duplicate hashes");
continue 's;
};
let start = log_duration("sort buckets", start);
if let Some(shard_stats) =
self.build_shard(shard, &hashes, &part_starts, pilots, taken)
{
stats.merge(shard_stats);
log_duration("find pilots", start);
} else {
trace!("Could not find pilots");
continue 's;
}
}
let start = std::time::Instant::now();
let remap = self.remap_free_slots(&taken);
log_duration("remap free", start);
if remap.is_err() {
trace!("Failed to construct CachelineEF");
continue 's;
}
break 's stats;
};
self.pilots = pilots;
let (p, r) = self.bits_per_element();
trace!("bits/element: {}", p + r);
log_duration("total build", overall_start);
Some(stats)
}
fn remap_free_slots(&mut self, taken: &Vec<BitVec>) -> Result<(), ()> {
assert_eq!(
taken.iter().map(|t| t.count_zeros()).sum::<usize>(),
self.slots_total - self.n,
"Not the right number of free slots left!\n total slots {} - n {}",
self.slots_total,
self.n
);
if !self.params.remap || self.slots_total == self.n {
return Ok(());
}
let mut v = Vec::with_capacity(self.slots_total - self.n);
let get = |t: &Vec<BitVec>, idx: usize| t[idx / self.slots][idx % self.slots];
for i in taken
.iter()
.enumerate()
.flat_map(|(p, t)| {
let offset = p * self.slots;
t.iter_zeros().map(move |i| offset + i)
})
.take_while(|&i| i < self.n)
{
while !get(&taken, self.n + v.len()) {
v.push(i as u64);
}
v.push(i as u64);
}
self.remap = MutPacked::try_new(v).ok_or(())?;
Ok(())
}
}
impl<Key: KeyT, BF: BucketFn, F: Packed, Hx: Hasher<Key>, V: AsRef<[u8]>>
PtrHash<Key, BF, F, Hx, V>
{
pub fn bits_per_element(&self) -> (f64, f64) {
let pilots = self.pilots.as_ref().size_in_bytes() as f64 / self.n as f64;
let remap = self.remap.size_in_bytes() as f64 / self.n as f64;
(8. * pilots, 8. * remap)
}
pub fn n(&self) -> usize {
self.n
}
pub fn max_index(&self) -> usize {
self.slots_total
}
pub fn slots_per_part(&self) -> usize {
self.slots
}
#[inline]
pub fn index_no_remap(&self, key: &Key) -> usize {
let hx = self.hash_key(key);
let b = self.bucket(hx);
let pilot = self.pilots.as_ref().index(b);
self.slot(hx, pilot)
}
#[inline]
pub fn index_single_part(&self, key: &Key) -> usize {
let hx = self.hash_key(key);
let b = self.bucket_in_part(hx.high());
let pilot = self.pilots.as_ref().index(b);
let slot = self.slot_in_part(hx, pilot);
if slot < self.n {
slot
} else {
self.remap.index(slot - self.n) as usize
}
}
#[inline]
pub fn index(&self, key: &Key) -> usize {
let hx = self.hash_key(key);
let b = self.bucket(hx);
let p = self.pilots.as_ref().index(b);
let slot = self.slot(hx, p);
if slot < self.n {
slot
} else {
self.remap.index(slot - self.n) as usize
}
}
#[inline]
pub fn index_stream<'a, const B: usize, const MINIMAL: bool, Q: Borrow<Key> + 'a>(
&'a self,
keys: impl IntoIterator<Item = Q> + 'a,
) -> impl Iterator<Item = usize> + 'a {
let mut keys = keys.into_iter();
let mut next_hashes: [Hx::H; B] = [Hx::H::default(); B];
let mut next_buckets: [usize; B] = [0; B];
let mut leftover = B;
for idx in 0..B {
let hx = keys
.next()
.map(|k| {
leftover -= 1;
self.hash_key(k.borrow())
})
.unwrap_or_default();
next_hashes[idx] = hx;
next_buckets[idx] = self.bucket(next_hashes[idx]);
crate::util::prefetch_index(self.pilots.as_ref(), next_buckets[idx]);
}
struct It<
'a,
const B: usize,
const MINIMAL: bool,
Key: KeyT,
Q: Borrow<Key> + 'a,
KeyIt: Iterator<Item = Q> + 'a,
BF: BucketFn,
F: Packed,
Hx: Hasher<Key>,
V: AsRef<[u8]>,
> {
ph: &'a PtrHash<Key, BF, F, Hx, V>,
keys: KeyIt,
next_hashes: [Hx::H; B],
next_buckets: [usize; B],
leftover: usize,
}
impl<
'a,
const B: usize,
const MINIMAL: bool,
Key: KeyT,
Q: Borrow<Key> + 'a,
KeyIt: Iterator<Item = Q> + 'a,
BF: BucketFn,
F: Packed,
Hx: Hasher<Key>,
V: AsRef<[u8]>,
> Iterator for It<'a, B, MINIMAL, Key, Q, KeyIt, BF, F, Hx, V>
{
type Item = usize;
fn next(&mut self) -> Option<usize> {
unimplemented!("Use a method that calls `fold()` instead.");
}
#[inline(always)]
fn fold<BB, FF>(mut self, init: BB, mut f: FF) -> BB
where
Self: Sized,
FF: FnMut(BB, Self::Item) -> BB,
{
let mut accum = init;
let mut i = 0;
for key in self.keys {
let next_hash = self.ph.hash_key(key.borrow());
let idx = i % B;
let cur_hash = self.next_hashes[idx];
let cur_bucket = self.next_buckets[idx];
self.next_hashes[idx] = next_hash;
self.next_buckets[idx] = self.ph.bucket(self.next_hashes[idx]);
crate::util::prefetch_index(self.ph.pilots.as_ref(), self.next_buckets[idx]);
let pilot = self.ph.pilots.as_ref().index(cur_bucket);
let slot = self.ph.slot(cur_hash, pilot);
let slot = if MINIMAL && slot >= self.ph.n {
self.ph.remap.index(slot - self.ph.n) as usize
} else {
slot
};
accum = f(accum, slot);
i += 1;
}
for _ in 0..B - self.leftover {
let idx = i % B;
let cur_hash = self.next_hashes[idx];
let cur_bucket = self.next_buckets[idx];
let pilot = self.ph.pilots.as_ref().index(cur_bucket);
let slot = self.ph.slot(cur_hash, pilot);
let slot = if MINIMAL && slot >= self.ph.n {
self.ph.remap.index(slot - self.ph.n) as usize
} else {
slot
};
accum = f(accum, slot);
i += 1;
}
accum
}
}
It::<B, MINIMAL, _, _, _, _, _, _, _> {
ph: self,
keys,
next_hashes,
next_buckets,
leftover,
}
}
#[inline]
pub fn index_batch<'a, const K: usize, const MINIMAL: bool, Q: Borrow<Key> + 'a>(
&'a self,
xs: [Q; K],
) -> [usize; K] {
let hashes = xs.map(|x| self.hash_key(x.borrow()));
let mut buckets: [usize; K] = [0; K];
for idx in 0..K {
buckets[idx] = self.bucket(hashes[idx]);
crate::util::prefetch_index(self.pilots.as_ref(), buckets[idx]);
}
from_fn(
#[inline(always)]
move |idx| {
let pilot = self.pilots.as_ref().index(buckets[idx]);
let slot = self.slot(hashes[idx], pilot);
if MINIMAL && slot >= self.n {
self.remap.index(slot - self.n) as usize
} else {
slot
}
},
)
}
#[doc(hidden)]
#[cfg(feature = "unstable")]
#[inline]
pub fn index_batch_exact<'a, const K: usize, const MINIMAL: bool>(
&'a self,
xs: impl IntoIterator<Item = &'a Key> + 'a,
) -> impl Iterator<Item = usize> + 'a {
let mut buckets: [usize; K] = [0; K];
let mut f = {
#[inline(always)]
move |hx: [Hx::H; K]| {
for idx in 0..K {
buckets[idx] = self.bucket(hx[idx]);
crate::util::prefetch_index(self.pilots.as_ref(), buckets[idx]);
}
(0..K).map(
#[inline(always)]
move |idx| {
let pilot = self.pilots.as_ref().index(buckets[idx]);
let slot = self.slot(hx[idx], pilot);
if MINIMAL && slot >= self.n {
self.remap.index(slot - self.n) as usize
} else {
slot
}
},
)
}
};
let array_chunks = xs.into_iter().map(|x| self.hash_key(x)).array_chunks::<K>();
array_chunks.into_iter().flat_map(
#[inline(always)]
move |chunk| f(chunk),
)
}
#[doc(hidden)]
#[inline]
pub fn index_batch_exact2<'a, const K: usize, const MINIMAL: bool>(
&'a self,
xs: impl IntoIterator<Item = &'a Key, IntoIter: ExactSizeIterator> + 'a,
) -> impl Iterator<Item = usize> + 'a {
let mut buckets: [usize; K] = [0; K];
let mut hs: [Hx::H; K] = [Hx::H::default(); K];
let mut xs = xs
.into_iter()
.map(|x| self.hash_key(x))
.chain([Default::default(); K]);
for i in 0..K {
hs[i] = xs.next().unwrap();
}
let mut idx = K;
xs.map(move |hx| {
if idx == K {
idx = 0;
for idx in 0..K {
buckets[idx] = self.bucket(hs[idx]);
crate::util::prefetch_index(self.pilots.as_ref(), buckets[idx]);
}
}
let pilot = self.pilots.as_ref().index(buckets[idx]);
let slot = self.slot(hs[idx], pilot);
hs[idx] = hx;
idx += 1;
if MINIMAL && slot >= self.n {
self.remap.index(slot - self.n) as usize
} else {
slot
}
})
}
fn hash_key(&self, x: &Key) -> Hx::H {
Hx::hash(x, self.seed)
}
fn hash_pilot(&self, p: Pilot) -> PilotHash {
MulHash::hash(&p, self.seed)
}
fn shard(&self, hx: Hx::H) -> usize {
self.rem_shards.reduce(hx.high())
}
fn part(&self, hx: Hx::H) -> usize {
self.rem_parts.reduce(hx.high())
}
fn bucket_in_part(&self, x: u64) -> usize {
if BF::B_OUTPUT {
self.params.bucket_fn.call(x) as usize
} else {
self.rem_buckets.reduce(self.params.bucket_fn.call(x))
}
}
fn bucket(&self, hx: Hx::H) -> usize {
if BF::LINEAR {
return self.rem_buckets_total.reduce(hx.high());
}
let (part, hx) = self.rem_parts.reduce_with_remainder(hx.high());
let bucket = self.bucket_in_part(hx);
part * self.buckets + bucket
}
fn slot(&self, hx: Hx::H, pilot: u64) -> usize {
(self.part(hx) * self.slots) + self.slot_in_part(hx, pilot)
}
fn slot_in_part(&self, hx: Hx::H, pilot: Pilot) -> usize {
self.slot_in_part_hp(hx, self.hash_pilot(pilot))
}
fn slot_in_part_hp(&self, hx: Hx::H, hp: PilotHash) -> usize {
self.rem_slots
.reduce(MulHash::C.wrapping_mul(hx.low() ^ hp))
}
}