#![allow(clippy::type_complexity)]
#![allow(clippy::too_many_arguments)]
use crate::bits::*;
use crate::dict::{BitSignedVFunc, SignedVFunc, VFilter};
use crate::func::{shard_edge::ShardEdge, *};
use crate::traits::BitVecOpsMut;
use crate::traits::bit_field_slice::{BitFieldSlice, BitFieldSliceMut, Word};
use crate::utils::*;
use derivative::Derivative;
use derive_setters::*;
use dsi_progress_logger::*;
use lender::FallibleLending;
use log::info;
use num_primitive::PrimitiveNumber;
use rand::rngs::SmallRng;
use rand::{Rng, RngExt, SeedableRng};
use rayon::iter::ParallelIterator;
use rayon::slice::ParallelSlice;
use rdst::*;
use std::any::TypeId;
use std::borrow::{Borrow, Cow};
use std::hint::unreachable_unchecked;
use std::marker::PhantomData;
use std::mem::transmute;
use std::ops::{BitXor, BitXorAssign};
use std::slice::Iter;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Instant;
use thread_priority::ThreadPriority;
use value_traits::slices::{SliceByValue, SliceByValueMut};
use super::shard_edge::FuseLge3Shards;
const LOG2_MAX_SHARDS: u32 = 16;
#[derive(Setters, Debug, Derivative)]
#[derivative(Default)]
#[setters(generate = false)]
pub struct VBuilder<
W: Word + BinSafe,
D: BitFieldSlice<W> + Send + Sync = Box<[W]>,
S = [u64; 2],
E: ShardEdge<S, 3> = FuseLge3Shards,
> {
#[setters(generate = true, strip_option)]
#[derivative(Default(value = "None"))]
expected_num_keys: Option<usize>,
#[setters(generate = true)]
#[derivative(Default(value = "8"))]
max_num_threads: usize,
#[setters(generate = true)]
offline: bool,
#[setters(generate = true)]
check_dups: bool,
#[setters(generate = true, strip_option)]
#[derivative(Default(value = "None"))]
low_mem: Option<bool>,
#[setters(generate = true)]
seed: u64,
#[setters(generate = true, strip_option)]
#[derivative(Default(value = "8"))]
log2_buckets: u32,
#[setters(generate = true, strip_option)]
#[derivative(Default(value = "0.001"))]
eps: f64,
bit_width: usize,
shard_edge: E,
num_keys: usize,
c: f64,
lge: bool,
num_threads: usize,
failed: AtomicBool,
#[doc(hidden)]
_marker_v: PhantomData<(W, D, S)>,
}
#[derive(thiserror::Error, Debug)]
pub enum BuildError {
#[error("Duplicate key")]
DuplicateKey,
#[error("Duplicate local signatures: use full signatures")]
DuplicateLocalSignatures,
#[error("Value too large for specified bit size")]
ValueTooLarge,
}
#[derive(thiserror::Error, Debug)]
pub enum SolveError {
#[error("Duplicate signature")]
DuplicateSignature,
#[error("Duplicate local signature")]
DuplicateLocalSignature,
#[error("Max shard too big")]
MaxShardTooBig,
#[error("Unsolvable shard")]
UnsolvableShard,
}
enum PeelResult<
'a,
W: Word + BinSafe + Send + Sync,
D: BitFieldSlice<W> + BitFieldSliceMut<W> + Send + Sync + 'a,
S: Sig + BinSafe,
E: ShardEdge<S, 3>,
V: BinSafe,
> {
Complete(),
Partial {
shard_index: usize,
shard: Arc<Vec<SigVal<S, V>>>,
data: ShardData<'a, D>,
double_stack: DoubleStack<E::Vertex>,
sides_stack: Vec<u8>,
_marker: PhantomData<W>,
},
}
struct XorGraph<X: Copy + Default + BitXor + BitXorAssign> {
edges: Box<[X]>,
degrees_sides: Box<[u8]>,
overflow: bool,
}
impl<X: BitXor + BitXorAssign + Default + Copy> XorGraph<X> {
pub fn new(n: usize) -> XorGraph<X> {
XorGraph {
edges: vec![X::default(); n].into(),
degrees_sides: vec![0; n].into(),
overflow: false,
}
}
#[inline(always)]
pub fn add(&mut self, v: usize, x: X, side: usize) {
debug_assert!(side < 3);
let (degree_size, overflow) = self.degrees_sides[v].overflowing_add(4);
self.degrees_sides[v] = degree_size;
self.overflow |= overflow;
self.degrees_sides[v] ^= side as u8;
self.edges[v] ^= x;
}
#[inline(always)]
pub fn remove(&mut self, v: usize, x: X, side: usize) {
debug_assert!(side < 3);
self.degrees_sides[v] -= 4;
self.degrees_sides[v] ^= side as u8;
self.edges[v] ^= x;
}
#[inline(always)]
pub fn zero(&mut self, v: usize) {
self.degrees_sides[v] &= 0b11;
}
#[inline(always)]
pub fn edge_and_side(&self, v: usize) -> (X, usize) {
debug_assert!(self.degree(v) < 2);
(self.edges[v] as _, (self.degrees_sides[v] & 0b11) as _)
}
#[inline(always)]
pub fn degree(&self, v: usize) -> u8 {
self.degrees_sides[v] >> 2
}
pub fn degrees(
&self,
) -> std::iter::Map<std::iter::Copied<std::slice::Iter<'_, u8>>, fn(u8) -> u8> {
self.degrees_sides.iter().copied().map(|d| d >> 2)
}
}
struct FastStack<X: Copy + Default> {
stack: Vec<X>,
top: usize,
}
impl<X: Copy + Default> FastStack<X> {
pub fn new(n: usize) -> FastStack<X> {
FastStack {
stack: vec![X::default(); n],
top: 0,
}
}
pub fn push(&mut self, x: X) {
debug_assert!(self.top < self.stack.len());
self.stack[self.top] = x;
self.top += 1;
}
pub fn len(&self) -> usize {
self.top
}
pub fn iter(&self) -> std::slice::Iter<'_, X> {
self.stack[..self.top].iter()
}
}
#[derive(Debug)]
struct DoubleStack<V> {
stack: Vec<V>,
lower: usize,
upper: usize,
}
impl<V: Default + Copy> DoubleStack<V> {
fn new(n: usize) -> DoubleStack<V> {
DoubleStack {
stack: vec![V::default(); n],
lower: 0,
upper: n,
}
}
}
impl<V: Copy> DoubleStack<V> {
#[inline(always)]
fn push_lower(&mut self, v: V) {
debug_assert!(self.lower < self.upper);
self.stack[self.lower] = v;
self.lower += 1;
}
#[inline(always)]
fn push_upper(&mut self, v: V) {
debug_assert!(self.lower < self.upper);
self.upper -= 1;
self.stack[self.upper] = v;
}
#[inline(always)]
fn pop_lower(&mut self) -> Option<V> {
if self.lower == 0 {
None
} else {
self.lower -= 1;
Some(self.stack[self.lower])
}
}
fn upper_len(&self) -> usize {
self.stack.len() - self.upper
}
fn iter_upper(&self) -> Iter<'_, V> {
self.stack[self.upper..].iter()
}
}
type ShardDataIter<'a, D> = <D as SliceByValueMut>::ChunksMut<'a>;
type ShardData<'a, D> = <ShardDataIter<'a, D> as Iterator>::Item;
impl<W: Word + BinSafe + AsU128, S: Sig + Send + Sync, E: ShardEdge<S, 3>>
VBuilder<W, Box<[W]>, S, E>
where
SigVal<S, W>: RadixKey,
SigVal<E::LocalSig, W>: BitXor + BitXorAssign,
Box<[W]>: BitFieldSliceMut<W> + BitFieldSlice<W>,
{
pub fn try_build_func<T: ?Sized + ToSig<S> + std::fmt::Debug, B: ?Sized + Borrow<T>>(
mut self,
keys: impl FallibleRewindableLender<
RewindError: std::error::Error + Send + Sync + 'static,
Error: std::error::Error + Send + Sync + 'static,
> + for<'lend> FallibleLending<'lend, Lend = &'lend B>,
values: impl FallibleRewindableLender<
RewindError: std::error::Error + Send + Sync + 'static,
Error: std::error::Error + Send + Sync + 'static,
> + for<'lend> FallibleLending<'lend, Lend = &'lend W>,
pl: &mut (impl ProgressLog + Clone + Send + Sync),
) -> anyhow::Result<VFunc<T, W, Box<[W]>, S, E>>
where
for<'a> <<Box<[W]> as SliceByValueMut>::ChunksMut<'a> as Iterator>::Item:
BitFieldSliceMut<W>,
for<'a> ShardDataIter<'a, Box<[W]>>: Send,
for<'a> ShardData<'a, Box<[W]>>: Send,
{
let get_val = |_shard_edge: &E, sig_val: SigVal<E::LocalSig, W>| sig_val.val;
let new_data = |_bit_width: usize, len: usize| vec![W::ZERO; len].into();
Ok(self
.build_loop(keys, values, None, &get_val, new_data, false, pl)?
.0)
}
}
impl<W: Word + BinSafe, S: Sig + Send + Sync, E: ShardEdge<S, 3>> VBuilder<W, Box<[W]>, S, E>
where
SigVal<S, EmptyVal>: RadixKey,
SigVal<E::LocalSig, EmptyVal>: BitXor + BitXorAssign,
Box<[W]>: BitFieldSliceMut<W> + BitFieldSlice<W>,
{
pub fn try_build_filter<T: ?Sized + ToSig<S> + std::fmt::Debug, B: ?Sized + Borrow<T>>(
mut self,
keys: impl FallibleRewindableLender<
RewindError: std::error::Error + Send + Sync + 'static,
Error: std::error::Error + Send + Sync + 'static,
> + for<'lend> FallibleLending<'lend, Lend = &'lend B>,
pl: &mut (impl ProgressLog + Clone + Send + Sync),
) -> anyhow::Result<VFilter<W, VFunc<T, W, Box<[W]>, S, E>>>
where
for<'a> <<Box<[W]> as SliceByValueMut>::ChunksMut<'a> as Iterator>::Item:
BitFieldSliceMut<W>,
for<'a> ShardDataIter<'a, Box<[W]>>: Send,
for<'a> ShardData<'a, Box<[W]>>: Send,
{
let filter_mask = W::MAX;
let get_val = |shard_edge: &E, sig_val: SigVal<E::LocalSig, EmptyVal>| {
W::as_from(mix64(shard_edge.edge_hash(sig_val.sig)))
};
let new_data = |_bit_width: usize, len: usize| vec![W::ZERO; len].into();
Ok(VFilter {
func: self
.build_loop(
keys,
FromCloneableIntoIterator::from(itertools::repeat_n(
EmptyVal::default(),
usize::MAX,
)),
Some(W::BITS as usize),
&get_val,
new_data,
false,
pl,
)?
.0,
filter_mask,
hash_bits: W::BITS,
})
}
}
impl<W: Word + BinSafe + AsU128, S: Sig + Send + Sync, E: ShardEdge<S, 3>>
VBuilder<W, BitFieldVec<W>, S, E>
where
SigVal<S, W>: RadixKey,
SigVal<E::LocalSig, W>: BitXor + BitXorAssign,
{
fn _try_build_func<T: ?Sized + ToSig<S> + std::fmt::Debug, B: ?Sized + Borrow<T>>(
mut self,
keys: impl FallibleRewindableLender<
RewindError: std::error::Error + Send + Sync + 'static,
Error: std::error::Error + Send + Sync + 'static,
> + for<'lend> FallibleLending<'lend, Lend = &'lend B>,
values: impl FallibleRewindableLender<
RewindError: std::error::Error + Send + Sync + 'static,
Error: std::error::Error + Send + Sync + 'static,
> + for<'lend> FallibleLending<'lend, Lend = &'lend W>,
keep_store: bool,
pl: &mut (impl ProgressLog + Clone + Send + Sync),
) -> anyhow::Result<(
VFunc<T, W, BitFieldVec<W>, S, E>,
Option<AnyShardStore<S, W>>,
)> {
let get_val = |_shard_edge: &E, sig_val: SigVal<E::LocalSig, W>| sig_val.val;
let new_data = |bit_width, len| BitFieldVec::<W>::new_unaligned(bit_width, len);
self.build_loop(keys, values, None, &get_val, new_data, keep_store, pl)
}
pub fn try_build_func<T: ?Sized + ToSig<S> + std::fmt::Debug, B: ?Sized + Borrow<T>>(
self,
keys: impl FallibleRewindableLender<
RewindError: std::error::Error + Send + Sync + 'static,
Error: std::error::Error + Send + Sync + 'static,
> + for<'lend> FallibleLending<'lend, Lend = &'lend B>,
values: impl FallibleRewindableLender<
RewindError: std::error::Error + Send + Sync + 'static,
Error: std::error::Error + Send + Sync + 'static,
> + for<'lend> FallibleLending<'lend, Lend = &'lend W>,
pl: &mut (impl ProgressLog + Clone + Send + Sync),
) -> anyhow::Result<VFunc<T, W, BitFieldVec<W>, S, E>> {
self._try_build_func(keys, values, false, pl)
.map(|res| res.0)
}
}
impl<S: Sig + Send + Sync, E: ShardEdge<S, 3>> VBuilder<usize, BitFieldVec<usize>, S, E>
where
SigVal<S, usize>: RadixKey,
SigVal<E::LocalSig, usize>: BitXor + BitXorAssign,
{
pub fn try_build_bit_sig_index<
T: ?Sized + ToSig<S> + std::fmt::Debug,
B: ?Sized + Borrow<T>,
>(
self,
keys: impl FallibleRewindableLender<
RewindError: std::error::Error + Send + Sync + 'static,
Error: std::error::Error + Send + Sync + 'static,
> + for<'lend> FallibleLending<'lend, Lend = &'lend B>,
hash_width: usize,
pl: &mut (impl ProgressLog + Clone + Send + Sync),
) -> anyhow::Result<BitSignedVFunc<VFunc<T, usize, BitFieldVec<usize>, S, E>, BitFieldVec>>
{
assert!(hash_width > 0);
assert!(hash_width <= 64);
let (func, store) =
self._try_build_func(keys, FromCloneableIntoIterator::from(0..), true, pl)?;
let num_keys = func.num_keys;
let shard_edge = &func.shard_edge;
let hash_mask = if hash_width == 64 {
u64::MAX
} else {
(1u64 << hash_width) - 1
};
let mut hashes = BitFieldVec::<usize>::new_unaligned(hash_width, num_keys);
pl.item_name("hash");
pl.expected_updates(Some(num_keys));
pl.start("Storing hashes...");
match store.unwrap() {
AnyShardStore::Online(mut shard_store) => {
for shard in shard_store.iter() {
for sig_val in shard.iter() {
let pos = sig_val.val;
let local_sig = shard_edge.local_sig(sig_val.sig);
let hash = (mix64(shard_edge.edge_hash(local_sig)) & hash_mask) as usize;
hashes.set_value(pos, hash);
pl.light_update();
}
}
}
AnyShardStore::Offline(mut shard_store) => {
for shard in shard_store.iter() {
for sig_val in shard.iter() {
let pos = sig_val.val;
let local_sig = shard_edge.local_sig(sig_val.sig);
let hash = (mix64(shard_edge.edge_hash(local_sig)) & hash_mask) as usize;
hashes.set_value(pos, hash);
pl.light_update();
}
}
}
}
pl.done();
Ok(BitSignedVFunc {
func,
hashes,
hash_mask,
})
}
pub fn try_build_sig_index<
T: ?Sized + ToSig<S> + std::fmt::Debug,
B: ?Sized + Borrow<T>,
H: Word,
>(
self,
keys: impl FallibleRewindableLender<
RewindError: std::error::Error + Send + Sync + 'static,
Error: std::error::Error + Send + Sync + 'static,
> + for<'lend> FallibleLending<'lend, Lend = &'lend B>,
pl: &mut (impl ProgressLog + Clone + Send + Sync),
) -> anyhow::Result<SignedVFunc<VFunc<T, usize, BitFieldVec<usize>, S, E>, Box<[H]>>> {
let (func, store) =
self._try_build_func(keys, FromCloneableIntoIterator::from(0..), true, pl)?;
let num_keys = func.num_keys;
let shard_edge = &func.shard_edge;
let mut hashes = vec![H::ZERO; num_keys].into_boxed_slice();
pl.item_name("hash");
pl.expected_updates(Some(num_keys));
pl.start("Storing hashes...");
match store.expect("Store should be present when keep_store is true") {
AnyShardStore::Online(mut shard_store) => {
for shard in shard_store.iter() {
for sig_val in shard.iter() {
let pos = sig_val.val;
let local_sig = shard_edge.local_sig(sig_val.sig);
let hash = H::as_from(mix64(shard_edge.edge_hash(local_sig)));
hashes.set_value(pos, hash);
pl.light_update();
}
}
}
AnyShardStore::Offline(mut shard_store) => {
for shard in shard_store.iter() {
for sig_val in shard.iter() {
let pos = sig_val.val;
let local_sig = shard_edge.local_sig(sig_val.sig);
let hash = H::as_from(mix64(shard_edge.edge_hash(local_sig)));
hashes.set_value(pos, hash);
pl.light_update();
}
}
}
}
pl.done();
Ok(SignedVFunc { func, hashes })
}
}
impl<W: Word + BinSafe, S: Sig + Send + Sync, E: ShardEdge<S, 3>> VBuilder<W, BitFieldVec<W>, S, E>
where
SigVal<S, EmptyVal>: RadixKey,
SigVal<E::LocalSig, EmptyVal>: BitXor + BitXorAssign,
{
pub fn try_build_filter<T: ?Sized + ToSig<S> + std::fmt::Debug, B: ?Sized + Borrow<T>>(
mut self,
keys: impl FallibleRewindableLender<
RewindError: std::error::Error + Send + Sync + 'static,
Error: std::error::Error + Send + Sync + 'static,
> + for<'lend> FallibleLending<'lend, Lend = &'lend B>,
filter_bits: usize,
pl: &mut (impl ProgressLog + Clone + Send + Sync),
) -> anyhow::Result<VFilter<W, VFunc<T, W, BitFieldVec<W>, S, E>>> {
assert!(filter_bits > 0);
assert!(filter_bits <= W::BITS as usize);
let filter_mask = W::MAX >> (W::BITS - filter_bits as u32);
let get_val = |shard_edge: &E, sig_val: SigVal<E::LocalSig, EmptyVal>| {
W::as_from(mix64(shard_edge.edge_hash(sig_val.sig))) & filter_mask
};
let new_data = |bit_width, len| BitFieldVec::<W>::new_unaligned(bit_width, len);
Ok(VFilter {
func: self
.build_loop(
keys,
FromCloneableIntoIterator::from(itertools::repeat_n(
EmptyVal::default(),
usize::MAX,
)),
Some(filter_bits),
&get_val,
new_data,
false,
pl,
)?
.0,
filter_mask,
hash_bits: filter_bits as _,
})
}
}
impl<
W: Word + BinSafe,
D: BitFieldSlice<W>
+ for<'a> BitFieldSliceMut<W, ChunksMut<'a>: Iterator<Item: BitFieldSliceMut<W>>>
+ Send
+ Sync,
S: Sig + Send + Sync,
E: ShardEdge<S, 3>,
> VBuilder<W, D, S, E>
{
fn build_loop<
T: ?Sized + ToSig<S> + std::fmt::Debug,
B: ?Sized + Borrow<T>,
V: BinSafe + Default + Send + Sync + Ord + AsU128,
>(
&mut self,
mut keys: impl FallibleRewindableLender<
RewindError: std::error::Error + Send + Sync + 'static,
Error: std::error::Error + Send + Sync + 'static,
> + for<'lend> FallibleLending<'lend, Lend = &'lend B>,
mut values: impl FallibleRewindableLender<
RewindError: std::error::Error + Send + Sync + 'static,
Error: std::error::Error + Send + Sync + 'static,
> + for<'lend> FallibleLending<'lend, Lend = &'lend V>,
bit_width: Option<usize>,
get_val: &(impl Fn(&E, SigVal<E::LocalSig, V>) -> W + Send + Sync),
new_data: fn(usize, usize) -> D,
keep_store: bool,
pl: &mut (impl ProgressLog + Clone + Send + Sync),
) -> anyhow::Result<(VFunc<T, W, D, S, E>, Option<AnyShardStore<S, V>>)>
where
SigVal<S, V>: RadixKey,
SigVal<E::LocalSig, V>: BitXor + BitXorAssign,
for<'a> ShardDataIter<'a, D>: Send,
for<'a> <ShardDataIter<'a, D> as Iterator>::Item: Send,
{
const {
assert!(
size_of::<E::Vertex>() <= size_of::<usize>(),
"ShardEdge::Vertex must fit in usize without truncation"
);
}
let mut dup_count = 0;
let mut local_dup_count = 0;
let mut prng = SmallRng::seed_from_u64(self.seed);
if let Some(expected_num_keys) = self.expected_num_keys {
self.shard_edge.set_up_shards(expected_num_keys, self.eps);
self.log2_buckets = self.shard_edge.shard_high_bits();
}
pl.info(format_args!("Using 2^{} buckets", self.log2_buckets));
loop {
let seed = prng.random();
let result = if self.offline {
self.try_seed(
seed,
sig_store::new_offline::<S, V>(
self.log2_buckets,
LOG2_MAX_SHARDS,
self.expected_num_keys,
)?,
&mut keys,
&mut values,
bit_width,
get_val,
new_data,
keep_store,
pl,
)
.map(|(func, store)| (func, store.map(AnyShardStore::Offline)))
} else {
self.try_seed(
seed,
sig_store::new_online::<S, V>(
self.log2_buckets,
LOG2_MAX_SHARDS,
self.expected_num_keys,
)?,
&mut keys,
&mut values,
bit_width,
get_val,
new_data,
keep_store,
pl,
)
.map(|(func, store)| (func, store.map(AnyShardStore::Online)))
};
match result {
Ok((func, store)) => {
return Ok((func, store));
}
Err(error) => {
match error.downcast::<SolveError>() {
Ok(vfunc_error) => match vfunc_error {
SolveError::DuplicateSignature => {
if dup_count >= 3 {
pl.error(format_args!("Duplicate keys (duplicate 128-bit signatures with four different seeds)"));
return Err(BuildError::DuplicateKey.into());
}
pl.warn(format_args!(
"Duplicate 128-bit signature, trying again with a different seed..."
));
dup_count += 1;
}
SolveError::DuplicateLocalSignature => {
if local_dup_count >= 2 {
pl.error(format_args!("Duplicate local signatures: use full signatures (duplicate local signatures with three different seeds)"));
return Err(BuildError::DuplicateLocalSignatures.into());
}
pl.warn(format_args!(
"Duplicate local signature, trying again with a different seed..."
));
local_dup_count += 1;
}
SolveError::MaxShardTooBig => {
pl.warn(format_args!(
"The maximum shard is too big, trying again with a different seed..."
));
}
SolveError::UnsolvableShard => {
pl.warn(format_args!(
"Unsolvable shard, trying again with a different seed..."
));
}
},
Err(error) => return Err(error),
}
}
}
values = values.rewind()?;
keys = keys.rewind()?;
}
}
}
impl<
W: Word + BinSafe,
D: BitFieldSlice<W>
+ for<'a> BitFieldSliceMut<W, ChunksMut<'a>: Iterator<Item: BitFieldSliceMut<W>>>
+ Send
+ Sync,
S: Sig + Send + Sync,
E: ShardEdge<S, 3>,
> VBuilder<W, D, S, E>
{
fn try_seed<
T: ?Sized + ToSig<S> + std::fmt::Debug,
B: ?Sized + Borrow<T>,
V: BinSafe + Default + Send + Sync + Ord + AsU128,
G: Fn(&E, SigVal<E::LocalSig, V>) -> W + Send + Sync,
SS: SigStore<S, V>,
>(
&mut self,
seed: u64,
mut sig_store: SS,
keys: &mut (
impl FallibleRewindableLender<
RewindError: std::error::Error + Send + Sync + 'static,
Error: std::error::Error + Send + Sync + 'static,
> + for<'lend> FallibleLending<'lend, Lend = &'lend B>
),
values: &mut (
impl FallibleRewindableLender<
RewindError: std::error::Error + Send + Sync + 'static,
Error: std::error::Error + Send + Sync + 'static,
> + for<'lend> FallibleLending<'lend, Lend = &'lend V>
),
bit_width: Option<usize>,
get_val: &G,
new_data: fn(usize, usize) -> D,
keep_store: bool,
pl: &mut (impl ProgressLog + Clone + Send + Sync),
) -> anyhow::Result<(VFunc<T, W, D, S, E>, Option<SS::ShardStore>)>
where
SigVal<S, V>: RadixKey,
SigVal<E::LocalSig, V>: BitXor + BitXorAssign,
for<'a> ShardDataIter<'a, D>: Send,
for<'a> <ShardDataIter<'a, D> as Iterator>::Item: Send,
{
let shard_edge = &mut self.shard_edge;
pl.expected_updates(self.expected_num_keys);
pl.item_name("key");
pl.start(format!(
"Computing and storing {}-bit signatures in {} using seed 0x{:016x}...",
std::mem::size_of::<S>() * 8,
sig_store
.temp_dir()
.map(|d| d.path().to_string_lossy())
.unwrap_or(Cow::Borrowed("memory")),
seed
));
let mut maybe_max_value = V::default();
let start = Instant::now();
while let Some(key) = keys.next()? {
pl.light_update();
let &maybe_val = values.next()?.expect("Not enough values");
let sig_val = SigVal {
sig: T::to_sig(key.borrow(), seed),
val: maybe_val,
};
maybe_max_value = Ord::max(maybe_max_value, maybe_val);
sig_store.try_push(sig_val)?;
}
pl.done();
self.num_keys = sig_store.len();
self.bit_width = if TypeId::of::<V>() == TypeId::of::<EmptyVal>() {
bit_width.expect("Bit width must be set for filters")
} else {
let len_width = maybe_max_value.as_u128().bit_len() as usize;
if let Some(bit_width) = bit_width {
if len_width > bit_width {
return Err(BuildError::ValueTooLarge.into());
}
bit_width
} else {
len_width
}
};
info!(
"Computation of signatures from inputs completed in {:.3} seconds ({} keys, {:.3} ns/key)",
start.elapsed().as_secs_f64(),
self.num_keys,
start.elapsed().as_nanos() as f64 / self.num_keys as f64
);
shard_edge.set_up_shards(self.num_keys, self.eps);
let start = Instant::now();
let mut shard_store = sig_store.into_shard_store(shard_edge.shard_high_bits())?;
let max_shard = shard_store.shard_sizes().iter().copied().max().unwrap_or(0);
let filter = TypeId::of::<V>() == TypeId::of::<EmptyVal>();
(self.c, self.lge) = shard_edge.set_up_graphs(self.num_keys, max_shard);
if filter {
pl.info(format_args!(
"Number of keys: {} Bit width: {}",
self.num_keys, self.bit_width,
));
} else {
pl.info(format_args!(
"Number of keys: {} Max value: {} Bit width: {}",
self.num_keys,
{
let v: u128 = maybe_max_value.as_u128();
v
},
self.bit_width,
));
}
if shard_edge.shard_high_bits() != 0 {
pl.info(format_args!(
"Max shard / average shard: {:.2}%",
(100.0 * max_shard as f64)
/ (self.num_keys as f64 / shard_edge.num_shards() as f64)
));
}
if max_shard as f64 > 1.01 * self.num_keys as f64 / shard_edge.num_shards() as f64 {
Err(SolveError::MaxShardTooBig.into())
} else {
let data = new_data(
self.bit_width,
shard_edge.num_vertices() * shard_edge.num_shards(),
);
if keep_store {
let func = self
.try_build_from_shard_iter(seed, data, shard_store.iter(), get_val, pl)
.inspect(|_| {
info!(
"Construction from signatures completed in {:.3} seconds ({} keys, {:.3} ns/key)",
start.elapsed().as_secs_f64(),
self.num_keys,
start.elapsed().as_nanos() as f64 / self.num_keys as f64
);
})?;
Ok((func, Some(shard_store)))
} else {
let func = self
.try_build_from_shard_iter(seed, data, shard_store.into_iter(), get_val, pl)
.inspect(|_| {
info!(
"Construction from signatures completed in {:.3} seconds ({} keys, {:.3} ns/key)",
start.elapsed().as_secs_f64(),
self.num_keys,
start.elapsed().as_nanos() as f64 / self.num_keys as f64
);
})?;
Ok((func, None))
}
}
}
fn try_build_from_shard_iter<
T: ?Sized + ToSig<S>,
I,
P,
V: BinSafe + Default + Send + Sync,
G: Fn(&E, SigVal<E::LocalSig, V>) -> W + Send + Sync,
>(
&mut self,
seed: u64,
mut data: D,
shard_iter: I,
get_val: &G,
pl: &mut P,
) -> Result<VFunc<T, W, D, S, E>, SolveError>
where
SigVal<S, V>: RadixKey,
SigVal<E::LocalSig, V>: BitXor + BitXorAssign,
P: ProgressLog + Clone + Send + Sync,
I: Iterator<Item = Arc<Vec<SigVal<S, V>>>> + Send,
for<'a> ShardDataIter<'a, D>: Send,
for<'a> <ShardDataIter<'a, D> as Iterator>::Item: Send,
{
let shard_edge = &self.shard_edge;
self.num_threads = shard_edge.num_shards().min(self.max_num_threads);
pl.info(format_args!("{}", self.shard_edge));
pl.info(format_args!(
"c: {}, Overhead: {:+.4}% Number of threads: {}",
self.c,
100. * ((shard_edge.num_vertices() * shard_edge.num_shards()) as f64
/ (self.num_keys as f64)
- 1.),
self.num_threads
));
if self.lge {
pl.info(format_args!("Peeling towards lazy Gaussian elimination"));
self.par_solve(
shard_iter,
&mut data,
|this, shard_index, shard, data, pl| {
this.lge_shard(shard_index, shard, data, get_val, pl)
},
&mut pl.concurrent(),
pl,
)?;
} else if self.low_mem == Some(true)
|| self.low_mem.is_none() && self.num_threads > 3 && shard_edge.num_shards() > 2
{
self.par_solve(
shard_iter,
&mut data,
|this, shard_index, shard, data, pl| {
this.peel_by_sig_vals_low_mem(shard_index, shard, data, get_val, pl)
},
&mut pl.concurrent(),
pl,
)?;
} else {
self.par_solve(
shard_iter,
&mut data,
|this, shard_index, shard, data, pl| {
this.peel_by_sig_vals_high_mem(shard_index, shard, data, get_val, pl)
},
&mut pl.concurrent(),
pl,
)?;
}
pl.info(format_args!(
"Bits/keys: {} ({:+.4}%)",
data.len() as f64 * self.bit_width as f64 / self.num_keys as f64,
100.0 * (data.len() as f64 / self.num_keys as f64 - 1.),
));
Ok(VFunc {
seed,
shard_edge: self.shard_edge,
num_keys: self.num_keys,
data,
_marker_t: std::marker::PhantomData,
_marker_w: std::marker::PhantomData,
_marker_s: std::marker::PhantomData,
})
}
}
macro_rules! remove_edge {
($xor_graph: ident, $e: ident, $side: ident, $edge: ident, $stack: ident, $push:ident, $conv: expr) => {
match $side {
0 => {
if $xor_graph.degree($e[1]) == 2 {
$stack.$push($conv($e[1]));
}
$xor_graph.remove($e[1], $edge, 1);
if $xor_graph.degree($e[2]) == 2 {
$stack.$push($conv($e[2]));
}
$xor_graph.remove($e[2], $edge, 2);
}
1 => {
if $xor_graph.degree($e[0]) == 2 {
$stack.$push($conv($e[0]));
}
$xor_graph.remove($e[0], $edge, 0);
if $xor_graph.degree($e[2]) == 2 {
$stack.$push($conv($e[2]));
}
$xor_graph.remove($e[2], $edge, 2);
}
2 => {
if $xor_graph.degree($e[0]) == 2 {
$stack.$push($conv($e[0]));
}
$xor_graph.remove($e[0], $edge, 0);
if $xor_graph.degree($e[1]) == 2 {
$stack.$push($conv($e[1]));
}
$xor_graph.remove($e[1], $edge, 1);
}
_ => unsafe { unreachable_unchecked() },
}
};
}
impl<
W: Word + BinSafe + Send + Sync,
D: BitFieldSlice<W>
+ for<'a> BitFieldSliceMut<W, ChunksMut<'a>: Iterator<Item: BitFieldSliceMut<W>>>
+ Send
+ Sync,
S: Sig + BinSafe,
E: ShardEdge<S, 3>,
> VBuilder<W, D, S, E>
{
fn count_sort<V: BinSafe>(&self, data: &mut [SigVal<S, V>]) {
let num_sort_keys = self.shard_edge.num_sort_keys();
let mut count = vec![0; num_sort_keys];
let mut copied = Box::new_uninit_slice(data.len());
for (&sig_val, copy) in data.iter().zip(copied.iter_mut()) {
count[self.shard_edge.sort_key(sig_val.sig)] += 1;
copy.write(sig_val);
}
let copied = unsafe { copied.assume_init() };
count.iter_mut().fold(0, |acc, c| {
let old = *c;
*c = acc;
acc + old
});
for &sig_val in copied.iter() {
let key = self.shard_edge.sort_key(sig_val.sig);
data[count[key]] = sig_val;
count[key] += 1;
}
}
const MAX_NO_LOCAL_SIG_CHECK: usize = 1 << 33;
fn par_solve<
'b,
V: BinSafe,
I: IntoIterator<Item = Arc<Vec<SigVal<S, V>>>> + Send,
SS: Fn(&Self, usize, Arc<Vec<SigVal<S, V>>>, ShardData<'b, D>, &mut P) -> Result<(), ()>
+ Send
+ Sync
+ Copy,
C: ConcurrentProgressLog + Send + Sync,
P: ProgressLog + Clone + Send + Sync,
>(
&self,
shard_iter: I,
data: &'b mut D,
solve_shard: SS,
main_pl: &mut C,
pl: &mut P,
) -> Result<(), SolveError>
where
I::IntoIter: Send,
SigVal<S, V>: RadixKey,
for<'a> ShardDataIter<'a, D>: Send,
for<'a> <ShardDataIter<'a, D> as Iterator>::Item: Send,
{
main_pl
.item_name("shard")
.expected_updates(Some(self.shard_edge.num_shards()))
.display_memory(true)
.start("Solving shards...");
self.failed.store(false, Ordering::Relaxed);
let num_shards = self.shard_edge.num_shards();
let buffer_size = self.num_threads.ilog2() as usize;
let (err_send, err_recv) = crossbeam_channel::bounded::<_>(self.num_threads);
let (data_send, data_recv) = crossbeam_channel::bounded::<(
usize,
(Arc<Vec<SigVal<S, V>>>, ShardData<'_, D>),
)>(buffer_size);
let result = std::thread::scope(|scope| {
scope.spawn(move || {
let _ = thread_priority::set_current_thread_priority(ThreadPriority::Max);
for val in shard_iter
.into_iter()
.zip(data.try_chunks_mut(self.shard_edge.num_vertices()).unwrap())
.enumerate()
{
if data_send.send(val).is_err() {
break;
}
}
drop(data_send);
});
for _thread_id in 0..self.num_threads {
let mut main_pl = main_pl.clone();
let mut pl = pl.clone();
let err_send = err_send.clone();
let data_recv = data_recv.clone();
scope.spawn(move || {
loop {
match data_recv.recv() {
Err(_) => return,
Ok((shard_index, (shard, mut data))) => {
if shard.is_empty() {
return;
}
main_pl.info(format_args!(
"Analyzing shard {}/{}...",
shard_index + 1,
num_shards
));
pl.start(format!(
"Sorting shard {}/{}...",
shard_index + 1,
num_shards
));
{
let shard = unsafe {
&mut *(Arc::as_ptr(&shard) as *mut Vec<SigVal<S, V>>)
};
if self.check_dups {
shard.radix_sort_builder().sort();
if shard.par_windows(2).any(|w| w[0].sig == w[1].sig) {
let _ = err_send.send(SolveError::DuplicateSignature);
return;
}
}
if TypeId::of::<E::LocalSig>() != TypeId::of::<S>()
&& self.num_keys > Self::MAX_NO_LOCAL_SIG_CHECK
{
let shard = unsafe {
transmute::<
&mut Vec<SigVal<S, V>>,
&mut Vec<E::SortSigVal<V>>,
>(shard)
};
let builder = shard.radix_sort_builder();
if self.max_num_threads == 1 {
builder
.with_single_threaded_tuner()
.with_parallel(false)
} else {
builder
}
.sort();
let shard_len = shard.len();
shard.dedup();
if TypeId::of::<V>() == TypeId::of::<EmptyVal>() {
pl.info(format_args!(
"Removed signatures: {}",
shard_len - shard.len()
));
} else {
if shard_len != shard.len() {
let _ = err_send
.send(SolveError::DuplicateLocalSignature);
return;
}
}
} else if self.shard_edge.num_sort_keys() != 1 {
self.count_sort::<V>(shard);
}
}
pl.done_with_count(shard.len());
main_pl.info(format_args!(
"Solving shard {}/{}...",
shard_index + 1,
num_shards
));
if self.failed.load(Ordering::Relaxed) {
return;
}
if TypeId::of::<V>() == TypeId::of::<EmptyVal>() {
Mwc192::seed_from_u64(self.seed).fill_bytes(unsafe {
data.as_mut_slice().align_to_mut::<u8>().1
});
}
if solve_shard(self, shard_index, shard, data, &mut pl).is_err() {
let _ = err_send.send(SolveError::UnsolvableShard);
return;
}
if self.failed.load(Ordering::Relaxed) {
return;
}
main_pl.info(format_args!(
"Completed shard {}/{}",
shard_index + 1,
num_shards
));
main_pl.update_and_display();
}
}
}
});
}
drop(err_send);
drop(data_recv);
if let Some(error) = err_recv.into_iter().next() {
self.failed.store(true, Ordering::Relaxed);
return Err(error);
}
Ok(())
});
main_pl.done();
result
}
fn peel_by_index<'a, V: BinSafe, G: Fn(&E, SigVal<E::LocalSig, V>) -> W + Send + Sync>(
&self,
shard_index: usize,
shard: Arc<Vec<SigVal<S, V>>>,
data: ShardData<'a, D>,
get_val: &G,
pl: &mut impl ProgressLog,
) -> Result<PeelResult<'a, W, D, S, E, V>, ()> {
let shard_edge = &self.shard_edge;
let num_vertices = shard_edge.num_vertices();
let num_shards = shard_edge.num_shards();
pl.start(format!(
"Generating graph for shard {}/{}...",
shard_index + 1,
num_shards
));
let mut xor_graph = XorGraph::<E::Vertex>::new(num_vertices);
for (edge_index, sig_val) in shard.iter().enumerate() {
for (side, &v) in shard_edge
.local_edge(shard_edge.local_sig(sig_val.sig))
.iter()
.enumerate()
{
xor_graph.add(v, E::Vertex::as_from(edge_index), side);
}
}
pl.done_with_count(shard.len());
assert!(
!xor_graph.overflow,
"Degree overflow for shard {}/{}",
shard_index + 1,
num_shards
);
if self.failed.load(Ordering::Relaxed) {
return Err(());
}
let mut double_stack = DoubleStack::<E::Vertex>::new(num_vertices);
let mut sides_stack = Vec::<u8>::new();
pl.start(format!(
"Peeling graph for shard {}/{} by edge indices...",
shard_index + 1,
num_shards
));
for (v, degree) in xor_graph.degrees().enumerate() {
if degree == 1 {
double_stack.push_lower(E::Vertex::as_from(v));
}
}
while let Some(v) = double_stack.pop_lower() {
let v: usize = v.as_to();
if xor_graph.degree(v) == 0 {
continue;
}
debug_assert!(xor_graph.degree(v) == 1);
let (edge_index, side) = xor_graph.edge_and_side(v);
xor_graph.zero(v);
double_stack.push_upper(edge_index);
sides_stack.push(side as u8);
let edge: usize = edge_index.as_to();
let e = shard_edge.local_edge(shard_edge.local_sig(shard[edge].sig));
remove_edge!(
xor_graph,
e,
side,
edge_index,
double_stack,
push_lower,
E::Vertex::as_from
);
}
pl.done();
if shard.len() != double_stack.upper_len() {
pl.info(format_args!(
"Peeling failed for shard {}/{} (peeled {} out of {} edges)",
shard_index + 1,
num_shards,
double_stack.upper_len(),
shard.len(),
));
return Ok(PeelResult::Partial {
shard_index,
shard,
data,
double_stack,
sides_stack,
_marker: std::marker::PhantomData,
});
}
self.assign(
shard_index,
data,
double_stack
.iter_upper()
.map(|&edge_index| {
let edge: usize = edge_index.as_to();
let sig_val = &shard[edge];
let local_sig = shard_edge.local_sig(sig_val.sig);
(
local_sig,
get_val(
shard_edge,
SigVal {
sig: local_sig,
val: sig_val.val,
},
),
)
})
.zip(sides_stack.into_iter().rev()),
pl,
);
Ok(PeelResult::Complete())
}
fn peel_by_sig_vals_high_mem<V: BinSafe, G: Fn(&E, SigVal<E::LocalSig, V>) -> W + Send + Sync>(
&self,
shard_index: usize,
shard: Arc<Vec<SigVal<S, V>>>,
data: ShardData<'_, D>,
get_val: &G,
pl: &mut impl ProgressLog,
) -> Result<(), ()>
where
SigVal<E::LocalSig, V>: BitXor + BitXorAssign + Default,
{
let shard_edge = &self.shard_edge;
let num_vertices = shard_edge.num_vertices();
let num_shards = shard_edge.num_shards();
let shard_len = shard.len();
pl.start(format!(
"Generating graph for shard {}/{}...",
shard_index + 1,
num_shards
));
let mut xor_graph = XorGraph::<SigVal<E::LocalSig, V>>::new(num_vertices);
for &sig_val in shard.iter() {
let local_sig = shard_edge.local_sig(sig_val.sig);
for (side, &v) in shard_edge.local_edge(local_sig).iter().enumerate() {
xor_graph.add(
v,
SigVal {
sig: local_sig,
val: sig_val.val,
},
side,
);
}
}
pl.done_with_count(shard.len());
drop(shard);
assert!(
!xor_graph.overflow,
"Degree overflow for shard {}/{}",
shard_index + 1,
num_shards
);
if self.failed.load(Ordering::Relaxed) {
return Err(());
}
let mut sig_vals_stack = FastStack::<SigVal<E::LocalSig, V>>::new(shard_len);
let mut sides_stack = FastStack::<u8>::new(shard_len);
let mut visit_stack = Vec::<E::Vertex>::with_capacity(num_vertices / 3);
pl.start(format!(
"Peeling graph for shard {}/{} by signatures (high-mem)...",
shard_index + 1,
num_shards
));
for (v, degree) in xor_graph.degrees().enumerate() {
if degree == 1 {
visit_stack.push(E::Vertex::as_from(v));
}
}
while let Some(v) = visit_stack.pop() {
let v: usize = v.as_to();
if xor_graph.degree(v) == 0 {
continue;
}
let (sig_val, side) = xor_graph.edge_and_side(v);
xor_graph.zero(v);
sig_vals_stack.push(sig_val);
sides_stack.push(side as u8);
let e = self.shard_edge.local_edge(sig_val.sig);
remove_edge!(xor_graph, e, side, sig_val, visit_stack, push, |v| {
E::Vertex::as_from(v)
});
}
pl.done();
if shard_len != sig_vals_stack.len() {
pl.info(format_args!(
"Peeling failed for shard {}/{} (peeled {} out of {} edges)",
shard_index + 1,
num_shards,
sig_vals_stack.len(),
shard_len
));
return Err(());
}
self.assign(
shard_index,
data,
sig_vals_stack
.iter()
.rev()
.map(|&sig_val| (sig_val.sig, get_val(shard_edge, sig_val)))
.zip(sides_stack.iter().copied().rev()),
pl,
);
Ok(())
}
fn peel_by_sig_vals_low_mem<V: BinSafe, G: Fn(&E, SigVal<E::LocalSig, V>) -> W + Send + Sync>(
&self,
shard_index: usize,
shard: Arc<Vec<SigVal<S, V>>>,
data: ShardData<'_, D>,
get_val: &G,
pl: &mut impl ProgressLog,
) -> Result<(), ()>
where
SigVal<E::LocalSig, V>: BitXor + BitXorAssign + Default,
{
let shard_edge = &self.shard_edge;
let num_vertices = shard_edge.num_vertices();
let num_shards = shard_edge.num_shards();
let shard_len = shard.len();
pl.start(format!(
"Generating graph for shard {}/{}...",
shard_index + 1,
num_shards,
));
let mut xor_graph = XorGraph::<SigVal<E::LocalSig, V>>::new(num_vertices);
for &sig_val in shard.iter() {
let local_sig = shard_edge.local_sig(sig_val.sig);
for (side, &v) in shard_edge.local_edge(local_sig).iter().enumerate() {
xor_graph.add(
v,
SigVal {
sig: local_sig,
val: sig_val.val,
},
side,
);
}
}
pl.done_with_count(shard.len());
drop(shard);
assert!(
!xor_graph.overflow,
"Degree overflow for shard {}/{}",
shard_index + 1,
num_shards
);
if self.failed.load(Ordering::Relaxed) {
return Err(());
}
let mut visit_stack = DoubleStack::<E::Vertex>::new(num_vertices);
pl.start(format!(
"Peeling graph for shard {}/{} by signatures (low-mem)...",
shard_index + 1,
num_shards
));
for (v, degree) in xor_graph.degrees().enumerate() {
if degree == 1 {
visit_stack.push_lower(E::Vertex::as_from(v));
}
}
while let Some(v) = visit_stack.pop_lower() {
let v: usize = v.as_to();
if xor_graph.degree(v) == 0 {
continue;
}
let (sig_val, side) = xor_graph.edge_and_side(v);
xor_graph.zero(v);
visit_stack.push_upper(E::Vertex::as_from(v));
let e = self.shard_edge.local_edge(sig_val.sig);
remove_edge!(xor_graph, e, side, sig_val, visit_stack, push_lower, |v| {
E::Vertex::as_from(v)
});
}
pl.done();
if shard_len != visit_stack.upper_len() {
pl.info(format_args!(
"Peeling failed for shard {}/{} (peeled {} out of {} edges)",
shard_index + 1,
num_shards,
visit_stack.upper_len(),
shard_len
));
return Err(());
}
self.assign(
shard_index,
data,
visit_stack.iter_upper().map(|&v| {
let (sig_val, side) = xor_graph.edge_and_side(v.as_to());
((sig_val.sig, get_val(shard_edge, sig_val)), side as u8)
}),
pl,
);
Ok(())
}
fn lge_shard<V: BinSafe, G: Fn(&E, SigVal<E::LocalSig, V>) -> W + Send + Sync>(
&self,
shard_index: usize,
shard: Arc<Vec<SigVal<S, V>>>,
data: ShardData<'_, D>,
get_val: &G,
pl: &mut impl ProgressLog,
) -> Result<(), ()> {
let shard_edge = &self.shard_edge;
match self.peel_by_index(shard_index, shard, data, get_val, pl) {
Err(()) => Err(()),
Ok(PeelResult::Complete()) => Ok(()),
Ok(PeelResult::Partial {
shard_index,
shard,
mut data,
double_stack,
sides_stack,
_marker: PhantomData,
}) => {
pl.info(format_args!("Switching to lazy Gaussian elimination..."));
pl.start(format!(
"Generating system for shard {}/{}...",
shard_index + 1,
shard_edge.num_shards()
));
let num_vertices = shard_edge.num_vertices();
let mut peeled_edges = BitVec::new(shard.len());
let mut used_vars = BitVec::new(num_vertices);
for &edge in double_stack.iter_upper() {
peeled_edges.set(edge.as_to(), true);
}
let mut system = unsafe {
crate::utils::mod2_sys::Modulo2System::from_parts(
num_vertices,
shard
.iter()
.enumerate()
.filter(|(edge_index, _)| !peeled_edges[*edge_index])
.map(|(_edge_index, sig_val)| {
let local_sig = shard_edge.local_sig(sig_val.sig);
let mut eq: Vec<_> = shard_edge
.local_edge(local_sig)
.iter()
.map(|&x| {
used_vars.set(x, true);
x as u32
})
.collect();
eq.sort_unstable();
crate::utils::mod2_sys::Modulo2Equation::from_parts(
eq,
get_val(
shard_edge,
SigVal {
sig: local_sig,
val: sig_val.val,
},
),
)
})
.collect(),
)
};
if self.failed.load(Ordering::Relaxed) {
return Err(());
}
pl.start("Solving system...");
let result = system.lazy_gaussian_elimination().map_err(|_| ())?;
pl.done_with_count(system.num_equations());
for (v, &value) in result.iter().enumerate().filter(|(v, _)| used_vars[*v]) {
data.set_value(v, value);
}
self.assign(
shard_index,
data,
double_stack
.iter_upper()
.map(|&edge_index| {
let edge: usize = edge_index.as_to();
let sig_val = &shard[edge];
let local_sig = shard_edge.local_sig(sig_val.sig);
(
local_sig,
get_val(
shard_edge,
SigVal {
sig: local_sig,
val: sig_val.val,
},
),
)
})
.zip(sides_stack.into_iter().rev()),
pl,
);
Ok(())
}
}
}
fn assign(
&self,
shard_index: usize,
mut data: ShardData<'_, D>,
sigs_vals_sides: impl Iterator<Item = ((E::LocalSig, W), u8)>,
pl: &mut impl ProgressLog,
) where
for<'a> ShardData<'a, D>: SliceByValueMut<Value = W>,
{
if self.failed.load(Ordering::Relaxed) {
return;
}
pl.start(format!(
"Assigning values for shard {}/{}...",
shard_index + 1,
self.shard_edge.num_shards()
));
for ((sig, val), side) in sigs_vals_sides {
let edge = self.shard_edge.local_edge(sig);
let side = side as usize;
unsafe {
let xor = match side {
0 => data.get_value_unchecked(edge[1]) ^ data.get_value_unchecked(edge[2]),
1 => data.get_value_unchecked(edge[0]) ^ data.get_value_unchecked(edge[2]),
2 => data.get_value_unchecked(edge[0]) ^ data.get_value_unchecked(edge[1]),
_ => core::hint::unreachable_unchecked(),
};
data.set_value_unchecked(edge[side], val ^ xor);
}
}
pl.done();
}
}