use crate::{Limb, Odd, Uint, modular::FixedMontyForm, primitives::u32_min};
use super::{ConstMontyForm, ConstMontyParams};
#[cfg(feature = "alloc")]
use super::BoxedMontyForm;
#[cfg(feature = "alloc")]
use crate::BoxedUint;
macro_rules! impl_longa_monty_lincomb {
($a_b:expr, $u:expr, $modulus:expr, $mod_neg_inv:expr, $nlimbs:expr) => {{
let len = $a_b.len();
let mut hi_carry = Limb::ZERO;
let mut hi;
let mut carry;
let mut j = 0;
while j < $nlimbs {
hi = hi_carry;
hi_carry = Limb::ZERO;
let mut i = 0;
while i < len {
let (ai, bi) = &$a_b[i];
carry = Limb::ZERO;
let mut k = 0;
while k < $nlimbs {
($u[k], carry) = ai.as_montgomery().limbs[j].carrying_mul_add(
bi.as_montgomery().limbs[k],
$u[k],
carry,
);
k += 1;
}
(hi, carry) = hi.carrying_add(carry, Limb::ZERO);
hi_carry = hi_carry.wrapping_add(carry);
i += 1;
}
let q = $u[0].wrapping_mul($mod_neg_inv);
(_, carry) = q.carrying_mul_add($modulus[0], $u[0], Limb::ZERO);
i = 1;
while i < $nlimbs {
($u[i - 1], carry) = q.carrying_mul_add($modulus[i], $u[i], carry);
i += 1;
}
($u[$nlimbs - 1], carry) = hi.carrying_add(carry, Limb::ZERO);
hi_carry = hi_carry.wrapping_add(carry);
j += 1;
}
hi_carry
}};
}
pub const fn lincomb_const_monty_form<MOD: ConstMontyParams<LIMBS>, const LIMBS: usize>(
mut products: &[(ConstMontyForm<MOD, LIMBS>, ConstMontyForm<MOD, LIMBS>)],
modulus: &Odd<Uint<LIMBS>>,
mod_neg_inv: Limb,
) -> Uint<LIMBS> {
let max_accum = 1 << u32_min(MOD::PARAMS.mod_leading_zeros, usize::BITS - 1);
let mut ret = Uint::ZERO;
let mut remain = products.len();
if remain <= max_accum {
let carry =
impl_longa_monty_lincomb!(products, ret.limbs, modulus.0.limbs, mod_neg_inv, LIMBS);
ret.try_sub_with_carry(carry, &modulus.0).0
} else {
let mut window;
while remain > 0 {
let mut buf = Uint::ZERO;
let mut count = remain;
if count > max_accum {
count = max_accum;
}
(window, products) = products.split_at(count);
let carry =
impl_longa_monty_lincomb!(window, buf.limbs, modulus.0.limbs, mod_neg_inv, LIMBS);
buf = buf.try_sub_with_carry(carry, &modulus.0).0;
ret = ret.add_mod(&buf, modulus.as_nz_ref());
remain -= count;
}
ret
}
}
pub const fn lincomb_monty_form<const LIMBS: usize>(
mut products: &[(&FixedMontyForm<LIMBS>, &FixedMontyForm<LIMBS>)],
modulus: &Odd<Uint<LIMBS>>,
mod_neg_inv: Limb,
mod_leading_zeros: u32,
) -> Uint<LIMBS> {
let max_accum = 1 << u32_min(mod_leading_zeros, usize::BITS - 1);
let mut ret = Uint::ZERO;
let mut remain = products.len();
if remain <= max_accum {
let carry =
impl_longa_monty_lincomb!(products, ret.limbs, modulus.0.limbs, mod_neg_inv, LIMBS);
ret.try_sub_with_carry(carry, &modulus.0).0
} else {
let mut window;
while remain > 0 {
let mut count = remain;
if count > max_accum {
count = max_accum;
}
(window, products) = products.split_at(count);
let mut buf = Uint::ZERO;
let carry =
impl_longa_monty_lincomb!(window, buf.limbs, modulus.0.limbs, mod_neg_inv, LIMBS);
buf = buf.try_sub_with_carry(carry, &modulus.0).0;
ret = ret.add_mod(&buf, modulus.as_nz_ref());
remain -= count;
}
ret
}
}
#[cfg(feature = "alloc")]
pub fn lincomb_boxed_monty_form(
mut products: &[(&BoxedMontyForm, &BoxedMontyForm)],
modulus: &Odd<BoxedUint>,
mod_neg_inv: Limb,
mod_leading_zeros: u32,
) -> BoxedUint {
let max_accum = 1 << u32_min(mod_leading_zeros, usize::BITS - 1);
let nlimbs = modulus.0.nlimbs();
let mut ret = BoxedUint::zero_with_precision(modulus.0.bits_precision());
let mut remain = products.len();
if remain <= max_accum {
let carry =
impl_longa_monty_lincomb!(products, ret.limbs, modulus.0.limbs, mod_neg_inv, nlimbs);
ret.sub_assign_mod_with_carry(carry, &modulus.0, &modulus.0);
} else {
let mut window;
let mut buf = BoxedUint::zero_with_precision(modulus.0.bits_precision());
while remain > 0 {
buf.limbs.fill(Limb::ZERO);
let mut count = remain;
if count > max_accum {
count = max_accum;
}
(window, products) = products.split_at(count);
let carry =
impl_longa_monty_lincomb!(window, buf.limbs, modulus.0.limbs, mod_neg_inv, nlimbs);
buf.sub_assign_mod_with_carry(carry, &modulus.0, &modulus.0);
ret.add_mod_assign(&buf, modulus.as_nz_ref());
remain -= count;
}
}
ret
}