use super::Divstep;
use crate::u256::U256;
impl Divstep {
pub(crate) fn inverse(self) -> Option<U256> {
if !self.g.is_zero() {
return None;
}
if self.f != U256::ONE {
return None;
}
let result = if self.f_neg {
U256::ZERO.sub_mod(&self.d, &self.modulus)
} else {
self.d
};
if result.is_zero() {
None
} else {
Some(result)
}
}
}
#[cfg(test)]
mod ai_tests {
use super::*;
const P: U256 = U256::from_be_limbs([
0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF,
0xFFFFFFFFFFFFFFFF, 0xFFFFFFFEFFFFFC2F,
]);
#[test]
fn inverse_of_seven() {
let inv = Divstep::new(P, U256::from_be_limbs([0, 0, 0, 7]))
.run()
.inverse()
.unwrap();
let fermat = U256::from_be_limbs([0, 0, 0, 7]).mod_inv_fermat(&P).unwrap();
assert_eq!(inv, fermat);
}
#[test]
fn inverse_of_one() {
let inv = Divstep::new(P, U256::ONE)
.run()
.inverse()
.unwrap();
assert_eq!(inv, U256::ONE);
}
#[test]
fn zero_has_no_inverse() {
let result = Divstep::new(P, U256::ZERO)
.run()
.inverse();
assert!(result.is_none());
}
#[test]
fn roundtrip() {
let a = U256::from_be_limbs([0, 0, 0, 42]);
let inv = Divstep::new(P, a).run().inverse().unwrap();
assert_eq!(a.mul_mod(&inv, &P), U256::ONE);
}
#[test]
fn generator_matches_lehmer() {
let gx = U256::from_be_limbs([
0x79BE667EF9DCBBAC, 0x55A06295CE870B07,
0x029BFCDB2DCE28D9, 0x59F2815B16F81798,
]);
let divstep_inv = Divstep::new(P, gx).run().inverse().unwrap();
let lehmer_inv = gx.mod_inv(&P).unwrap();
assert_eq!(divstep_inv, lehmer_inv);
}
#[test]
fn inverse_of_p_minus_one() {
let p_minus_1 = P - U256::ONE;
let inv = Divstep::new(P, p_minus_1).run().inverse().unwrap();
assert_eq!(inv, p_minus_1);
}
#[test]
fn inverse_of_two() {
let two = U256::from_be_limbs([0, 0, 0, 2]);
let inv = Divstep::new(P, two).run().inverse().unwrap();
assert_eq!(two.mul_mod(&inv, &P), U256::ONE);
}
#[test]
fn batch_matches_fermat() {
let values = [
U256::from_be_limbs([0, 0, 0, 3]),
U256::from_be_limbs([0, 0, 0, 100]),
U256::from_be_limbs([0, 0, 0, 65537]),
U256::from_be_limbs([
0x79BE667EF9DCBBAC, 0x55A06295CE870B07,
0x029BFCDB2DCE28D9, 0x59F2815B16F81798,
]),
];
for v in &values {
let divstep = Divstep::new(P, *v).run().inverse().unwrap();
let fermat = v.mod_inv_fermat(&P).unwrap();
assert_eq!(divstep, fermat, "mismatch for {v:?}");
}
}
}