pub use crate::commons::math::decomposition::DecompositionLevel;
use crate::commons::numeric::UnsignedInteger;
use crate::prelude::{DecompositionBaseLog, DecompositionLevelCount};
use dyn_stack::{DynArray, DynStack};
use std::iter::Map;
use std::slice::IterMut;
pub struct TensorSignedDecompositionLendingIter<'buffers, Scalar: UnsignedInteger> {
base_log: usize,
current_level: usize,
mod_b_mask: Scalar,
states: DynArray<'buffers, Scalar>,
fresh: bool,
}
impl<'buffers, Scalar: UnsignedInteger> TensorSignedDecompositionLendingIter<'buffers, Scalar> {
#[inline]
pub(crate) fn new(
input: impl Iterator<Item = Scalar>,
base_log: DecompositionBaseLog,
level: DecompositionLevelCount,
stack: DynStack<'buffers>,
) -> (Self, DynStack<'buffers>) {
let shift = Scalar::BITS - base_log.0 * level.0;
let (states, stack) =
stack.collect_aligned(aligned_vec::CACHELINE_ALIGN, input.map(|i| i >> shift));
(
TensorSignedDecompositionLendingIter {
base_log: base_log.0,
current_level: level.0,
mod_b_mask: (Scalar::ONE << base_log.0) - Scalar::ONE,
states,
fresh: true,
},
stack,
)
}
#[inline]
pub fn next_term<'short>(
&'short mut self,
) -> Option<(
DecompositionLevel,
DecompositionBaseLog,
Map<IterMut<'short, Scalar>, impl FnMut(&'short mut Scalar) -> Scalar>,
)> {
self.fresh = false;
if self.current_level == 0 {
return None;
}
let current_level = self.current_level;
let base_log = self.base_log;
let mod_b_mask = self.mod_b_mask;
self.current_level -= 1;
Some((
DecompositionLevel(current_level),
DecompositionBaseLog(self.base_log),
self.states
.iter_mut()
.map(move |state| decompose_one_level(base_log, state, mod_b_mask)),
))
}
}
#[inline]
fn decompose_one_level<S: UnsignedInteger>(base_log: usize, state: &mut S, mod_b_mask: S) -> S {
let res = *state & mod_b_mask;
*state >>= base_log;
let mut carry = (res.wrapping_sub(S::ONE) | *state) & res;
carry >>= base_log - 1;
*state += carry;
res.wrapping_sub(carry << base_log)
}