#![expect(clippy::needless_pass_by_value)] #![expect(clippy::inline_always)]
use crypto_bigint::{Choice, CtEq, CtSelect, CtLt as _, CtGt as _, Zero, BitOps, Limb, UintRef};
pub(crate) trait Limbs: AsRef<[Limb]> + AsMut<[Limb]> + CtEq + Zero + BitOps {
fn like_zero(&self) -> Self;
#[inline(always)]
fn swap(&mut self, b: &mut Self, choice: Choice) {
let a = &mut <_ as AsMut<[Limb]>>::as_mut(self);
let b = &mut <_ as AsMut<[Limb]>>::as_mut(b);
for (a, b) in a.iter_mut().zip(b.iter_mut()) {
<_>::ct_swap(a, b, choice);
}
}
}
#[inline(always)]
fn approximate_a_lte_c<L: Limbs>(
a: (&mut u32, &mut L),
b_sign: &mut Choice,
c: (&mut u32, &mut L),
) {
let c_lt_a = c.0.ct_lt(a.0);
c.0.ct_swap(a.0, c_lt_a);
L::swap(a.1, c.1, c_lt_a);
*b_sign ^= c_lt_a;
}
#[inline(always)]
fn a_lte_c<L: Limbs>(a: &mut L, b_sign: &mut Choice, c: &mut L) {
let limbs = <_ as AsRef<[Limb]>>::as_ref(c).len();
let c_lt_a = UintRef::new(&c.as_ref()[.. limbs]).ct_lt(UintRef::new(&a.as_ref()[.. limbs]));
L::swap(a, c, c_lt_a);
*b_sign ^= c_lt_a;
}
#[inline(always)]
fn should_reduce_to_next_bit_except_final(
b: (&mut Choice, &mut UintRef),
b_needs_negation: Choice,
b_lte_a: &mut Choice,
a_bits: u32,
b_bits_bound: u32,
) -> Choice {
let b_gt_a = (!*b_lte_a) & b_bits_bound.ct_gt(&a_bits);
*b_lte_a = !b_gt_a;
b_gt_a &
Choice::from(u8::from(b.1.bit_vartime(b_bits_bound.wrapping_sub(1))))
.ct_eq(&!b_needs_negation)
}
#[inline(always)]
fn should_reduce_to_next_bit_final<L: Limbs>(a: &L, b: (&Choice, &UintRef)) -> Choice {
let b_abs = b.1.as_limbs();
let mut borrow = Limb::ZERO;
for (a_limb, b_limb) in <_ as AsRef<[Limb]>>::as_ref(a)[.. b_abs.len()].iter().zip(b_abs) {
let _a_diff_b_abs_limb;
(_a_diff_b_abs_limb, borrow) = a_limb.borrowing_sub(*b_limb, borrow);
}
!borrow.is_zero()
}
#[expect(clippy::too_many_arguments)]
#[inline(always)]
fn reduce_to_next_bit<L: Limbs>(
a: &L,
b: (&mut Choice, &mut UintRef),
b_needs_negation: &mut Choice,
c: &mut L,
should_reduce: Choice,
limbs: usize,
a_bits: u32,
b_bits: u32,
) {
#[cfg(debug_assertions)]
{
debug_assert!(bool::from(a.bits().ct_lt(&c.bits()) | a.bits().ct_eq(&c.bits())));
debug_assert!(limbs <= <_ as AsRef::<[Limb]>>::as_ref(a).len());
debug_assert!(limbs <= <_ as AsRef::<[Limb]>>::as_ref(&b.1).len());
debug_assert!(bool::from((!should_reduce) | UintRef::new(a.as_ref()).ct_lt(b.1)));
debug_assert!(bool::from((!should_reduce) | a_bits.ct_eq(&a.bits())));
debug_assert!(bool::from((*b_needs_negation) | (!should_reduce) | b_bits.ct_eq(&b.1.bits())));
debug_assert!(bool::from((!should_reduce) | b_bits.ct_lt(&b.1.bits_precision())));
}
let log_2_m = {
let log_2_m = b_bits.wrapping_sub(a_bits).wrapping_sub(1);
<_ as CtSelect>::ct_select(&0, &log_2_m, (!a_bits.ct_eq(&b_bits)) & should_reduce)
};
let mut m_a = c.like_zero();
{
let m_a = UintRef::new_mut(&mut <_ as AsMut<[Limb]>>::as_mut(&mut m_a)[.. limbs]);
m_a.copy_from_slice(&a.as_ref()[.. limbs]);
m_a.shl_assign(log_2_m);
}
let m_a = UintRef::new_mut(<_ as AsMut<[Limb]>>::as_mut(&mut m_a));
{
*b.0 ^= *b_needs_negation;
let b_negation_carry = Limb::from(u8::from(*b_needs_negation));
let b_negation_mask = Limb::ZERO.wrapping_sub(b_negation_carry);
let mut b_diff_m_a_carry = Limb::ZERO;
let mut two_m_a_carry = Limb::ZERO;
let mut b_diff_two_m_a_carry =
Limb::from(u8::from(should_reduce)).wrapping_add(b_negation_carry);
for (b_limb, m_a_limb) in b.1.iter_mut().zip(m_a.iter_mut()) {
{
let two_m_a_limb: Limb = ((*m_a_limb) << 1) | two_m_a_carry;
two_m_a_carry = (*m_a_limb) >> const { Limb::BITS - 1 };
let new_b_limb;
(new_b_limb, b_diff_two_m_a_carry) = ((*b_limb) ^ b_negation_mask).carrying_add(
Limb::ct_select(&Limb::ZERO, &!two_m_a_limb, should_reduce),
b_diff_two_m_a_carry,
);
*b_limb = new_b_limb;
}
{
let new_b_diff_m_a_limb;
(new_b_diff_m_a_limb, b_diff_m_a_carry) =
(*b_limb).carrying_add(*m_a_limb, b_diff_m_a_carry);
*m_a_limb = Limb::ct_select(&Limb::ZERO, &new_b_diff_m_a_limb, should_reduce);
}
}
*b_needs_negation = should_reduce & b_diff_two_m_a_carry.ct_eq(&Limb::ZERO);
}
let b_diff_m_a = m_a;
{
let m_b_diff_m_square_a = b_diff_m_a;
m_b_diff_m_square_a.shl_assign(log_2_m);
let mut borrow = Limb::ZERO;
for (c_limb, m_b_diff_m_square_a_limb) in
<_ as AsMut<[Limb]>>::as_mut(c).iter_mut().zip(m_b_diff_m_square_a.as_limbs())
{
let new_limb;
(new_limb, borrow) = c_limb.borrowing_sub(*m_b_diff_m_square_a_limb, borrow);
*c_limb = new_limb;
}
}
}
fn negate_b(b: (&mut Choice, &mut UintRef), b_needs_negation: Choice) {
*b.0 ^= b_needs_negation;
let mask = Limb::ZERO.wrapping_sub(Limb::from(u8::from(b_needs_negation)));
for b_limb in b.1.iter_mut() {
*b_limb ^= mask;
}
b.1[0] |= Limb::ONE;
}
#[inline(always)]
fn normalize<L: Limbs>(a: L, mut b: (Choice, L), c: L) -> (L, (Choice, L), L) {
b.0 |= b.1.ct_eq(&a) | a.ct_eq(&c);
(a, b, c)
}
#[inline(always)]
pub(crate) fn reduce_to_upper_bound<L: Limbs>(
log_2_bound: u32,
mut a: L,
mut b: (Choice, L),
mut c: L,
upper_bound: u32,
) -> (L, (Choice, L), L) {
#[cfg(debug_assertions)]
{
debug_assert!(bool::from(a.bits().ct_lt(&log_2_bound) | a.bits().ct_eq(&log_2_bound)));
debug_assert!(bool::from(b.1.bits().ct_lt(&log_2_bound) | b.1.bits().ct_eq(&log_2_bound)));
debug_assert!(
usize::try_from(log_2_bound.div_ceil(Limb::BITS)).unwrap() <=
<_ as AsRef::<[Limb]>>::as_ref(&a).len()
);
debug_assert!(
<_ as AsRef::<[Limb]>>::as_ref(&a).len() <= <_ as AsRef::<[Limb]>>::as_ref(&b.1).len()
);
}
let original_limbs = usize::try_from(log_2_bound.div_ceil(Limb::BITS)).unwrap();
{
let (b_sign, mut b_value) =
(&mut b.0, UintRef::new_mut(&mut <_ as AsMut<[Limb]>>::as_mut(&mut b.1)[.. original_limbs]));
let mut b_lte_a = Choice::FALSE;
let mut b_needs_negation = Choice::FALSE;
let mut limbs = original_limbs;
#[expect(clippy::range_plus_one)]
let mut bits = ((upper_bound + 1) .. (log_2_bound + 1)).rev();
let mut a_bits = a.bits();
let mut c_bits = c.bits();
{
let progress_in_partial_limb = usize::try_from(log_2_bound % Limb::BITS).unwrap();
for bits in (&mut bits).take(2 + progress_in_partial_limb) {
approximate_a_lte_c((&mut a_bits, &mut a), b_sign, (&mut c_bits, &mut c));
let should_reduce = should_reduce_to_next_bit_except_final(
(b_sign, b_value),
b_needs_negation,
&mut b_lte_a,
a_bits,
bits,
);
reduce_to_next_bit(
&a,
(b_sign, b_value),
&mut b_needs_negation,
&mut c,
should_reduce,
limbs,
a_bits,
bits,
);
debug_assert!(bool::from(
b_needs_negation | (!should_reduce) | b_value.bits().ct_lt(&bits)
));
c_bits = c.bits();
}
negate_b((b_sign, b_value), b_needs_negation);
b_needs_negation = Choice::FALSE;
if progress_in_partial_limb != 0 {
limbs -= 1;
b_value = b_value.leading_mut(limbs);
}
}
while bits.len() != 0 {
debug_assert_ne!(limbs, 0);
#[expect(clippy::as_conversions)]
for bits in (&mut bits).take(const { Limb::BITS as usize }) {
approximate_a_lte_c((&mut a_bits, &mut a), b_sign, (&mut c_bits, &mut c));
let should_reduce = should_reduce_to_next_bit_except_final(
(b_sign, b_value),
b_needs_negation,
&mut b_lte_a,
a_bits,
bits,
);
reduce_to_next_bit(
&a,
(b_sign, b_value),
&mut b_needs_negation,
&mut c,
should_reduce,
limbs,
a_bits,
bits,
);
debug_assert!(bool::from(
b_needs_negation | (!should_reduce) | b_value.bits().ct_lt(&bits)
));
c_bits = c.bits();
}
negate_b((b_sign, b_value), b_needs_negation);
b_needs_negation = Choice::FALSE;
limbs -= 1;
b_value = b_value.leading_mut(limbs);
}
{
let (b_sign, b_value) = (
&mut b.0,
UintRef::new_mut(&mut <_ as AsMut<[Limb]>>::as_mut(&mut b.1)[.. original_limbs]),
);
a_lte_c(&mut a, b_sign, &mut c);
let a_bits = a.bits();
let b_bits = b_value.bits();
let should_reduce = should_reduce_to_next_bit_final(&a, (b_sign, b_value));
reduce_to_next_bit(
&a,
(b_sign, b_value),
&mut b_needs_negation,
&mut c,
should_reduce,
original_limbs,
a_bits,
b_bits,
);
negate_b((b_sign, b_value), b_needs_negation);
}
}
(a, b, c)
}
#[expect(private_bounds)]
#[inline(always)]
pub(crate) fn partial_reduce<L: super::c::Limbs + Limbs>(
log_2_bound: u32,
a: L,
mut b: (Choice, L),
negative_discriminant_abs: &L,
) -> (L, (Choice, L), L) {
let discriminant_bits = negative_discriminant_abs.bits_vartime();
let sqrt_discriminant_bits = discriminant_bits.div_ceil(2);
#[cfg(debug_assertions)]
{
debug_assert!(
negative_discriminant_abs.bits_vartime() <
(u32::try_from(a.as_ref().len()).unwrap() * Limb::BITS)
);
debug_assert!(bool::from(
a.bits().ct_lt(&(u32::try_from(a.as_ref().len()).unwrap() * Limb::BITS))
));
debug_assert_eq!(
<L as AsRef::<[Limb]>>::as_ref(&a).len(),
<L as AsRef::<[Limb]>>::as_ref(&b.1).len()
);
}
b.1 = {
let mut two_a = a.clone();
UintRef::new_mut(two_a.as_mut()).shl1_assign();
L::rem(b.1, &two_a)
};
let c = super::c(&a, &b, negative_discriminant_abs);
let (mut a, mut b, mut c) =
reduce_to_upper_bound(log_2_bound, a, b, c, sqrt_discriminant_bits - 1);
a_lte_c(&mut a, &mut b.0, &mut c);
#[cfg(debug_assertions)]
{
debug_assert!(bool::from(
a.bits().ct_lt(&discriminant_bits.div_ceil(2)) |
a.bits().ct_eq(&discriminant_bits.div_ceil(2))
));
debug_assert!(bool::from(
b.1.bits().ct_lt(&discriminant_bits.div_ceil(2)) |
b.1.bits().ct_eq(&discriminant_bits.div_ceil(2))
));
}
(a, b, c)
}
#[inline(always)]
pub(crate) fn reduce<L: Limbs>(
log_2_bound: u32,
a: L,
b: (Choice, L),
c: L,
) -> (L, (Choice, L), L) {
let (mut a, mut b, mut c) = reduce_to_upper_bound(log_2_bound, a, b, c, 0);
a_lte_c(&mut a, &mut b.0, &mut c);
let (a, b, c) = normalize(a, b, c);
#[cfg(debug_assertions)]
{
let a = UintRef::new(AsRef::<[Limb]>::as_ref(&a));
let b_abs = UintRef::new(AsRef::<[Limb]>::as_ref(&b.1));
let c = UintRef::new(AsRef::<[Limb]>::as_ref(&c));
debug_assert!(bool::from(b_abs.ct_lt(a) | b_abs.ct_eq(&a)));
debug_assert!(bool::from(a.ct_lt(c) | a.ct_eq(&c)));
let b_eq_a_or_a_eq_c = a.ct_eq(&b_abs) | a.ct_eq(&c);
debug_assert!(bool::from((!b_eq_a_or_a_eq_c) | b.0));
}
(a, b, c)
}