#[cfg(feature = "alloc")]
pub(crate) mod boxed;
use core::fmt;
use crate::{Choice, CtOption, I64, Int, Limb, Odd, U64, Uint, primitives::u32_min};
const GCD_BATCH_SIZE: u32 = 62;
#[derive(Clone, Debug)]
pub(crate) struct SafeGcdInverter<const LIMBS: usize> {
pub(super) modulus: Odd<Uint<LIMBS>>,
inverse: u64,
adjuster: Uint<LIMBS>,
}
type Matrix = [[i64; 2]; 2];
impl<const LIMBS: usize> SafeGcdInverter<LIMBS> {
#[cfg(test)]
pub(crate) const fn new(modulus: &Odd<Uint<LIMBS>>, adjuster: &Uint<LIMBS>) -> Self {
Self::new_with_inverse(
modulus,
U64::from_u64(modulus.as_uint_ref().invert_mod_u64()),
adjuster,
)
}
#[inline]
pub(crate) const fn new_with_inverse(
modulus: &Odd<Uint<LIMBS>>,
inverse: U64,
adjuster: &Uint<LIMBS>,
) -> Self {
Self {
modulus: *modulus,
inverse: inverse.as_uint_ref().lowest_u64(),
adjuster: *adjuster,
}
}
pub const fn invert(&self, value: &Uint<LIMBS>) -> CtOption<Uint<LIMBS>> {
invert_odd_mod_precomp::<LIMBS, false>(value, &self.modulus, self.inverse, &self.adjuster)
}
pub const fn invert_vartime(&self, value: &Uint<LIMBS>) -> CtOption<Uint<LIMBS>> {
invert_odd_mod_precomp::<LIMBS, true>(value, &self.modulus, self.inverse, &self.adjuster)
}
}
#[inline]
pub const fn invert_odd_mod<const LIMBS: usize, const VARTIME: bool>(
a: &Uint<LIMBS>,
m: &Odd<Uint<LIMBS>>,
) -> CtOption<Uint<LIMBS>> {
let mi = m.as_uint_ref().invert_mod_u64();
invert_odd_mod_precomp::<LIMBS, VARTIME>(a, m, mi, &Uint::ONE)
}
const fn invert_odd_mod_precomp<const LIMBS: usize, const VARTIME: bool>(
a: &Uint<LIMBS>,
m: &Odd<Uint<LIMBS>>,
mi: u64,
e: &Uint<LIMBS>,
) -> CtOption<Uint<LIMBS>> {
let a_nonzero = a.is_nonzero();
let (mut f, mut g) = (SignedInt::from_uint(*m.as_ref()), SignedInt::from_uint(*a));
let (mut d, mut e) = (SignedInt::<LIMBS>::ZERO, SignedInt::from_uint(*e));
let mut steps = iterations(Uint::<LIMBS>::BITS);
let mut delta = 1;
let mut t;
while steps > 0 {
if VARTIME && g.is_zero_vartime() {
break;
}
let batch = u32_min(steps, GCD_BATCH_SIZE);
(delta, t) = jump::<VARTIME>(f.lowest(), g.lowest(), delta, batch);
(f, g) = update_fg(&f, &g, t, batch);
(d, e) = update_de(&d, &e, m.as_ref(), mi, t, batch);
steps -= batch;
}
let d = d.norm(f.is_negative(), m.as_ref());
CtOption::new(d, Uint::eq(&f.magnitude, &Uint::ONE).and(a_nonzero))
}
pub const fn gcd_odd<const LIMBS: usize, const VARTIME: bool>(
f: &Odd<Uint<LIMBS>>,
g: &Uint<LIMBS>,
) -> Odd<Uint<LIMBS>> {
let (mut f, mut g) = (SignedInt::from_uint(*f.as_ref()), SignedInt::from_uint(*g));
let mut steps = iterations(Uint::<LIMBS>::BITS);
let mut delta = 1;
let mut t;
while steps > 0 {
if VARTIME && g.is_zero_vartime() {
break;
}
let batch = u32_min(steps, GCD_BATCH_SIZE);
(delta, t) = jump::<VARTIME>(f.lowest(), g.lowest(), delta, batch);
(f, g) = update_fg(&f, &g, t, batch);
steps -= batch;
}
f.magnitude().to_odd().expect_copied("odd by construction")
}
#[inline]
const fn jump<const VARTIME: bool>(
mut f: i64,
mut g: i64,
mut delta: i64,
mut batch: u32,
) -> (i64, Matrix) {
debug_assert!(f & 1 == 1, "f must be odd");
let mut t = [[1i64, 0], [0, 1]];
while batch > 0 {
(f, g, delta, t) = if VARTIME {
jump_step_vartime(f, g, delta, t)
} else {
jump_step(f, g, delta, t)
};
batch -= 1;
}
(delta, t)
}
#[inline(always)]
#[allow(clippy::cast_sign_loss)]
const fn jump_step(
mut f: i64,
mut g: i64,
mut delta: i64,
mut t: Matrix,
) -> (i64, i64, i64, Matrix) {
let d_gtz = Choice::from_u64_nz((delta & !(delta >> 63)) as u64);
let g_odd = Choice::from_u64_lsb((g & 1) as u64);
let g_adj = g_odd.select_i64(0, f);
let swap = d_gtz.and(g_odd);
delta = swap.select_i64(2i64.wrapping_add(delta), 2i64.wrapping_sub(delta));
f = swap.select_i64(f, g);
g = swap.select_i64(g.wrapping_add(g_adj), g.wrapping_sub(g_adj)) >> 1;
t = [
[
swap.select_i64(t[0][0], t[1][0]) << 1,
swap.select_i64(t[0][1], t[1][1]) << 1,
],
[
t[1][0].wrapping_add(g_odd.select_i64(0, d_gtz.select_i64(t[0][0], -t[0][0]))),
t[1][1].wrapping_add(g_odd.select_i64(0, d_gtz.select_i64(t[0][1], -t[0][1]))),
],
];
(f, g, delta, t)
}
#[inline(always)]
const fn jump_step_vartime(
mut f: i64,
mut g: i64,
mut delta: i64,
mut t: Matrix,
) -> (i64, i64, i64, Matrix) {
if (g & 1) != 0 {
(f, g, delta, t) = if delta > 0 {
(
g,
g.wrapping_sub(f),
2i64.wrapping_sub(delta),
[
t[1],
[t[1][0].wrapping_sub(t[0][0]), t[1][1].wrapping_sub(t[0][1])],
],
)
} else {
(
f,
g.wrapping_add(f),
2i64.wrapping_add(delta),
[
t[0],
[t[1][0].wrapping_add(t[0][0]), t[1][1].wrapping_add(t[0][1])],
],
)
};
} else {
delta = 2i64.wrapping_add(delta);
}
g >>= 1;
t[0][0] <<= 1;
t[0][1] <<= 1;
(f, g, delta, t)
}
#[inline]
const fn update_fg<const LIMBS: usize>(
a: &SignedInt<LIMBS>,
b: &SignedInt<LIMBS>,
t: Matrix,
shift: u32,
) -> (SignedInt<LIMBS>, SignedInt<LIMBS>) {
(
SignedInt::lincomb_int_reduce_shift(
a,
b,
&I64::from_i64(t[0][0]),
&I64::from_i64(t[0][1]),
shift,
),
SignedInt::lincomb_int_reduce_shift(
a,
b,
&I64::from_i64(t[1][0]),
&I64::from_i64(t[1][1]),
shift,
),
)
}
#[inline]
const fn update_de<const LIMBS: usize>(
d: &SignedInt<LIMBS>,
e: &SignedInt<LIMBS>,
m: &Uint<LIMBS>,
mi: u64,
t: Matrix,
shift: u32,
) -> (SignedInt<LIMBS>, SignedInt<LIMBS>) {
(
SignedInt::lincomb_int_reduce_shift_mod(
d,
e,
&Int::from_i64(t[0][0]),
&Int::from_i64(t[0][1]),
shift,
m,
U64::from_u64(mi),
),
SignedInt::lincomb_int_reduce_shift_mod(
d,
e,
&Int::from_i64(t[1][0]),
&Int::from_i64(t[1][1]),
shift,
m,
U64::from_u64(mi),
),
)
}
#[inline]
const fn conditional_negate_in_place_wide<const L: usize, const H: usize>(
lo: &mut Uint<L>,
hi: &mut Uint<H>,
flag: Choice,
) {
let (neg, carry) = lo.carrying_neg();
let hi_neg = hi
.not()
.wrapping_add(&Uint::select(&Uint::ZERO, &Uint::ONE, carry));
*lo = Uint::select(lo, &neg, flag);
*hi = Uint::select(hi, &hi_neg, flag);
}
#[inline]
const fn shr_in_place_wide<const L: usize, const H: usize>(
lo: &mut Uint<L>,
hi: &mut Uint<H>,
shift: u32,
) {
debug_assert!(H <= L);
debug_assert!(shift < Uint::<H>::BITS);
let copy = hi.shl_vartime(Uint::<H>::BITS - shift);
*hi = hi.shr_vartime(shift);
*lo = lo.shr_vartime(shift);
let mut offs = shift.div_ceil(Limb::BITS) as usize;
lo.limbs[L - offs] = lo.limbs[L - offs].bitor(copy.limbs[H - offs]);
loop {
offs -= 1;
if offs == 0 {
break;
}
lo.limbs[L - offs] = copy.limbs[H - offs];
}
}
#[inline]
#[allow(clippy::integer_division_remainder_used, reason = "public parameter")]
const fn iterations(bits: u32) -> u32 {
(45907 * bits + 30179) / 19929
}
#[derive(Clone, Copy)]
struct SignedInt<const LIMBS: usize> {
sign: Choice,
magnitude: Uint<LIMBS>,
}
impl<const LIMBS: usize> SignedInt<LIMBS> {
pub const ZERO: Self = Self::from_uint(Uint::ZERO);
pub const fn from_uint(uint: Uint<LIMBS>) -> Self {
Self {
sign: Choice::FALSE,
magnitude: uint,
}
}
pub const fn from_uint_sign(magnitude: Uint<LIMBS>, sign: Choice) -> Self {
Self { sign, magnitude }
}
pub const fn magnitude(&self) -> Uint<LIMBS> {
self.magnitude
}
pub const fn is_nonzero(&self) -> Choice {
self.magnitude.is_nonzero()
}
pub const fn is_zero_vartime(&self) -> bool {
self.magnitude.is_zero_vartime()
}
pub const fn is_negative(&self) -> Choice {
self.sign
}
#[allow(clippy::cast_possible_wrap)]
pub const fn lowest(&self) -> i64 {
let mag = (self.magnitude.as_uint_ref().lowest_u64() & (u64::MAX >> 1)) as i64;
self.sign.select_i64(mag, mag.wrapping_neg())
}
#[inline]
pub(crate) const fn lincomb_int<const RHS: usize>(
a: &SignedInt<LIMBS>,
b: &SignedInt<LIMBS>,
c: &Int<RHS>,
d: &Int<RHS>,
) -> (Uint<LIMBS>, Uint<RHS>, Choice) {
let (c, c_sign) = c.abs_sign();
let (d, d_sign) = d.abs_sign();
let (mut x, mut x_hi) = a.magnitude.widening_mul(&c);
let x_neg = a.sign.xor(c_sign);
let (mut y, mut y_hi) = b.magnitude.widening_mul(&d);
let y_neg = b.sign.xor(d_sign);
let odd_neg = x_neg.xor(y_neg);
conditional_negate_in_place_wide(&mut y, &mut y_hi, odd_neg.not());
let mut borrow;
(x, borrow) = x.borrowing_sub(&y, Limb::ZERO);
(x_hi, borrow) = x_hi.borrowing_sub(&y_hi, borrow);
let swap = borrow.is_nonzero().and(odd_neg);
conditional_negate_in_place_wide(&mut x, &mut x_hi, swap);
let sign = x_neg.and(swap.not()).or(y_neg.and(swap));
(x, x_hi, sign)
}
pub(crate) const fn lincomb_int_reduce_shift<const S: usize>(
a: &Self,
b: &Self,
c: &Int<S>,
d: &Int<S>,
shift: u32,
) -> Self {
debug_assert!(shift < Uint::<S>::BITS);
let (mut a, mut a_hi, a_sign) = Self::lincomb_int(a, b, c, d);
shr_in_place_wide(&mut a, &mut a_hi, shift);
SignedInt::from_uint_sign(a, a_sign)
}
pub(crate) const fn lincomb_int_reduce_shift_mod<const S: usize>(
a: &Self,
b: &Self,
c: &Int<S>,
d: &Int<S>,
shift: u32,
m: &Uint<LIMBS>,
mi: Uint<S>,
) -> SignedInt<LIMBS> {
debug_assert!(shift < Uint::<S>::BITS);
let (mut x, mut x_hi, mut x_sign) = SignedInt::lincomb_int(a, b, c, d);
let mut mf = x.resize::<S>().wrapping_mul(&mi);
mf = mf.bitand(&Uint::MAX.shr_vartime(Uint::<S>::BITS - shift));
let (xa, xa_hi) = m.widening_mul(&mf);
let mut borrow;
(x, borrow) = x.borrowing_sub(&xa, Limb::ZERO);
(x_hi, borrow) = x_hi.borrowing_sub(&xa_hi, borrow);
let swap = borrow.is_nonzero();
conditional_negate_in_place_wide(&mut x, &mut x_hi, swap);
x_sign = x_sign.xor(swap);
shr_in_place_wide(&mut x, &mut x_hi, shift);
debug_assert!(
x_hi.shr1().is_nonzero().not().to_bool_vartime(),
"overflow was larger than one bit"
);
x = x.try_sub_with_carry(x_hi.limbs[0], m).0;
SignedInt::from_uint_sign(x, x_sign)
}
const fn norm(&self, f_sign: Choice, m: &Uint<LIMBS>) -> Uint<LIMBS> {
let swap = f_sign.xor(self.sign).and(self.is_nonzero());
Uint::select(&self.magnitude, &m.wrapping_sub(&self.magnitude), swap)
}
pub const fn eq(a: &Self, b: &Self) -> Choice {
Uint::eq(&a.magnitude, &b.magnitude).and(a.sign.eq(b.sign).or(a.is_nonzero().not()))
}
}
impl<const LIMBS: usize> fmt::Debug for SignedInt<LIMBS> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_fmt(format_args!(
"{}0x{}",
if self.sign.to_bool_vartime() {
"-"
} else {
"+"
},
&self.magnitude
))
}
}
impl<const LIMBS: usize> PartialEq for SignedInt<LIMBS> {
fn eq(&self, other: &Self) -> bool {
Self::eq(self, other).to_bool_vartime()
}
}
#[cfg(test)]
mod tests {
use super::SafeGcdInverter;
use crate::{U128, U256, modular::safegcd::shr_in_place_wide};
#[test]
fn invert() {
let g =
U256::from_be_hex("00000000CBF9350842F498CE441FC2DC23C7BF47D3DE91C327B2157C5E4EED77");
let modulus =
U256::from_be_hex("FFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551")
.to_odd()
.unwrap();
let inverter = SafeGcdInverter::new(&modulus, &U256::ONE);
let result = inverter.invert(&g).unwrap();
assert_eq!(
U256::from_be_hex("FB668F8F509790BC549B077098918604283D42901C92981062EB48BC723F617B"),
result
);
}
#[test]
fn shr_wide() {
let hi = U128::from_u128(0x11111111222222223333333344444444);
let lo = U256::MAX;
let (mut a, mut a_hi) = (lo, hi);
shr_in_place_wide(&mut a, &mut a_hi, 16);
assert_eq!(
a,
U256::from_be_hex("4444FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF")
);
assert_eq!(a_hi, U128::from_u128(0x1111111122222222333333334444));
let (mut b, mut b_hi) = (lo, hi);
shr_in_place_wide(&mut b, &mut b_hi, 68);
assert_eq!(
b,
U256::from_be_hex("23333333344444444FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF")
);
assert_eq!(b_hi, U128::from_u128(0x111111112222222));
}
}