use crate::support::int_traits::NarrowingDiv;
use crate::support::{DInt, HInt, Int};
pub fn linear_mul_reduction<U>(x: U, mut e: u32, mut y: U) -> U
where
U: HInt + Int<Unsigned = U>,
U::D: NarrowingDiv,
{
assert!(y <= U::MAX >> 2);
assert!(x < (y << 1));
let _0 = U::ZERO;
let _1 = U::ONE;
if (y & (y - _1)).is_zero() {
if e < U::BITS {
return (x << e) & (y - _1);
} else {
return _0;
}
}
let s = y.leading_zeros() - 2;
e += s;
y <<= s;
let mut m = Reducer::new(x, y);
while e >= U::BITS - 1 {
m.word_reduce();
e -= U::BITS - 1;
}
m.shift_reduce(e);
let r = m.partial_remainder();
r.checked_sub(y).unwrap_or(r) >> s
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct Reducer<U: HInt> {
m: U,
r: U,
_2xq: U::D,
}
impl<U> Reducer<U>
where
U: HInt,
U: Int<Unsigned = U>,
{
fn new(x: U, n: U) -> Self
where
U::D: NarrowingDiv,
{
let _1 = U::ONE;
assert!(n > (_1 << (U::BITS - 3)));
assert!(n < (_1 << (U::BITS - 2)));
let m = n << 1;
assert!(x < m);
let dividend = ((_1 << (U::BITS - 1)) - m).widen_hi();
let (f, r) = dividend.checked_narrowing_div_rem(m).unwrap();
let _2x = x + x;
let _2xq = _2x.widen_hi() + _2x.widen_mul(f);
Self { m, r, _2xq }
}
fn partial_remainder(&self) -> U {
let u = self._2xq.hi();
let _2 = U::ONE + U::ONE;
self.m.widen_mul(u + _2).hi()
}
fn shift_reduce(&mut self, k: u32) -> U {
assert!(k < U::BITS);
let a = self._2xq.hi() >> (U::BITS - 1 - k);
let (low, high) = (self._2xq << k).lo_hi();
let b = U::D::from_lo_hi(low, high & (U::MAX >> 1));
self._2xq = a.widen_mul(self.r) + b;
a
}
fn word_reduce(&mut self) -> U {
let (v, u) = self._2xq.lo_hi();
self._2xq = u.widen_mul(self.r) + U::widen_hi(v >> 1);
u
}
}
#[cfg(test)]
mod test {
use crate::support::linear_mul_reduction;
use crate::support::modular::Reducer;
#[test]
fn reducer_ops() {
for n in 33..=63_u8 {
for x in 0..2 * n {
let temp = Reducer::new(x, n);
let n = n as u32;
let x0 = temp.partial_remainder() as u32;
assert_eq!(x as u32, x0);
for k in 0..=7 {
let mut red = temp.clone();
let u = red.shift_reduce(k) as u32;
let x1 = red.partial_remainder() as u32;
assert_eq!(x1, (x0 << k) - u * n);
assert!(x1 < 2 * n);
assert!((red._2xq as u32).is_multiple_of(2 * x1));
if k == 7 {
let mut alt = temp.clone();
let w = alt.word_reduce();
assert_eq!(u, w as u32);
assert_eq!(alt, red);
}
}
}
}
}
#[test]
fn reduction_u8() {
for y in 1..64u8 {
for x in 0..2 * y {
let mut r = x % y;
for e in 0..100 {
assert_eq!(r, linear_mul_reduction(x, e, y));
r <<= 1;
if r >= y {
r -= y;
}
}
}
}
}
#[test]
fn reduction_u128() {
assert_eq!(
linear_mul_reduction::<u128>(17, 100, 123456789),
(17 << 100) % 123456789
);
assert_eq!(
linear_mul_reduction(0xdead_beef, 100, 1_u128 << 116),
0xbeef << 100
);
let x = 10_u128.pow(37);
let y = 11_u128.pow(36);
assert!(x < y);
let mut r = x;
for e in 0..1000 {
assert_eq!(r, linear_mul_reduction(x, e, y));
r <<= 1;
if r >= y {
r -= y;
}
assert!(r != 0);
}
}
}