use super::Uint;
use crate::ct::ConstantTimeLess;
#[cfg(feature = "alloc")]
use super::BoxedUint;
#[cfg(feature = "alloc")]
pub fn inv_mod_boxed(a: &BoxedUint, m: &BoxedUint) -> Option<BoxedUint> {
if a.is_zero() || m.is_zero() {
return None;
}
let one = BoxedUint::from_u64(1);
let (mut old_r, mut r) = (a.reduce(m), m.clone());
let (mut old_s, mut old_neg) = (one.clone(), false);
let (mut s, mut s_neg) = (BoxedUint::zero(m.limbs()), false);
while !r.is_zero() {
let (q, rem) = old_r.divrem(&r);
old_r = r;
r = rem;
let qs = q.mul(&s);
let (new_s, new_neg) = signed_sub_boxed(&old_s, old_neg, &qs, s_neg);
old_s = s;
old_neg = s_neg;
s = new_s;
s_neg = new_neg;
}
if old_r != one {
return None; }
if old_neg {
Some(m.sub(&old_s).reduce(m))
} else {
Some(old_s.reduce(m))
}
}
#[cfg(feature = "alloc")]
fn signed_sub_boxed(a: &BoxedUint, a_neg: bool, b: &BoxedUint, b_neg: bool) -> (BoxedUint, bool) {
if a_neg == b_neg {
if !a.lt(b) {
(a.sub(b), a_neg)
} else {
(b.sub(a), !a_neg)
}
} else {
(a.add(b), a_neg)
}
}
pub fn inv_mod<const LIMBS: usize>(a: &Uint<LIMBS>, m: &Uint<LIMBS>) -> Option<Uint<LIMBS>> {
if bool::from(a.is_zero()) || bool::from(m.is_zero()) {
return None;
}
let one = Uint::ONE;
let (mut old_r, mut r) = (a.reduce(m), *m);
let (mut old_s, mut old_neg) = (one, false);
let (mut s, mut s_neg) = (Uint::ZERO, false);
while !bool::from(r.is_zero()) {
let (q, rem) = old_r.divrem(&r);
old_r = r;
r = rem;
let qs = q.mul_wide(&s).0;
let (new_s, new_neg) = signed_sub(&old_s, old_neg, &qs, s_neg);
old_s = s;
old_neg = s_neg;
s = new_s;
s_neg = new_neg;
}
if old_r != one {
return None; }
if old_neg {
Some(m.wrapping_sub(&old_s))
} else {
Some(old_s)
}
}
fn signed_sub<const LIMBS: usize>(
a: &Uint<LIMBS>,
a_neg: bool,
b: &Uint<LIMBS>,
b_neg: bool,
) -> (Uint<LIMBS>, bool) {
if a_neg == b_neg {
if !bool::from(a.ct_lt(b)) {
(a.wrapping_sub(b), a_neg) } else {
(b.wrapping_sub(a), !a_neg)
}
} else {
(a.wrapping_add(b), a_neg)
}
}
#[cfg(test)]
mod tests {
use super::super::MontModulus;
use super::*;
use crate::ct::ConstantTimeEq;
#[test]
fn small_inverses() {
assert_eq!(
inv_mod(&Uint::<1>::from_u64(3), &Uint::<1>::from_u64(11)),
Some(Uint::<1>::from_u64(4))
);
assert_eq!(
inv_mod(&Uint::<1>::from_u64(7), &Uint::<1>::from_u64(15)),
Some(Uint::<1>::from_u64(13))
);
assert_eq!(
inv_mod(&Uint::<1>::from_u64(3), &Uint::<1>::from_u64(10)),
Some(Uint::<1>::from_u64(7))
);
assert_eq!(
inv_mod(&Uint::<1>::ONE, &Uint::<1>::from_u64(97)),
Some(Uint::<1>::ONE)
);
}
#[test]
fn non_invertible_returns_none() {
assert_eq!(
inv_mod(&Uint::<1>::from_u64(3), &Uint::<1>::from_u64(15)),
None );
assert_eq!(
inv_mod(&Uint::<1>::from_u64(4), &Uint::<1>::from_u64(10)),
None );
assert_eq!(inv_mod(&Uint::<1>::ZERO, &Uint::<1>::from_u64(7)), None);
}
#[test]
fn inverse_property_u64() {
let moduli: [u64; 4] = [97, 0xFFFF_FFFF_FFFF_FFFF, 1_000_003, 0x1_0000_0000];
let vals: [u64; 4] = [2, 3, 0x1234_5678, 0xfedc_ba98_7654_3211];
for &m in &moduli {
for &a in &vals {
let a = a % m;
if a == 0 {
continue;
}
if let Some(inv) = inv_mod(&Uint::<1>::from_u64(a), &Uint::<1>::from_u64(m)) {
let prod = (a as u128 * inv.as_limbs()[0] as u128 % m as u128) as u64;
assert_eq!(prod, 1, "a={a} m={m}");
}
}
}
}
#[test]
fn inverse_property_128bit_odd() {
let m = Uint::<2>::from_limbs([0x1234_5678_9abc_def1, 0x0fed_cba9_8765_4321]);
let modulus = MontModulus::new(m);
let a = Uint::<2>::from_u64(0x9e3779b97f4a7c15);
let inv = inv_mod(&a, &m).expect("a coprime to m");
assert!(bool::from(modulus.mul_mod(&a, &inv).ct_eq(&Uint::ONE)));
}
#[test]
fn rsa_style_even_modulus() {
let phi = Uint::<2>::from_u64(0x0003_a8f2_1c4b_d7e8); let e = Uint::<2>::from_u64(65537);
let d = inv_mod(&e, &phi).expect("65537 coprime to phi");
let prod = e.mul_wide(&d).0; assert_eq!(prod.divrem(&phi).1, Uint::ONE);
}
}