#![allow(clippy::type_complexity)]
use std::borrow::Borrow;
use super::shard_edge::FuseLge3Shards;
use crate::bits::BitFieldVec;
use crate::func::VFunc;
use crate::func::shard_edge::ShardEdge;
use crate::traits::Word;
use crate::utils::*;
use mem_dbg::*;
use num_primitive::{PrimitiveNumber, PrimitiveNumberAs};
use value_traits::slices::SliceByValue;
#[derive(Clone, MemSize, MemDbg)]
#[cfg_attr(feature = "epserde", derive(epserde::Epserde))]
#[cfg_attr(
feature = "epserde",
epserde(bound(
deser = "D::Value: for<'a> epserde::deser::DeserInner<DeserType<'a> = D::Value>"
))
)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(
feature = "serde",
serde(bound(
serialize = "D: serde::Serialize, D::Value: serde::Serialize, E: serde::Serialize, F: serde::Serialize",
deserialize = "D: serde::Deserialize<'de>, D::Value: serde::Deserialize<'de>, E: serde::Deserialize<'de>, F: serde::Deserialize<'de>"
))
)]
pub struct VFunc2<K: ?Sized, D: SliceByValue, S = [u64; 2], E = FuseLge3Shards, F = E> {
pub(crate) short: VFunc<K, D, S, E>,
pub(crate) long: VFunc<K, D, S, F>,
pub(crate) remap: Box<[D::Value]>,
pub(crate) escape: D::Value,
}
impl<K: ?Sized, D: SliceByValue, S, E, F> std::fmt::Debug for VFunc2<K, D, S, E, F>
where
D::Value: std::fmt::Debug,
VFunc<K, D, S, E>: std::fmt::Debug,
VFunc<K, D, S, F>: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("VFunc2")
.field("short", &self.short)
.field("long", &self.long)
.field("remap", &self.remap)
.field("escape", &self.escape)
.finish()
}
}
impl<K: ?Sized, W: Word, S: Sig, E: ShardEdge<S, 3>, F: ShardEdge<S, 3>>
VFunc2<K, BitFieldVec<Box<[W]>>, S, E, F>
{
#[must_use]
pub fn empty() -> Self {
Self {
short: VFunc::empty(),
long: VFunc::empty(),
remap: Box::new([]),
escape: W::ZERO,
}
}
}
impl<
K: ?Sized + ToSig<S>,
D: SliceByValue<Value: Word + BinSafe + PrimitiveNumberAs<usize>>,
S: Sig,
E: ShardEdge<S, 3>,
F: ShardEdge<S, 3>,
> VFunc2<K, D, S, E, F>
{
#[inline]
pub fn get_by_sig(&self, sig: S) -> D::Value {
let idx = self.short.get_by_sig(sig);
if idx != self.escape {
self.remap[idx.as_to::<usize>()]
} else {
self.long.get_by_sig(sig)
}
}
#[inline(always)]
pub fn get(&self, key: impl Borrow<K>) -> D::Value {
self.get_by_sig(K::to_sig(key.borrow(), self.short.seed))
}
pub fn len(&self) -> usize {
self.short.num_keys
}
pub fn is_empty(&self) -> bool {
self.short.num_keys == 0
}
}
use crate::traits::{TryIntoUnaligned, Unaligned};
impl<K: ?Sized, W: Word, S: Sig, E: ShardEdge<S, 3>, F: ShardEdge<S, 3>> TryIntoUnaligned
for VFunc2<K, BitFieldVec<Box<[W]>>, S, E, F>
{
type Unaligned = VFunc2<K, Unaligned<BitFieldVec<Box<[W]>>>, S, E, F>;
fn try_into_unaligned(
self,
) -> Result<Self::Unaligned, crate::traits::UnalignedConversionError> {
Ok(VFunc2 {
short: self.short.try_into_unaligned()?,
long: self.long.try_into_unaligned()?,
remap: self.remap,
escape: self.escape,
})
}
}
impl<K: ?Sized, W: Word, S: Sig, E: ShardEdge<S, 3>, F: ShardEdge<S, 3>>
From<Unaligned<VFunc2<K, BitFieldVec<Box<[W]>>, S, E, F>>>
for VFunc2<K, BitFieldVec<Box<[W]>>, S, E, F>
{
fn from(vf: Unaligned<VFunc2<K, BitFieldVec<Box<[W]>>, S, E, F>>) -> Self {
VFunc2 {
short: VFunc::from(vf.short),
long: VFunc::from(vf.long),
remap: vf.remap,
escape: vf.escape,
}
}
}
#[cfg(feature = "rayon")]
mod build {
use super::*;
use crate::func::VBuilder;
use core::error::Error;
use dsi_progress_logger::ProgressLog;
use lender::*;
use rdst::RadixKey;
use std::ops::{BitXor, BitXorAssign};
use sync_cell_slice::SyncSlice;
pub(crate) struct HybridMap<K, V> {
array: Vec<V>,
map: std::collections::HashMap<K, V>,
default: V,
}
impl<K: Word + PrimitiveNumberAs<usize>, V: Copy + Eq> HybridMap<K, V> {
pub(crate) fn new(max_key: Option<K>, default: V) -> Self {
let mut array_len = 1 << 10;
if let Some(mk) = max_key {
array_len = array_len.min(mk.as_to::<usize>() + 1);
}
Self {
array: vec![default; array_len],
map: std::collections::HashMap::new(),
default,
}
}
pub(crate) fn insert(&mut self, key: K, value: V) {
let k: usize = key.as_to();
if k < self.array.len() {
self.array[k] = value;
} else {
self.map.insert(key, value);
}
}
#[inline(always)]
pub(crate) fn get(&self, key: K) -> V {
let k: usize = key.as_to();
if k < self.array.len() {
self.array[k]
} else {
self.map.get(&key).copied().unwrap_or(self.default)
}
}
pub(crate) fn keys_by_desc_value(&self) -> Vec<K>
where
V: Ord,
{
let array_iter = self
.array
.iter()
.enumerate()
.filter(|&(_, v)| *v != self.default)
.map(|(k, _)| K::try_from(k).ok().unwrap());
let map_iter = self.map.keys().copied();
let mut keys: Vec<K> = array_iter.chain(map_iter).collect();
keys.sort_by_key(|b| std::cmp::Reverse(self.get(*b)));
keys
}
}
impl<K: Word + PrimitiveNumberAs<usize>> HybridMap<K, usize> {
#[inline(always)]
pub(crate) fn incr(&mut self, key: K) {
self.add(key, 1);
}
#[inline(always)]
pub(crate) fn add(&mut self, key: K, amount: usize) {
let k: usize = key.as_to();
if k < self.array.len() {
self.array[k] += amount;
} else {
*self.map.entry(key).or_insert(0) += amount;
}
}
}
pub(crate) fn find_optimal_r<W: Word>(
n: usize,
max_value: W,
sorted_vals: &[W],
count_of: impl Fn(W) -> usize,
w_bits: usize,
) -> usize {
let w = max_value.bit_len() as usize;
let m = sorted_vals.len();
let c = 1.11f64;
let mut post = n;
let mut pos = 0usize;
let mut best_r = 0usize;
let mut best_cost = f64::MAX;
for r in 0..w {
let cost_first = if r == 0 { 0.0 } else { c * n as f64 * r as f64 };
let cost_second = c * post as f64 * w as f64;
let cost_remap = pos as f64 * w_bits as f64;
let cost = cost_first + cost_second + cost_remap;
if cost < best_cost {
best_cost = cost;
best_r = r;
}
let to_absorb = (1usize << r).min(m - pos);
for _ in 0..to_absorb {
post -= count_of(sorted_vals[pos]);
pos += 1;
}
}
best_r
}
impl<K, W, S, E, F> VFunc2<K, BitFieldVec<Box<[W]>>, S, E, F>
where
K: ?Sized + ToSig<S> + std::fmt::Debug,
W: Word + BinSafe + MemSize + mem_dbg::FlatType,
S: Sig + Send + Sync,
E: ShardEdge<S, 3> + MemSize + mem_dbg::FlatType,
F: ShardEdge<S, 3> + MemSize + mem_dbg::FlatType,
Box<[W]>: MemSize,
SigVal<S, W>: RadixKey,
SigVal<E::LocalSig, W>: BitXor + BitXorAssign,
SigVal<F::LocalSig, W>: BitXor + BitXorAssign,
{
pub fn try_new<B: ?Sized + std::borrow::Borrow<K>>(
keys: impl FallibleRewindableLender<
RewindError: Error + Send + Sync + 'static,
Error: Error + Send + Sync + 'static,
> + for<'lend> FallibleLending<'lend, Lend = &'lend B>,
values: impl FallibleRewindableLender<
RewindError: Error + Send + Sync + 'static,
Error: Error + Send + Sync + 'static,
> + for<'lend> FallibleLending<'lend, Lend = &'lend W>,
n: usize,
pl: &mut (impl ProgressLog + Clone + Send + Sync),
) -> anyhow::Result<Self> {
Self::try_new_with_builder(keys, values, n, VBuilder::default(), pl)
}
pub fn try_new_with_builder<B: ?Sized + std::borrow::Borrow<K>>(
keys: impl FallibleRewindableLender<
RewindError: Error + Send + Sync + 'static,
Error: Error + Send + Sync + 'static,
> + for<'lend> FallibleLending<'lend, Lend = &'lend B>,
values: impl FallibleRewindableLender<
RewindError: Error + Send + Sync + 'static,
Error: Error + Send + Sync + 'static,
> + for<'lend> FallibleLending<'lend, Lend = &'lend W>,
n: usize,
builder: VBuilder<BitFieldVec<Box<[W]>>, S, E>,
pl: &mut (impl ProgressLog + Clone + Send + Sync),
) -> anyhow::Result<Self> {
let mut builder = builder.expected_num_keys(n);
builder
.try_populate_and_build(
keys,
values,
&mut |builder, seed, mut store, _max_value, _num_keys, pl, _state: &mut ()| {
Self::try_build_from_store::<W>(
seed,
builder.shard_edge,
&mut *store,
&|v| v,
VBuilder::default()
.max_num_threads(builder.max_num_threads)
.eps(builder.eps),
pl,
)
},
pl,
(),
)
.map(|(r, _keys)| r)
}
pub fn try_build_from_store<V: BinSafe + Default + Send + Sync + Copy>(
seed: u64,
shard_edge: E,
store: &mut (impl ShardStore<S, V> + ?Sized),
get_val: &(impl Fn(V) -> W + Send + Sync),
builder: VBuilder<BitFieldVec<Box<[W]>>, S, E>,
pl: &mut (impl ProgressLog + Clone + Send + Sync),
) -> anyhow::Result<Self>
where
SigVal<S, V>: RadixKey,
SigVal<E::LocalSig, V>: BitXor + BitXorAssign,
SigVal<F::LocalSig, V>: BitXor + BitXorAssign,
{
let mut max_value = W::ZERO;
let mut counts: HybridMap<W, usize> = HybridMap::new(None, 0);
for shard in store.iter() {
for sv in shard.iter() {
let val = get_val(sv.val);
if val > max_value {
max_value = val;
}
counts.incr(val);
}
}
Self::build_from_hybrid_counts(
seed, shard_edge, store, get_val, max_value, counts, builder, pl,
)
}
fn build_from_hybrid_counts<V: BinSafe + Default + Send + Sync + Copy>(
seed: u64,
shard_edge: E,
store: &mut (impl ShardStore<S, V> + ?Sized),
get_val: &(impl Fn(V) -> W + Send + Sync),
max_value: W,
counts: HybridMap<W, usize>,
mut builder: VBuilder<BitFieldVec<Box<[W]>>, S, E>,
pl: &mut (impl ProgressLog + Clone + Send + Sync),
) -> anyhow::Result<Self>
where
SigVal<S, V>: RadixKey,
SigVal<E::LocalSig, V>: BitXor + BitXorAssign,
SigVal<F::LocalSig, V>: BitXor + BitXorAssign,
{
let sorted_vals: Vec<W> = counts.keys_by_desc_value();
let w = max_value.bit_len() as usize;
let m = sorted_vals.len();
let n = store.len();
let best_r = find_optimal_r(
n,
max_value,
&sorted_vals,
|v| counts.get(v),
W::BITS as usize,
);
let escape_usize = (1usize << best_r).wrapping_sub(1); let escape = W::try_from(escape_usize).ok().unwrap();
let num_remapped = escape_usize.min(m);
let remap: Box<[W]> = sorted_vals[..num_remapped].to_vec().into_boxed_slice();
let mut inv_map: HybridMap<W, W> = HybridMap::new(Some(max_value), escape);
for (i, &val) in remap.iter().enumerate() {
inv_map.insert(val, W::try_from(i).ok().unwrap());
}
pl.info(format_args!(
"Two-step: r={best_r}, escape={escape_usize}, {num_remapped} remapped values, \
{m} distinct values, max_value={max_value} ({w} bits)",
));
let max_shb = store.max_shard_high_bits();
let max_num_shards = 1usize << max_shb;
let max_shard_mask = (1u64 << max_shb) - 1;
let mut escaped_counts = vec![0usize; max_num_shards];
let sync_counts = escaped_counts.as_sync_slice();
let saved_max_num_threads = builder.max_num_threads;
let saved_eps = builder.eps;
pl.info(format_args!(
"Building key -> remapped index ({best_r} bits, escape={escape_usize})..."
));
let short = builder.try_build_func_with_store_and_inspect::<K, V>(
seed,
shard_edge,
escape,
store,
&|_e, sig_val| inv_map.get(get_val(sig_val.val)),
&|sv: &SigVal<S, V>| {
if inv_map.get(get_val(sv.val)) == escape {
let shard_idx = sv.sig.high_bits(max_shb, max_shard_mask) as usize;
unsafe {
let c = sync_counts[shard_idx].get();
sync_counts[shard_idx].set(c + 1);
}
}
},
pl,
)?;
let n_escaped = n - sorted_vals[..num_remapped]
.iter()
.map(|&v| counts.get(v))
.sum::<usize>();
debug_assert_eq!(
escaped_counts.iter().sum::<usize>(),
n_escaped,
"inspect-counted escaped != freq-computed escaped"
);
let mut long_shard_edge = F::default();
long_shard_edge.set_up_shards(n_escaped, saved_eps);
let long_shard_high_bits = long_shard_edge.shard_high_bits();
let long_num_shards = 1usize << long_shard_high_bits;
let shards_per_long = max_num_shards / long_num_shards;
let filtered_shard_sizes: Vec<usize> = escaped_counts
.chunks(shards_per_long)
.map(|chunk| chunk.iter().sum())
.collect();
pl.info(format_args!(
"Building key -> full value ({w} bits, {n_escaped} escaped keys, {:.1}%)...",
100.0 * n_escaped as f64 / n as f64
));
let mut filtered_store = FilteredShardStore::new(
store,
long_shard_high_bits,
|sv: &SigVal<S, V>| inv_map.get(get_val(sv.val)) == escape,
filtered_shard_sizes,
);
let long = VBuilder::<BitFieldVec<Box<[W]>>, S, F>::default()
.max_num_threads(saved_max_num_threads)
.try_build_func_with_store::<K, V>(
seed,
long_shard_edge,
max_value,
&mut filtered_store,
&|_e, sig_val| get_val(sig_val.val),
pl,
)?;
let result = Self {
short,
long,
remap,
escape,
};
let n = store.len();
let total = result.mem_size(SizeFlags::default()) * 8;
pl.info(format_args!(
"Bits/keys: {:.2} ({total} bits for {n} keys)",
total as f64 / n as f64,
));
Ok(result)
}
}
}
#[cfg(feature = "rayon")]
pub(crate) use build::{HybridMap, find_optimal_r};