#![allow(clippy::type_complexity)]
#![allow(clippy::too_many_arguments)]
use crate::bits::*;
use crate::func::{shard_edge::ShardEdge, *};
use crate::traits::bit_field_slice::{BitFieldSlice, BitFieldSliceMut};
use crate::traits::{BitVecOpsMut, Word};
use crate::utils::*;
use core::error::Error;
use derivative::Derivative;
use derive_setters::*;
use dsi_progress_logger::*;
use lender::FallibleLending;
use log::info;
use num_primitive::PrimitiveNumber;
use rand::rngs::SmallRng;
use rand::{Rng, RngExt, SeedableRng};
use rayon::iter::ParallelIterator;
use rayon::slice::ParallelSlice;
use rdst::*;
use std::any::TypeId;
use std::borrow::{Borrow, Cow};
use std::hint::unreachable_unchecked;
use std::marker::PhantomData;
use std::mem::transmute;
use std::ops::{BitXor, BitXorAssign};
use std::slice::Iter;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Instant;
use thread_priority::ThreadPriority;
use value_traits::slices::{SliceByValue, SliceByValueMut};
use super::shard_edge::FuseLge3Shards;
const LOG2_MAX_SHARDS: u32 = 16;
fn default_max_num_threads() -> usize {
std::thread::available_parallelism()
.map(|p| p.get().min(16))
.unwrap_or(1)
}
#[derive(Setters, Debug, Derivative)]
#[derivative(Default)]
#[setters(generate = false)]
pub struct VBuilder<D, S = [u64; 2], E = FuseLge3Shards> {
#[setters(generate = true, strip_option)]
#[derivative(Default(value = "None"))]
expected_num_keys: Option<usize>,
#[setters(generate = true)]
#[derivative(Default(value = "default_max_num_threads()"))]
pub(crate) max_num_threads: usize,
#[setters(generate = true)]
offline: bool,
#[setters(generate = true)]
check_dups: bool,
#[setters(generate = true, strip_option)]
#[derivative(Default(value = "None"))]
low_mem: Option<bool>,
#[setters(generate = true)]
seed: u64,
#[setters(generate = true, strip_option)]
#[derivative(Default(value = "8"))]
log2_buckets: u32,
#[setters(generate = true, strip_option)]
#[derivative(Default(value = "0.001"))]
pub(crate) eps: f64,
pub(crate) bit_width: usize,
pub(crate) shard_edge: E,
pub(crate) num_keys: usize,
c: f64,
lge: bool,
num_threads: usize,
failed: AtomicBool,
#[doc(hidden)]
_marker: PhantomData<(D, S)>,
}
impl<D: BitFieldSlice<Value: Word + BinSafe> + Send + Sync, S: Sig, E: ShardEdge<S, 3>>
VBuilder<D, S, E>
{
pub(crate) fn init_shards_and_seed(&mut self) -> u64 {
if let Some(expected_num_keys) = self.expected_num_keys {
self.shard_edge.set_up_shards(expected_num_keys, self.eps);
self.log2_buckets = self.shard_edge.shard_high_bits();
}
self.seed
}
pub fn set_from<D2, S2, E2>(mut self, other: &VBuilder<D2, S2, E2>) -> Self {
self.max_num_threads = other.max_num_threads;
self.offline = other.offline;
self.check_dups = other.check_dups;
self.low_mem = other.low_mem;
self.eps = other.eps;
self
}
}
#[derive(thiserror::Error, Debug)]
pub enum BuildError {
#[error("Duplicate key")]
DuplicateKey,
#[error("Duplicate local signatures: use full signatures")]
DuplicateLocalSignatures,
#[error("Value too large for specified bit size")]
ValueTooLarge,
}
#[derive(thiserror::Error, Debug)]
pub enum SolveError {
#[error("Duplicate signature")]
DuplicateSignature,
#[error("Duplicate local signature")]
DuplicateLocalSignature,
#[error("Max shard too big")]
MaxShardTooBig,
#[error("Unsolvable shard")]
UnsolvableShard,
}
enum PeelResult<
'a,
D: BitFieldSlice<Value: Word + BinSafe + Send + Sync> + BitFieldSliceMut + Send + Sync + 'a,
S: Sig + BinSafe,
E: ShardEdge<S, 3>,
V: BinSafe,
> {
Complete(),
Partial {
shard_index: usize,
shard: Arc<Vec<SigVal<S, V>>>,
data: ShardData<'a, D>,
double_stack: DoubleStack<E::Vertex>,
sides_stack: Vec<u8>,
},
}
struct XorGraph<X: Copy + Default + BitXor + BitXorAssign> {
edges: Box<[X]>,
degrees_sides: Box<[u8]>,
overflow: bool,
}
impl<X: BitXor + BitXorAssign + Default + Copy> XorGraph<X> {
pub fn new(n: usize) -> XorGraph<X> {
XorGraph {
edges: vec![X::default(); n].into(),
degrees_sides: vec![0; n].into(),
overflow: false,
}
}
#[inline(always)]
pub fn add(&mut self, v: usize, x: X, side: usize) {
debug_assert!(side < 3);
let (degree_size, overflow) = self.degrees_sides[v].overflowing_add(4);
self.degrees_sides[v] = degree_size;
self.overflow |= overflow;
self.degrees_sides[v] ^= side as u8;
self.edges[v] ^= x;
}
#[inline(always)]
pub fn remove(&mut self, v: usize, x: X, side: usize) {
debug_assert!(side < 3);
self.degrees_sides[v] -= 4;
self.degrees_sides[v] ^= side as u8;
self.edges[v] ^= x;
}
#[inline(always)]
pub fn zero(&mut self, v: usize) {
self.degrees_sides[v] &= 0b11;
}
#[inline(always)]
pub fn edge_and_side(&self, v: usize) -> (X, usize) {
debug_assert!(self.degree(v) < 2);
(self.edges[v] as _, (self.degrees_sides[v] & 0b11) as _)
}
#[inline(always)]
pub fn degree(&self, v: usize) -> u8 {
self.degrees_sides[v] >> 2
}
pub fn degrees(
&self,
) -> std::iter::Map<std::iter::Copied<std::slice::Iter<'_, u8>>, fn(u8) -> u8> {
self.degrees_sides.iter().copied().map(|d| d >> 2)
}
}
struct FastStack<X: Copy + Default> {
stack: Vec<X>,
top: usize,
}
impl<X: Copy + Default> FastStack<X> {
pub fn new(n: usize) -> FastStack<X> {
FastStack {
stack: vec![X::default(); n],
top: 0,
}
}
pub fn push(&mut self, x: X) {
debug_assert!(self.top < self.stack.len());
self.stack[self.top] = x;
self.top += 1;
}
pub fn len(&self) -> usize {
self.top
}
pub fn iter(&self) -> std::slice::Iter<'_, X> {
self.stack[..self.top].iter()
}
}
#[derive(Debug)]
struct DoubleStack<V> {
stack: Vec<V>,
lower: usize,
upper: usize,
}
impl<V: Default + Copy> DoubleStack<V> {
fn new(n: usize) -> DoubleStack<V> {
DoubleStack {
stack: vec![V::default(); n],
lower: 0,
upper: n,
}
}
}
impl<V: Copy> DoubleStack<V> {
#[inline(always)]
fn push_lower(&mut self, v: V) {
debug_assert!(self.lower < self.upper);
self.stack[self.lower] = v;
self.lower += 1;
}
#[inline(always)]
fn push_upper(&mut self, v: V) {
debug_assert!(self.lower < self.upper);
self.upper -= 1;
self.stack[self.upper] = v;
}
#[inline(always)]
fn pop_lower(&mut self) -> Option<V> {
if self.lower == 0 {
None
} else {
self.lower -= 1;
Some(self.stack[self.lower])
}
}
fn upper_len(&self) -> usize {
self.stack.len() - self.upper
}
fn iter_upper(&self) -> Iter<'_, V> {
self.stack[self.upper..].iter()
}
}
type ShardDataIter<'a, D> = <D as SliceByValueMut>::ChunksMut<'a>;
type ShardData<'a, D> = <ShardDataIter<'a, D> as Iterator>::Item;
impl<W: Word + BinSafe, S: Sig + Send + Sync, E: ShardEdge<S, 3>>
VBuilder<BitFieldVec<Box<[W]>>, S, E>
where
SigVal<S, W>: RadixKey,
SigVal<E::LocalSig, W>: BitXor + BitXorAssign,
{
pub fn try_build_func_with_store<K: ?Sized + ToSig<S>, V: BinSafe + Default + Send + Sync>(
&mut self,
seed: u64,
shard_edge: E,
max_value: W,
shard_store: &mut (impl ShardStore<S, V> + ?Sized),
get_val: &(impl Fn(&E, SigVal<E::LocalSig, V>) -> W + Send + Sync),
pl: &mut (impl ProgressLog + Clone + Send + Sync),
) -> anyhow::Result<VFunc<K, BitFieldVec<Box<[W]>>, S, E>>
where
SigVal<S, V>: RadixKey,
SigVal<E::LocalSig, V>: BitXor + BitXorAssign,
for<'a> ShardDataIter<'a, BitFieldVec<Box<[W]>>>: Send,
for<'a> <ShardDataIter<'a, BitFieldVec<Box<[W]>>> as Iterator>::Item: Send,
{
self.try_build_func_with_store_and_inspect(
seed,
shard_edge,
max_value,
shard_store,
get_val,
&|_| {},
pl,
)
}
pub fn try_build_func_with_store_and_inspect<
K: ?Sized + ToSig<S>,
V: BinSafe + Default + Send + Sync,
>(
&mut self,
seed: u64,
shard_edge: E,
max_value: W,
shard_store: &mut (impl ShardStore<S, V> + ?Sized),
get_val: &(impl Fn(&E, SigVal<E::LocalSig, V>) -> W + Send + Sync),
inspect: &(impl Fn(&SigVal<S, V>) + Send + Sync),
pl: &mut (impl ProgressLog + Clone + Send + Sync),
) -> anyhow::Result<VFunc<K, BitFieldVec<Box<[W]>>, S, E>>
where
SigVal<S, V>: RadixKey,
SigVal<E::LocalSig, V>: BitXor + BitXorAssign,
for<'a> ShardDataIter<'a, BitFieldVec<Box<[W]>>>: Send,
for<'a> <ShardDataIter<'a, BitFieldVec<Box<[W]>>> as Iterator>::Item: Send,
{
self.shard_edge = shard_edge;
self.num_keys = shard_store.len();
self.bit_width = max_value.bit_len() as usize;
let max_shard = shard_store.shard_sizes().max().unwrap_or(0);
(self.c, self.lge) = self.shard_edge.set_up_graphs(self.num_keys, max_shard);
pl.info(format_args!(
"Number of keys: {} Max value: {max_value} Bitwidth: {}",
self.num_keys, self.bit_width,
));
let data: BitFieldVec<Box<[W]>> = BitFieldVec::<Box<[W]>>::new_padded(
self.bit_width,
self.shard_edge.num_vertices() * self.shard_edge.num_shards(),
);
self.try_build_from_shard_iter(seed, data, shard_store.iter(), get_val, inspect, pl)
.map_err(Into::into)
}
}
pub(crate) struct RetryState {
prng: SmallRng,
dup_count: u32,
local_dup_count: u32,
}
impl RetryState {
pub(crate) fn next_seed(&mut self) -> u64 {
self.prng.random()
}
pub(crate) fn handle_solve_result<R>(
&mut self,
result: anyhow::Result<R>,
pl: &mut impl ProgressLog,
) -> anyhow::Result<Option<R>> {
match result {
Ok(r) => Ok(Some(r)),
Err(error) => match error.downcast::<SolveError>() {
Ok(vfunc_error) => match vfunc_error {
SolveError::DuplicateSignature => {
if self.dup_count >= 3 {
pl.error(format_args!("Duplicate keys (duplicate 128-bit signatures with four different seeds)"));
return Err(BuildError::DuplicateKey.into());
}
pl.warn(format_args!(
"Duplicate 128-bit signature, trying again with a different seed..."
));
self.dup_count += 1;
Ok(None)
}
SolveError::DuplicateLocalSignature => {
if self.local_dup_count >= 2 {
pl.error(format_args!("Duplicate local signatures: use full signatures (duplicate local signatures with three different seeds)"));
return Err(BuildError::DuplicateLocalSignatures.into());
}
pl.warn(format_args!(
"Duplicate local signature, trying again with a different seed..."
));
self.local_dup_count += 1;
Ok(None)
}
SolveError::MaxShardTooBig => {
pl.warn(format_args!(
"The maximum shard is too big, trying again with a different seed..."
));
Ok(None)
}
SolveError::UnsolvableShard => {
pl.warn(format_args!(
"Unsolvable shard, trying again with a different seed..."
));
Ok(None)
}
},
Err(error) => Err(error),
},
}
}
}
impl<
D: BitFieldSlice<Value: Word + BinSafe> + Send + Sync,
S: Sig + Send + Sync,
E: ShardEdge<S, 3>,
> VBuilder<D, S, E>
{
pub(crate) fn try_build_func<
K: ?Sized + ToSig<S> + std::fmt::Debug,
B: ?Sized + Borrow<K>,
P: ProgressLog + Clone + Send + Sync,
L: FallibleRewindableLender<
RewindError: Error + Send + Sync + 'static,
Error: Error + Send + Sync + 'static,
> + for<'lend> FallibleLending<'lend, Lend = &'lend B>,
>(
mut self,
keys: L,
values: impl FallibleRewindableLender<
RewindError: Error + Send + Sync + 'static,
Error: Error + Send + Sync + 'static,
> + for<'lend> FallibleLending<'lend, Lend = &'lend D::Value>,
new_data: fn(usize, usize) -> D,
pl: &mut P,
) -> anyhow::Result<(VFunc<K, D, S, E>, L)>
where
SigVal<S, D::Value>: RadixKey,
SigVal<E::LocalSig, D::Value>: BitXor + BitXorAssign,
D: for<'a> BitFieldSliceMut<ChunksMut<'a>: Iterator<Item: BitFieldSliceMut>>,
for<'a> ShardDataIter<'a, D>: Send,
for<'a> <ShardDataIter<'a, D> as Iterator>::Item: Send,
{
let get_val = |_shard_edge: &E, sig_val: SigVal<E::LocalSig, D::Value>| sig_val.val;
self.try_populate_and_build(
keys,
values,
&mut |builder, seed, mut store, max_value, _num_keys, pl: &mut P, _state: &mut ()| {
builder.bit_width = max_value.bit_len() as usize;
let data = new_data(
builder.bit_width,
builder.shard_edge.num_vertices() * builder.shard_edge.num_shards(),
);
pl.info(format_args!(
"Number of keys: {} Max value: {max_value} Bit width: {}",
builder.num_keys, builder.bit_width,
));
let func = builder.try_build_from_shard_iter(
seed,
data,
store.drain(),
&get_val,
&|_| {},
pl,
)?;
Ok(func)
},
pl,
(),
)
}
pub(crate) fn try_build_func_and_store<
K: ?Sized + ToSig<S> + std::fmt::Debug,
B: ?Sized + Borrow<K>,
P: ProgressLog + Clone + Send + Sync,
L: FallibleRewindableLender<
RewindError: Error + Send + Sync + 'static,
Error: Error + Send + Sync + 'static,
> + for<'lend> FallibleLending<'lend, Lend = &'lend B>,
>(
mut self,
keys: L,
values: impl FallibleRewindableLender<
RewindError: Error + Send + Sync + 'static,
Error: Error + Send + Sync + 'static,
> + for<'lend> FallibleLending<'lend, Lend = &'lend D::Value>,
new_data: fn(usize, usize) -> D,
pl: &mut P,
) -> anyhow::Result<(
VFunc<K, D, S, E>,
Box<dyn ShardStore<S, D::Value> + Send + Sync>,
L,
)>
where
SigVal<S, D::Value>: RadixKey,
SigVal<E::LocalSig, D::Value>: BitXor + BitXorAssign,
D: for<'a> BitFieldSliceMut<ChunksMut<'a>: Iterator<Item: BitFieldSliceMut>>,
for<'a> ShardDataIter<'a, D>: Send,
for<'a> <ShardDataIter<'a, D> as Iterator>::Item: Send,
{
let get_val = |_shard_edge: &E, sig_val: SigVal<E::LocalSig, D::Value>| sig_val.val;
self.try_populate_and_build(
keys,
values,
&mut |builder, seed, mut store, max_value, _num_keys, pl: &mut P, _state: &mut ()| {
builder.bit_width = max_value.bit_len() as usize;
let data = new_data(
builder.bit_width,
builder.shard_edge.num_vertices() * builder.shard_edge.num_shards(),
);
pl.info(format_args!(
"Number of keys: {} Max value: {max_value} Bit width: {}",
builder.num_keys, builder.bit_width,
));
let func = builder.try_build_from_shard_iter(
seed,
data,
store.iter(),
&get_val,
&|_| {},
pl,
)?;
Ok((func, store))
},
pl,
(),
)
.map(|((func, store), keys)| (func, store, keys))
}
pub(crate) fn try_build_filter<
K: ?Sized + ToSig<S> + std::fmt::Debug,
B: ?Sized + Borrow<K>,
P: ProgressLog + Clone + Send + Sync,
G: Fn(&E, SigVal<E::LocalSig, EmptyVal>) -> D::Value + Send + Sync,
>(
mut self,
mut keys: impl FallibleRewindableLender<
RewindError: Error + Send + Sync + 'static,
Error: Error + Send + Sync + 'static,
> + for<'lend> FallibleLending<'lend, Lend = &'lend B>,
bit_width: usize,
new_data: fn(usize, usize) -> D,
get_val: &G,
pl: &mut P,
) -> anyhow::Result<VFunc<K, D, S, E>>
where
SigVal<S, EmptyVal>: RadixKey,
SigVal<E::LocalSig, EmptyVal>: BitXor + BitXorAssign,
D: for<'a> BitFieldSliceMut<ChunksMut<'a>: Iterator<Item: BitFieldSliceMut>>,
for<'a> ShardDataIter<'a, D>: Send,
for<'a> <ShardDataIter<'a, D> as Iterator>::Item: Send,
{
let mut rs = self.retry_state(pl);
loop {
let seed = rs.next_seed();
let result = {
let mut populate =
|seed: u64,
push: &mut dyn FnMut(SigVal<S, EmptyVal>) -> anyhow::Result<()>,
pl: &mut P,
_state: &mut ()| {
while let Some(key) = keys.next()? {
pl.light_update();
push(SigVal {
sig: K::to_sig(key.borrow(), seed),
val: EmptyVal::default(),
})?;
}
Ok(EmptyVal::default())
};
self.try_solve_once(
seed,
&mut populate,
&mut |builder,
seed,
mut store,
_max_value,
_num_keys,
pl: &mut P,
_state: &mut ()| {
builder.bit_width = bit_width;
let data = new_data(
builder.bit_width,
builder.shard_edge.num_vertices() * builder.shard_edge.num_shards(),
);
pl.info(format_args!(
"Number of keys: {} Bit width: {}",
builder.num_keys, builder.bit_width,
));
let func = builder.try_build_from_shard_iter(
seed,
data,
store.drain(),
get_val,
&|_| {},
pl,
)?;
Ok(func)
},
pl,
&mut (),
)
};
if let Some(r) = rs.handle_solve_result(result, pl)? {
return Ok(r);
}
keys = keys.rewind()?;
}
}
pub(crate) fn retry_state(&mut self, pl: &mut impl ProgressLog) -> RetryState {
self.init_shards_and_seed();
pl.info(format_args!("Using 2^{} buckets", self.log2_buckets));
RetryState {
prng: SmallRng::seed_from_u64(self.seed),
dup_count: 0,
local_dup_count: 0,
}
}
pub(crate) fn try_populate_and_build<
K: ?Sized + ToSig<S> + std::fmt::Debug,
B: ?Sized + Borrow<K>,
V: BinSafe + Default + Send + Sync + Ord,
R,
P: ProgressLog + Clone + Send + Sync,
C,
L: FallibleRewindableLender<
RewindError: Error + Send + Sync + 'static,
Error: Error + Send + Sync + 'static,
> + for<'lend> FallibleLending<'lend, Lend = &'lend B>,
>(
&mut self,
mut keys: L,
mut values: impl FallibleRewindableLender<
RewindError: Error + Send + Sync + 'static,
Error: Error + Send + Sync + 'static,
> + for<'lend> FallibleLending<'lend, Lend = &'lend V>,
build_fn: &mut impl FnMut(
&mut Self,
u64,
Box<dyn ShardStore<S, V> + Send + Sync>,
V,
usize,
&mut P,
&mut C,
) -> anyhow::Result<R>,
pl: &mut P,
mut state: C,
) -> anyhow::Result<(R, L)>
where
SigVal<S, V>: RadixKey,
{
let mut rs = self.retry_state(pl);
loop {
let seed = rs.next_seed();
let result = {
let mut populate = |seed: u64,
push: &mut dyn FnMut(SigVal<S, V>) -> anyhow::Result<()>,
pl: &mut P,
_state: &mut C| {
let mut maybe_max_value = V::default();
while let Some(key) = keys.next()? {
pl.light_update();
let &maybe_val = values.next()?.expect("Not enough values");
maybe_max_value = Ord::max(maybe_max_value, maybe_val);
push(SigVal {
sig: K::to_sig(key.borrow(), seed),
val: maybe_val,
})?;
}
Ok(maybe_max_value)
};
self.try_solve_once(seed, &mut populate, build_fn, pl, &mut state)
};
if let Some(r) = rs.handle_solve_result(result, pl)? {
return Ok((r, keys));
}
values = values.rewind()?;
keys = keys.rewind()?;
}
}
pub(crate) fn try_par_populate_and_build<
K: ?Sized + ToSig<S> + std::fmt::Debug + Sync,
B: Borrow<K> + Sync,
V: BinSafe + Default + Send + Sync + Ord + Copy,
R,
P: ProgressLog + Clone + Send + Sync,
C,
VF: Fn(usize) -> V + Send + Sync,
>(
&mut self,
keys: &[B],
val_fn: &VF,
build_fn: &mut impl FnMut(
&mut Self,
u64,
Box<dyn ShardStore<S, V> + Send + Sync>,
V,
usize,
&mut P,
&mut C,
) -> anyhow::Result<R>,
pl: &mut P,
mut state: C,
) -> anyhow::Result<R>
where
SigVal<S, V>: RadixKey,
S: Send,
{
let mut rs = self.retry_state(pl);
let n = keys.len();
loop {
let seed = rs.next_seed();
let result = {
let mut sig_store = sig_store::new_online::<S, V>(
self.log2_buckets,
LOG2_MAX_SHARDS,
self.expected_num_keys,
)?;
pl.expected_updates(Some(n));
pl.item_name("key");
pl.start(format!(
"Computing and storing {}-bit signatures in memory (parallel) using seed 0x{seed:016x}...",
std::mem::size_of::<S>() * 8,
));
let maybe_max_value = sig_store.par_populate(n, self.max_num_threads, |i| SigVal {
sig: K::to_sig(keys[i].borrow(), seed),
val: val_fn(i),
});
pl.done();
let num_keys = sig_store.len();
let shard_edge = &mut self.shard_edge;
shard_edge.set_up_shards(num_keys, self.eps);
let shard_store = sig_store.into_shard_store(shard_edge.shard_high_bits())?;
let max_shard = shard_store.shard_sizes().max().unwrap_or(0);
if max_shard as f64 > 1.01 * num_keys as f64 / shard_edge.num_shards() as f64 {
Err(SolveError::MaxShardTooBig.into())
} else {
(self.c, self.lge) = shard_edge.set_up_graphs(num_keys, max_shard);
self.num_keys = num_keys;
let store = Box::new(shard_store) as Box<dyn ShardStore<S, V> + Send + Sync>;
build_fn(self, seed, store, maybe_max_value, num_keys, pl, &mut state)
}
};
if let Some(r) = rs.handle_solve_result(result, pl)? {
return Ok(r);
}
}
}
pub(crate) fn try_solve_once<
V: BinSafe + Default + Send + Sync + Ord,
R,
P: ProgressLog + Clone + Send + Sync,
C,
>(
&mut self,
seed: u64,
populate: &mut impl FnMut(
u64,
&mut dyn FnMut(SigVal<S, V>) -> anyhow::Result<()>,
&mut P,
&mut C,
) -> anyhow::Result<V>,
build_fn: &mut impl FnMut(
&mut Self,
u64,
Box<dyn ShardStore<S, V> + Send + Sync>,
V,
usize,
&mut P,
&mut C,
) -> anyhow::Result<R>,
pl: &mut P,
state: &mut C,
) -> anyhow::Result<R>
where
SigVal<S, V>: RadixKey,
{
if self.offline {
self.try_solve_once_inner(
seed,
sig_store::new_offline::<S, V>(
self.log2_buckets,
LOG2_MAX_SHARDS,
self.expected_num_keys,
)?,
populate,
build_fn,
pl,
state,
)
} else {
self.try_solve_once_inner(
seed,
sig_store::new_online::<S, V>(
self.log2_buckets,
LOG2_MAX_SHARDS,
self.expected_num_keys,
)?,
populate,
build_fn,
pl,
state,
)
}
}
fn try_solve_once_inner<
V: BinSafe + Default + Send + Sync + Ord,
R,
P: ProgressLog + Clone + Send + Sync,
SS: SigStore<S, V, ShardStore: 'static>,
C,
>(
&mut self,
seed: u64,
mut sig_store: SS,
populate: &mut impl FnMut(
u64,
&mut dyn FnMut(SigVal<S, V>) -> anyhow::Result<()>,
&mut P,
&mut C,
) -> anyhow::Result<V>,
build_fn: &mut impl FnMut(
&mut Self,
u64,
Box<dyn ShardStore<S, V> + Send + Sync>,
V,
usize,
&mut P,
&mut C,
) -> anyhow::Result<R>,
pl: &mut P,
state: &mut C,
) -> anyhow::Result<R>
where
SigVal<S, V>: RadixKey,
{
pl.expected_updates(self.expected_num_keys);
pl.item_name("key");
pl.start(format!(
"Computing and storing {}-bit signatures in {} using seed 0x{:016x}...",
std::mem::size_of::<S>() * 8,
sig_store
.temp_dir()
.map(|d| d.path().to_string_lossy())
.unwrap_or(Cow::Borrowed("memory")),
seed
));
let start = Instant::now();
let maybe_max_value = populate(
seed,
&mut |sig_val| sig_store.try_push(sig_val).map_err(Into::into),
pl,
state,
)?;
pl.done();
let num_keys = sig_store.len();
info!(
"Computation of signatures from inputs completed in {:.3} seconds ({} keys, {:.3} ns/key)",
start.elapsed().as_secs_f64(),
num_keys,
start.elapsed().as_nanos() as f64 / num_keys as f64
);
let shard_edge = &mut self.shard_edge;
shard_edge.set_up_shards(num_keys, self.eps);
let start = Instant::now();
let shard_store = sig_store.into_shard_store(shard_edge.shard_high_bits())?;
let max_shard = shard_store.shard_sizes().max().unwrap_or(0);
if shard_edge.shard_high_bits() != 0 {
pl.info(format_args!(
"Max shard / average shard: {:.2}%",
(100.0 * max_shard as f64) / (num_keys as f64 / shard_edge.num_shards() as f64)
));
}
if max_shard as f64 > 1.01 * num_keys as f64 / shard_edge.num_shards() as f64 {
return Err(SolveError::MaxShardTooBig.into());
}
(self.c, self.lge) = shard_edge.set_up_graphs(num_keys, max_shard);
self.num_keys = num_keys;
let store = Box::new(shard_store) as Box<dyn ShardStore<S, V> + Send + Sync>;
build_fn(self, seed, store, maybe_max_value, num_keys, pl, state).inspect(|_| {
info!(
"Construction from signatures completed in {:.3} seconds ({} keys, {:.3} ns/key)",
start.elapsed().as_secs_f64(),
num_keys,
start.elapsed().as_nanos() as f64 / num_keys as f64
);
})
}
}
impl<
D: BitFieldSlice<Value: Word + BinSafe>
+ for<'a> BitFieldSliceMut<ChunksMut<'a>: Iterator<Item: BitFieldSliceMut>>
+ Send
+ Sync,
S: Sig + Send + Sync,
E: ShardEdge<S, 3>,
> VBuilder<D, S, E>
{
pub(crate) fn try_build_from_shard_iter<
K: ?Sized + ToSig<S>,
I,
P,
V: BinSafe + Default + Send + Sync,
G: Fn(&E, SigVal<E::LocalSig, V>) -> D::Value + Send + Sync,
H: Fn(&SigVal<S, V>) + Send + Sync,
>(
&mut self,
seed: u64,
mut data: D,
shard_iter: I,
get_val: &G,
inspect: &H,
pl: &mut P,
) -> Result<VFunc<K, D, S, E>, SolveError>
where
SigVal<S, V>: RadixKey,
SigVal<E::LocalSig, V>: BitXor + BitXorAssign,
P: ProgressLog + Clone + Send + Sync,
I: Iterator<Item = Arc<Vec<SigVal<S, V>>>> + Send,
for<'a> ShardDataIter<'a, D>: Send,
for<'a> <ShardDataIter<'a, D> as Iterator>::Item: Send,
{
const {
assert!(
size_of::<E::Vertex>() <= size_of::<usize>(),
"ShardEdge::Vertex must fit in usize without truncation"
);
}
let shard_edge = &self.shard_edge;
self.num_threads = shard_edge.num_shards().min(self.max_num_threads);
pl.info(format_args!("{}", self.shard_edge));
pl.info(format_args!(
"c: {}, Overhead: {:+.4}% Number of threads: {}",
self.c,
100. * ((shard_edge.num_vertices() * shard_edge.num_shards()) as f64
/ (self.num_keys as f64)
- 1.),
self.num_threads
));
pl.log_level(log::Level::Debug);
if self.lge {
pl.info(format_args!(
"Peeling with lazy Gaussian elimination fallback"
));
self.par_solve(
shard_iter,
&mut data,
|this, shard_index, shard, data, pl| {
this.lge_shard(shard_index, shard, data, get_val, inspect, pl)
},
&mut pl.concurrent(),
pl,
)?;
} else if self.low_mem == Some(true)
|| self.low_mem.is_none() && self.num_threads > 3 && shard_edge.num_shards() > 2
{
self.par_solve(
shard_iter,
&mut data,
|this, shard_index, shard, data, pl| {
this.peel_by_sig_vals_low_mem(shard_index, shard, data, get_val, inspect, pl)
},
&mut pl.concurrent(),
pl,
)?;
} else {
self.par_solve(
shard_iter,
&mut data,
|this, shard_index, shard, data, pl| {
this.peel_by_sig_vals_high_mem(shard_index, shard, data, get_val, inspect, pl)
},
&mut pl.concurrent(),
pl,
)?;
}
pl.log_level(log::Level::Info);
pl.info(format_args!(
"Bits/keys: {} ({:+.4}%)",
data.len() as f64 * self.bit_width as f64 / self.num_keys as f64,
100.0 * (data.len() as f64 / self.num_keys as f64 - 1.),
));
Ok(VFunc {
seed,
shard_edge: self.shard_edge,
num_keys: self.num_keys,
data,
_marker: std::marker::PhantomData,
})
}
}
macro_rules! remove_edge {
($xor_graph: ident, $e: ident, $side: ident, $edge: ident, $stack: ident, $push:ident, $conv: expr) => {
match $side {
0 => {
if $xor_graph.degree($e[1]) == 2 {
$stack.$push($conv($e[1]));
}
$xor_graph.remove($e[1], $edge, 1);
if $xor_graph.degree($e[2]) == 2 {
$stack.$push($conv($e[2]));
}
$xor_graph.remove($e[2], $edge, 2);
}
1 => {
if $xor_graph.degree($e[0]) == 2 {
$stack.$push($conv($e[0]));
}
$xor_graph.remove($e[0], $edge, 0);
if $xor_graph.degree($e[2]) == 2 {
$stack.$push($conv($e[2]));
}
$xor_graph.remove($e[2], $edge, 2);
}
2 => {
if $xor_graph.degree($e[0]) == 2 {
$stack.$push($conv($e[0]));
}
$xor_graph.remove($e[0], $edge, 0);
if $xor_graph.degree($e[1]) == 2 {
$stack.$push($conv($e[1]));
}
$xor_graph.remove($e[1], $edge, 1);
}
_ => unsafe { unreachable_unchecked() },
}
};
}
impl<
D: BitFieldSlice<Value: Word + BinSafe + Send + Sync>
+ for<'a> BitFieldSliceMut<ChunksMut<'a>: Iterator<Item: BitFieldSliceMut>>
+ Send
+ Sync,
S: Sig + BinSafe,
E: ShardEdge<S, 3>,
> VBuilder<D, S, E>
{
fn count_sort<V: BinSafe>(&self, data: &mut [SigVal<S, V>]) {
let num_sort_keys = self.shard_edge.num_sort_keys();
let mut count = vec![0; num_sort_keys];
let mut copied = Box::new_uninit_slice(data.len());
for (&sig_val, copy) in data.iter().zip(copied.iter_mut()) {
count[self.shard_edge.sort_key(sig_val.sig)] += 1;
copy.write(sig_val);
}
let copied = unsafe { copied.assume_init() };
count.iter_mut().fold(0, |acc, c| {
let old = *c;
*c = acc;
acc + old
});
for &sig_val in copied.iter() {
let key = self.shard_edge.sort_key(sig_val.sig);
data[count[key]] = sig_val;
count[key] += 1;
}
}
#[cfg(target_pointer_width = "64")]
const MAX_NO_LOCAL_SIG_CHECK: usize = 1 << 33;
#[cfg(not(target_pointer_width = "64"))]
const MAX_NO_LOCAL_SIG_CHECK: usize = usize::MAX;
fn par_solve<
'b,
V: BinSafe,
I: IntoIterator<IntoIter: Send, Item = Arc<Vec<SigVal<S, V>>>> + Send,
SS: Fn(&Self, usize, Arc<Vec<SigVal<S, V>>>, ShardData<'b, D>, &mut P) -> Result<(), ()>
+ Send
+ Sync
+ Copy,
C: ConcurrentProgressLog + Send + Sync,
P: ProgressLog + Clone + Send + Sync,
>(
&self,
shard_iter: I,
data: &'b mut D,
solve_shard: SS,
main_pl: &mut C,
pl: &mut P,
) -> Result<(), SolveError>
where
SigVal<S, V>: RadixKey,
for<'a> ShardDataIter<'a, D>: Send,
for<'a> <ShardDataIter<'a, D> as Iterator>::Item: Send,
{
main_pl
.item_name("shard")
.expected_updates(Some(self.shard_edge.num_shards()))
.display_memory(true)
.start("Solving shards...");
self.failed.store(false, Ordering::Relaxed);
let num_shards = self.shard_edge.num_shards();
let buffer_size = self.num_threads.ilog2() as usize;
let (err_send, err_recv) = crossbeam_channel::bounded::<_>(self.num_threads);
let (data_send, data_recv) = crossbeam_channel::bounded::<(
usize,
(Arc<Vec<SigVal<S, V>>>, ShardData<'_, D>),
)>(buffer_size);
let result = std::thread::scope(|scope| {
scope.spawn(move || {
let _ = thread_priority::set_current_thread_priority(ThreadPriority::Max);
for val in shard_iter
.into_iter()
.zip(data.try_chunks_mut(self.shard_edge.num_vertices()).unwrap())
.enumerate()
{
if data_send.send(val).is_err() {
break;
}
}
drop(data_send);
});
for _thread_id in 0..self.num_threads {
let mut main_pl = main_pl.clone();
let mut pl = pl.clone();
let err_send = err_send.clone();
let data_recv = data_recv.clone();
scope.spawn(move || {
loop {
match data_recv.recv() {
Err(_) => return,
Ok((shard_index, (shard, mut data))) => {
if shard.is_empty() {
return;
}
main_pl.info(format_args!(
"Analyzing shard {}/{}...",
shard_index + 1,
num_shards
));
pl.start(format!(
"Sorting shard {}/{}...",
shard_index + 1,
num_shards
));
{
let shard = unsafe {
&mut *(Arc::as_ptr(&shard) as *mut Vec<SigVal<S, V>>)
};
if self.check_dups {
shard.radix_sort_builder().sort();
if shard.par_windows(2).any(|w| w[0].sig == w[1].sig) {
let _ = err_send.send(SolveError::DuplicateSignature);
return;
}
}
#[allow(clippy::absurd_extreme_comparisons)]
if TypeId::of::<E::LocalSig>() != TypeId::of::<S>()
&& self.num_keys > Self::MAX_NO_LOCAL_SIG_CHECK
{
let shard = unsafe {
transmute::<
&mut Vec<SigVal<S, V>>,
&mut Vec<E::SortSigVal<V>>,
>(shard)
};
let builder = shard.radix_sort_builder();
if self.max_num_threads == 1 {
builder
.with_single_threaded_tuner()
.with_parallel(false)
} else {
builder
}
.sort();
let shard_len = shard.len();
shard.dedup();
if TypeId::of::<V>() == TypeId::of::<EmptyVal>() {
pl.info(format_args!(
"Removed signatures: {}",
shard_len - shard.len()
));
} else {
if shard_len != shard.len() {
let _ = err_send
.send(SolveError::DuplicateLocalSignature);
return;
}
}
} else if self.shard_edge.num_sort_keys() != 1 {
self.count_sort::<V>(shard);
}
}
pl.done_with_count(shard.len());
main_pl.info(format_args!(
"Solving shard {}/{}...",
shard_index + 1,
num_shards
));
if self.failed.load(Ordering::Relaxed) {
return;
}
if TypeId::of::<V>() == TypeId::of::<EmptyVal>() {
Mwc192::seed_from_u64(self.seed).fill_bytes(unsafe {
data.as_mut_slice().align_to_mut::<u8>().1
});
}
if solve_shard(self, shard_index, shard, data, &mut pl).is_err() {
let _ = err_send.send(SolveError::UnsolvableShard);
return;
}
if self.failed.load(Ordering::Relaxed) {
return;
}
main_pl.info(format_args!(
"Completed shard {}/{}",
shard_index + 1,
num_shards
));
main_pl.update_and_display();
}
}
}
});
}
drop(err_send);
drop(data_recv);
if let Some(error) = err_recv.into_iter().next() {
self.failed.store(true, Ordering::Relaxed);
return Err(error);
}
Ok(())
});
main_pl.done();
result
}
fn peel_by_index<
'a,
V: BinSafe,
G: Fn(&E, SigVal<E::LocalSig, V>) -> D::Value + Send + Sync,
H: Fn(&SigVal<S, V>) + Send + Sync,
>(
&self,
shard_index: usize,
shard: Arc<Vec<SigVal<S, V>>>,
data: ShardData<'a, D>,
get_val: &G,
inspect: &H,
pl: &mut impl ProgressLog,
) -> Result<PeelResult<'a, D, S, E, V>, ()> {
let shard_edge = &self.shard_edge;
let num_vertices = shard_edge.num_vertices();
let num_shards = shard_edge.num_shards();
pl.start(format!(
"Generating graph for shard {}/{}...",
shard_index + 1,
num_shards
));
let mut xor_graph = XorGraph::<E::Vertex>::new(num_vertices);
for (edge_index, sig_val) in shard.iter().enumerate() {
inspect(sig_val);
for (side, &v) in shard_edge
.local_edge(shard_edge.local_sig(sig_val.sig))
.iter()
.enumerate()
{
xor_graph.add(v, E::Vertex::as_from(edge_index), side);
}
}
pl.done_with_count(shard.len());
assert!(
!xor_graph.overflow,
"Degree overflow for shard {}/{}",
shard_index + 1,
num_shards
);
if self.failed.load(Ordering::Relaxed) {
return Err(());
}
let mut double_stack = DoubleStack::<E::Vertex>::new(num_vertices);
let mut sides_stack = Vec::<u8>::new();
pl.start(format!(
"Peeling graph for shard {}/{} by edge indices...",
shard_index + 1,
num_shards
));
for (v, degree) in xor_graph.degrees().enumerate() {
if degree == 1 {
double_stack.push_lower(E::Vertex::as_from(v));
}
}
while let Some(v) = double_stack.pop_lower() {
let v: usize = v.as_to();
if xor_graph.degree(v) == 0 {
continue;
}
debug_assert!(xor_graph.degree(v) == 1);
let (edge_index, side) = xor_graph.edge_and_side(v);
xor_graph.zero(v);
double_stack.push_upper(edge_index);
sides_stack.push(side as u8);
let edge: usize = edge_index.as_to();
let e = shard_edge.local_edge(shard_edge.local_sig(shard[edge].sig));
remove_edge!(
xor_graph,
e,
side,
edge_index,
double_stack,
push_lower,
E::Vertex::as_from
);
}
pl.done();
if shard.len() != double_stack.upper_len() {
pl.info(format_args!(
"Peeling failed for shard {}/{} (peeled {} out of {} edges)",
shard_index + 1,
num_shards,
double_stack.upper_len(),
shard.len(),
));
return Ok(PeelResult::Partial {
shard_index,
shard,
data,
double_stack,
sides_stack,
});
}
self.assign(
shard_index,
data,
double_stack
.iter_upper()
.map(|&edge_index| {
let edge: usize = edge_index.as_to();
let sig_val = &shard[edge];
let local_sig = shard_edge.local_sig(sig_val.sig);
(
local_sig,
get_val(
shard_edge,
SigVal {
sig: local_sig,
val: sig_val.val,
},
),
)
})
.zip(sides_stack.into_iter().rev()),
pl,
);
Ok(PeelResult::Complete())
}
fn peel_by_sig_vals_high_mem<
V: BinSafe,
G: Fn(&E, SigVal<E::LocalSig, V>) -> D::Value + Send + Sync,
H: Fn(&SigVal<S, V>) + Send + Sync,
>(
&self,
shard_index: usize,
shard: Arc<Vec<SigVal<S, V>>>,
data: ShardData<'_, D>,
get_val: &G,
inspect: &H,
pl: &mut impl ProgressLog,
) -> Result<(), ()>
where
SigVal<E::LocalSig, V>: BitXor + BitXorAssign + Default,
{
let shard_edge = &self.shard_edge;
let num_vertices = shard_edge.num_vertices();
let num_shards = shard_edge.num_shards();
let shard_len = shard.len();
pl.start(format!(
"Generating graph for shard {}/{}...",
shard_index + 1,
num_shards
));
let mut xor_graph = XorGraph::<SigVal<E::LocalSig, V>>::new(num_vertices);
for &sig_val in shard.iter() {
inspect(&sig_val);
let local_sig = shard_edge.local_sig(sig_val.sig);
for (side, &v) in shard_edge.local_edge(local_sig).iter().enumerate() {
xor_graph.add(
v,
SigVal {
sig: local_sig,
val: sig_val.val,
},
side,
);
}
}
pl.done_with_count(shard.len());
drop(shard);
assert!(
!xor_graph.overflow,
"Degree overflow for shard {}/{}",
shard_index + 1,
num_shards
);
if self.failed.load(Ordering::Relaxed) {
return Err(());
}
let mut sig_vals_stack = FastStack::<SigVal<E::LocalSig, V>>::new(shard_len);
let mut sides_stack = FastStack::<u8>::new(shard_len);
let mut visit_stack = Vec::<E::Vertex>::with_capacity(num_vertices / 3);
pl.start(format!(
"Peeling graph for shard {}/{} by signatures (high-mem)...",
shard_index + 1,
num_shards
));
for (v, degree) in xor_graph.degrees().enumerate() {
if degree == 1 {
visit_stack.push(E::Vertex::as_from(v));
}
}
while let Some(v) = visit_stack.pop() {
let v: usize = v.as_to();
if xor_graph.degree(v) == 0 {
continue;
}
let (sig_val, side) = xor_graph.edge_and_side(v);
xor_graph.zero(v);
sig_vals_stack.push(sig_val);
sides_stack.push(side as u8);
let e = self.shard_edge.local_edge(sig_val.sig);
remove_edge!(xor_graph, e, side, sig_val, visit_stack, push, |v| {
E::Vertex::as_from(v)
});
}
pl.done();
if shard_len != sig_vals_stack.len() {
pl.info(format_args!(
"Peeling failed for shard {}/{} (peeled {} out of {} edges)",
shard_index + 1,
num_shards,
sig_vals_stack.len(),
shard_len
));
return Err(());
}
self.assign(
shard_index,
data,
sig_vals_stack
.iter()
.rev()
.map(|&sig_val| (sig_val.sig, get_val(shard_edge, sig_val)))
.zip(sides_stack.iter().copied().rev()),
pl,
);
Ok(())
}
fn peel_by_sig_vals_low_mem<
V: BinSafe,
G: Fn(&E, SigVal<E::LocalSig, V>) -> D::Value + Send + Sync,
H: Fn(&SigVal<S, V>) + Send + Sync,
>(
&self,
shard_index: usize,
shard: Arc<Vec<SigVal<S, V>>>,
data: ShardData<'_, D>,
get_val: &G,
inspect: &H,
pl: &mut impl ProgressLog,
) -> Result<(), ()>
where
SigVal<E::LocalSig, V>: BitXor + BitXorAssign + Default,
{
let shard_edge = &self.shard_edge;
let num_vertices = shard_edge.num_vertices();
let num_shards = shard_edge.num_shards();
let shard_len = shard.len();
pl.start(format!(
"Generating graph for shard {}/{}...",
shard_index + 1,
num_shards,
));
let mut xor_graph = XorGraph::<SigVal<E::LocalSig, V>>::new(num_vertices);
for &sig_val in shard.iter() {
inspect(&sig_val);
let local_sig = shard_edge.local_sig(sig_val.sig);
for (side, &v) in shard_edge.local_edge(local_sig).iter().enumerate() {
xor_graph.add(
v,
SigVal {
sig: local_sig,
val: sig_val.val,
},
side,
);
}
}
pl.done_with_count(shard.len());
drop(shard);
assert!(
!xor_graph.overflow,
"Degree overflow for shard {}/{}",
shard_index + 1,
num_shards
);
if self.failed.load(Ordering::Relaxed) {
return Err(());
}
let mut visit_stack = DoubleStack::<E::Vertex>::new(num_vertices);
pl.start(format!(
"Peeling graph for shard {}/{} by signatures (low-mem)...",
shard_index + 1,
num_shards
));
for (v, degree) in xor_graph.degrees().enumerate() {
if degree == 1 {
visit_stack.push_lower(E::Vertex::as_from(v));
}
}
while let Some(v) = visit_stack.pop_lower() {
let v: usize = v.as_to();
if xor_graph.degree(v) == 0 {
continue;
}
let (sig_val, side) = xor_graph.edge_and_side(v);
xor_graph.zero(v);
visit_stack.push_upper(E::Vertex::as_from(v));
let e = self.shard_edge.local_edge(sig_val.sig);
remove_edge!(xor_graph, e, side, sig_val, visit_stack, push_lower, |v| {
E::Vertex::as_from(v)
});
}
pl.done();
if shard_len != visit_stack.upper_len() {
pl.info(format_args!(
"Peeling failed for shard {}/{} (peeled {} out of {} edges)",
shard_index + 1,
num_shards,
visit_stack.upper_len(),
shard_len
));
return Err(());
}
self.assign(
shard_index,
data,
visit_stack.iter_upper().map(|&v| {
let (sig_val, side) = xor_graph.edge_and_side(v.as_to());
((sig_val.sig, get_val(shard_edge, sig_val)), side as u8)
}),
pl,
);
Ok(())
}
fn lge_shard<
V: BinSafe,
G: Fn(&E, SigVal<E::LocalSig, V>) -> D::Value + Send + Sync,
H: Fn(&SigVal<S, V>) + Send + Sync,
>(
&self,
shard_index: usize,
shard: Arc<Vec<SigVal<S, V>>>,
data: ShardData<'_, D>,
get_val: &G,
inspect: &H,
pl: &mut impl ProgressLog,
) -> Result<(), ()> {
let shard_edge = &self.shard_edge;
match self.peel_by_index(shard_index, shard, data, get_val, inspect, pl) {
Err(()) => Err(()),
Ok(PeelResult::Complete()) => Ok(()),
Ok(PeelResult::Partial {
shard_index,
shard,
mut data,
double_stack,
sides_stack,
}) => {
pl.info(format_args!("Switching to lazy Gaussian elimination..."));
pl.start(format!(
"Generating system for shard {}/{}...",
shard_index + 1,
shard_edge.num_shards()
));
let num_vertices = shard_edge.num_vertices();
let mut peeled_edges: BitVec = BitVec::new(shard.len());
let mut used_vars: BitVec = BitVec::new(num_vertices);
for &edge in double_stack.iter_upper() {
peeled_edges.set(edge.as_to(), true);
}
let mut system = unsafe {
crate::utils::mod2_sys::Modulo2System::from_parts(
num_vertices,
shard
.iter()
.enumerate()
.filter(|(edge_index, _)| !peeled_edges[*edge_index])
.map(|(_edge_index, sig_val)| {
let local_sig = shard_edge.local_sig(sig_val.sig);
let mut eq: Vec<_> = shard_edge
.local_edge(local_sig)
.iter()
.map(|&x| {
used_vars.set(x, true);
x as u32
})
.collect();
eq.sort_unstable();
crate::utils::mod2_sys::Modulo2Equation::from_parts(
eq,
get_val(
shard_edge,
SigVal {
sig: local_sig,
val: sig_val.val,
},
),
)
})
.collect(),
)
};
if self.failed.load(Ordering::Relaxed) {
return Err(());
}
pl.start("Solving system...");
let result = system.lazy_gaussian_elimination().map_err(|_| ())?;
pl.done_with_count(system.num_equations());
for (v, &value) in result.iter().enumerate().filter(|(v, _)| used_vars[*v]) {
data.set_value(v, value);
}
self.assign(
shard_index,
data,
double_stack
.iter_upper()
.map(|&edge_index| {
let edge: usize = edge_index.as_to();
let sig_val = &shard[edge];
let local_sig = shard_edge.local_sig(sig_val.sig);
(
local_sig,
get_val(
shard_edge,
SigVal {
sig: local_sig,
val: sig_val.val,
},
),
)
})
.zip(sides_stack.into_iter().rev()),
pl,
);
Ok(())
}
}
}
fn assign(
&self,
shard_index: usize,
mut data: ShardData<'_, D>,
sigs_vals_sides: impl Iterator<Item = ((E::LocalSig, D::Value), u8)>,
pl: &mut impl ProgressLog,
) where
for<'a> ShardData<'a, D>: SliceByValueMut,
{
if self.failed.load(Ordering::Relaxed) {
return;
}
pl.start(format!(
"Assigning values for shard {}/{}...",
shard_index + 1,
self.shard_edge.num_shards()
));
for ((sig, val), side) in sigs_vals_sides {
let edge = self.shard_edge.local_edge(sig);
let side = side as usize;
unsafe {
let xor = match side {
0 => data.get_value_unchecked(edge[1]) ^ data.get_value_unchecked(edge[2]),
1 => data.get_value_unchecked(edge[0]) ^ data.get_value_unchecked(edge[2]),
2 => data.get_value_unchecked(edge[0]) ^ data.get_value_unchecked(edge[1]),
_ => core::hint::unreachable_unchecked(),
};
data.set_value_unchecked(edge[side], val ^ xor);
}
}
pl.done();
}
}