use crate::{Choice, Limb, Odd, OddUint, Uint, primitives::u32_min, word};
impl<const LIMBS: usize> Uint<LIMBS> {
#[inline]
pub(super) const fn bounded_div2k_mod_q(
self,
k: u32,
k_upper_bound: u32,
q: &Odd<Self>,
) -> Self {
let one_half_mod_q = OddUint::half_mod(q).0;
let (mut x, mut e) = (self, 0);
let max_iters_per_round = Limb::BITS - 1;
let rounds = k_upper_bound.div_ceil(max_iters_per_round);
let mut r = 0;
while r < rounds {
let f_upper_bound =
u32_min(k_upper_bound - r * max_iters_per_round, max_iters_per_round);
let f = u32_min(k - e, f_upper_bound);
let (_, s) = x.limbs[0].bounded_div2k_mod_q(f, f_upper_bound, one_half_mod_q.limbs[0]);
x = q.mul_add_div2k(s, &x, f);
e += f;
r += 1;
}
x
}
#[inline]
const fn carrying_mul_add_limb(
mut self,
b: Limb,
addend: &Self,
mut carry: Limb,
) -> (Self, Limb) {
let mut i = 0;
while i < LIMBS {
(self.limbs[i], carry) = self.limbs[i].carrying_mul_add(b, addend.limbs[i], carry);
i += 1;
}
(self, carry)
}
}
impl Limb {
const fn bounded_div2k_mod_q(
mut self,
k: u32,
k_upper_bound: u32,
one_half_mod_q: Self,
) -> (Self, Self) {
let mut factor = Limb::ZERO;
let mut i = 0;
while i < k_upper_bound {
let execute = Choice::from_u32_lt(i, k);
let (shifted, carry) = self.shr1();
self = Self::select(self, shifted, execute);
let overflow = word::choice_from_msb(carry.0);
let add_back_q = overflow.and(execute);
self = self.wrapping_add(Self::select(Self::ZERO, one_half_mod_q, add_back_q));
factor = factor.bitxor(Self::select(Self::ZERO, Self::ONE.shl(i), add_back_q));
i += 1;
}
(self, factor)
}
}
impl<const LIMBS: usize> OddUint<LIMBS> {
const fn half_mod(q: &Self) -> Self {
Odd(q.as_ref().shr1().wrapping_add(&Uint::ONE))
}
#[allow(clippy::integer_division_remainder_used, reason = "needs triage")]
const fn mul_add_div2k(&self, b: Limb, addend: &Uint<LIMBS>, k: u32) -> Uint<LIMBS> {
let (lo, hi) = self.as_ref().carrying_mul_add_limb(b, addend, Limb::ZERO);
lo.shr_limb_with_carry(k, hi.unbounded_shl(Limb::BITS - k))
.0
}
}
#[cfg(test)]
mod tests {
use crate::{Limb, U128, Uint};
#[test]
fn test_uint_bounded_div2k_mod_q() {
let q = U128::from(3u64).to_odd().unwrap();
let res = U128::ONE.shl_vartime(64).bounded_div2k_mod_q(0, 0, &q);
assert_eq!(res, U128::ONE.shl_vartime(64));
let res = U128::ONE.shl_vartime(64).bounded_div2k_mod_q(5, 5, &q);
assert_eq!(res, U128::ONE.shl_vartime(59));
let res = U128::ONE.bounded_div2k_mod_q(1, 1, &q);
assert_eq!(res, U128::from(2u64));
let res = U128::from(8u64).bounded_div2k_mod_q(17, 17, &q);
assert_eq!(res, U128::ONE);
let q = U128::from(2864434311u64).to_odd().unwrap();
let res = U128::from(8u64).bounded_div2k_mod_q(17, 17, &q);
assert_eq!(res, U128::from(303681787u64));
let q = U128::from_be_hex("0000AAAABBBB33330000AAAABBBB3333")
.to_odd()
.unwrap();
let res = U128::MAX.bounded_div2k_mod_q(71, 71, &q);
assert_eq!(res, U128::from_be_hex("00002D6F169DBBF300002D6F169DBBF3"));
let res = U128::MAX.bounded_div2k_mod_q(71, 0, &q);
assert_eq!(res, U128::MAX);
let res = U128::MAX.bounded_div2k_mod_q(71, 30, &q);
assert_eq!(res, U128::from_be_hex("000071EEB6013E76000071EEB6013E76"));
let res = U128::MAX.bounded_div2k_mod_q(30, 127, &q);
assert_eq!(res, U128::from_be_hex("000071EEB6013E76000071EEB6013E76"));
}
#[test]
fn test_limb_bounded_div2k_mod_q() {
let x = Limb::MAX.wrapping_sub(Limb::from(15u32));
let q = Limb::from(55u32);
let half_mod_q = q.shr1().0.wrapping_add(Limb::ONE);
let (res, factor) = x.bounded_div2k_mod_q(0, 3, half_mod_q);
assert_eq!(res, x);
assert_eq!(factor, Limb::ZERO);
let (res, factor) = x.bounded_div2k_mod_q(4, 4, half_mod_q);
assert_eq!(res, x.shr(4));
assert_eq!(factor, Limb::ZERO);
let (res, factor) = x.bounded_div2k_mod_q(5, 5, half_mod_q);
assert_eq!(res, x.shr(5).wrapping_add(half_mod_q));
assert_eq!(factor, Limb::ONE.shl(4));
let (res, factor) = x.bounded_div2k_mod_q(5, 4, half_mod_q);
assert_eq!(res, x.shr(4));
assert_eq!(factor, Limb::ZERO);
}
#[test]
fn test_carrying_mul_add_limb() {
let x = U128::from_be_hex("ABCDEF98765432100123456789FEDCBA");
let q = U128::MAX;
let f = Limb::ZERO;
let (res, carry) = q.carrying_mul_add_limb(f, &x, Limb::ZERO);
assert_eq!(res, x);
assert_eq!(carry, Limb::ZERO);
let x = U128::from_be_hex("ABCDEF98765432100123456789FEDCBA");
let q = U128::MAX;
let f = Limb::ONE;
let (res, carry) = q.carrying_mul_add_limb(f, &x, Limb::ZERO);
assert_eq!(res, x.wrapping_add(&q));
assert_eq!(carry, Limb::ONE);
let x = U128::from_be_hex("ABCDEF98765432100123456789FEDCBA");
let q = U128::MAX;
let f = Limb::MAX;
let (res, mac_carry) = q.carrying_mul_add_limb(f, &x, Limb::ZERO);
let (qf_lo, qf_hi) = q.widening_mul(&Uint::new([f; 1]));
let (lo, carry) = qf_lo.carrying_add(&x, Limb::ZERO);
let (hi, carry) = qf_hi.carrying_add(&Uint::ZERO, carry);
assert_eq!(res, lo);
assert_eq!(mac_carry, hi.limbs[0]);
assert_eq!(carry, Limb::ZERO);
}
}