use crate::traits::PrimitiveUint;
pub(crate) fn add<T, const M: u64>(lhs: &T, rhs: &T) -> T
where
T: PrimitiveUint,
{
let modulus = T::from_u64(M).expect("the modulus fits into `T`");
let result = lhs.wrapping_add(rhs);
if result >= modulus || &result < lhs {
result.wrapping_sub(&modulus)
} else {
result
}
}
pub(crate) fn sub<T, const M: u64>(lhs: &T, rhs: &T) -> T
where
T: PrimitiveUint,
{
let modulus = T::from_u64(M).expect("the modulus fits into `T`");
let result = lhs.wrapping_sub(rhs);
if lhs < rhs {
result.wrapping_add(&modulus)
} else {
result
}
}
pub(crate) fn mul<T, const M: u64>(lhs: &T, rhs: &T) -> T
where
T: PrimitiveUint,
{
T::reduce_from_wide::<M>(lhs.to_wide() * rhs.to_wide())
}
pub(crate) fn neg<T, const M: u64>(arg: &T) -> T
where
T: PrimitiveUint,
{
if arg == &T::ZERO {
T::ZERO
} else {
T::from_u64(M).expect("the modulus fits into `T`") - *arg
}
}
pub(crate) fn modular_inverse<T, const M: u64>(arg: &T) -> Option<T>
where
T: PrimitiveUint,
{
let modulus = T::from_u64(M).expect("the modulus fits into `T`");
#[derive(Clone, Copy)]
struct Signed<T> {
value: T,
is_negative: bool,
}
if modulus <= T::ONE {
return None;
}
if arg == &T::ZERO {
return None;
}
let mut a = *arg;
let mut b = modulus;
let mut x0 = Signed {
value: T::ZERO,
is_negative: false,
}; let mut x1 = Signed {
value: T::ONE,
is_negative: false,
};
while a > T::ONE {
if b == T::ZERO {
return None;
}
let t = b;
let q = a / b;
b = a % b;
a = t;
let temp_x0 = x0;
let qx0 = q * x0.value;
debug_assert!(!(x0.is_negative == x1.is_negative && x1.value == qx0));
if x0.is_negative != x1.is_negative {
x0.value = x1.value + qx0;
x0.is_negative = x1.is_negative;
} else if x1.value > qx0 {
x0.value = x1.value - qx0;
x0.is_negative = x1.is_negative;
} else {
x0.value = qx0 - x1.value;
x0.is_negative = !x0.is_negative;
}
x1 = temp_x0;
}
Some(if x1.is_negative {
modulus - x1.value
} else {
x1.value
})
}
#[cfg(test)]
mod tests {
use super::modular_inverse;
use proptest::prelude::*;
#[test]
fn inverse_of_zero() {
const M: u64 = 0xfffffffffffffe95u64;
assert!(modular_inverse::<u64, M>(&0).is_none());
}
proptest! {
#[test]
fn inverse(x in any::<u64>()) {
const M: u64 = 0xfffffffffffffe95u64;
let x = if x == 0 {
1
}
else {
x
};
let inv = modular_inverse::<u64, M>(&x).unwrap();
let should_be_one = ((inv as u128) * (x as u128) % (M as u128)) as u64;
assert_eq!(should_be_one, 1);
}
}
}