pub(crate) struct Montgomery32 {
q: u32,
q_inv: u32, r_sq: u32, }
impl Montgomery32 {
pub const fn new(q: u32) -> Self {
let mut q_inv = q;
let mut i = 0;
while i < 5 {
q_inv = q_inv.wrapping_mul(2u32.wrapping_sub(q.wrapping_mul(q_inv)));
i += 1;
}
q_inv = q_inv.wrapping_neg();
assert!(u32::MAX == q_inv.wrapping_mul(q));
let r_q = (1u64 << 32) % q as u64;
let r_sq = ((r_q * r_q) % q as u64) as u32;
assert!(r_sq as u128 == (1u128 << 64) % q as u128);
Self { q, q_inv, r_sq }
}
#[inline(always)]
pub fn reduce(&self, t: u64) -> u32 {
let m = (t as u32).wrapping_mul(self.q_inv);
let mq = (m as u64) * (self.q as u64);
let (sum, overflowed) = t.overflowing_add(mq);
let mut reduced = sum >> 32;
if overflowed {
reduced += 1u64 << 32;
}
debug_assert!(
reduced < 2 * (self.q as u64),
"Montgomery reduced value is too large"
);
if let Some(out) = reduced.checked_sub(self.q as u64) {
out as u32
} else {
reduced as u32
}
}
#[inline(always)]
pub fn mul(&self, a: u32, b: u32) -> u32 {
debug_assert!(
a < self.q && b < self.q,
"Montgomery mul inputs must be < q"
);
self.reduce(a as u64 * b as u64)
}
#[inline(always)]
pub fn neg(&self, a: u32) -> u32 {
debug_assert!(a < self.q, "Montgomery neg inputs must be < q");
if a == 0 {
return 0;
}
self.q - a
}
#[inline(always)]
pub fn add(&self, a: u32, b: u32) -> u32 {
debug_assert!(
a < self.q && b < self.q,
"Montgomery add inputs must be < q"
);
if b == 0 {
return a;
}
let sum = a as u64 + b as u64;
if let Some(out) = sum.checked_sub(self.q as u64) {
out as u32
} else {
sum as u32
}
}
#[inline(always)]
pub fn to_mont(&self, a: u32) -> u32 {
debug_assert!(a < self.q, "Montgomery input is not in the prime field");
self.mul(a, self.r_sq)
}
#[inline(always)]
pub fn from_mont(&self, a: u32) -> u32 {
self.reduce(a as u64)
}
}
#[test]
fn mont_roundtrip() {
let q = 455 * 2u32.pow(20) * 9 + 1;
let mont = Montgomery32::new(q);
for i in 0..1000 {
let a_roundtrip = mont.from_mont(mont.to_mont(i));
assert_eq!(i, a_roundtrip);
}
for _ in 0..1000 {
let a = rand::random::<u32>() % q;
let a_roundtrip = mont.from_mont(mont.to_mont(a));
assert_eq!(a, a_roundtrip);
}
}
#[test]
fn mont_muls() {
let q = 2u32.pow(31) - 1;
let mont = Montgomery32::new(q);
for _ in 0..1000 {
let a = rand::random::<u32>() % q;
let b = rand::random::<u32>() % q;
let c = ((a as u64 * b as u64) % q as u64) as u32;
let a = mont.to_mont(a);
let b = mont.to_mont(b);
let c_mont = mont.mul(a, b);
assert_eq!(c, mont.from_mont(c_mont));
}
}
#[test]
fn mont_adds() {
let q = 2u32.pow(31) - 1;
let mont = Montgomery32::new(q);
for _ in 0..1000 {
let a = rand::random::<u32>() % q;
let b = rand::random::<u32>() % q;
let c = ((a as u64 + b as u64) % q as u64) as u32;
let a = mont.to_mont(a);
let b = mont.to_mont(b);
let c_mont = mont.add(a, b);
assert_eq!(c, mont.from_mont(c_mont));
}
}
#[test]
fn mont_subs() {
let q = 2u32.pow(31) - 1;
let mont = Montgomery32::new(q);
for _ in 0..1000 {
let a = rand::random::<u32>() % q;
let b = rand::random::<u32>() % q;
let c = ((a as u64 + (2 * q as u64 - b as u64)) % q as u64) as u32;
let a = mont.to_mont(a);
let b = mont.to_mont(b);
let c_mont = mont.add(a, mont.neg(b));
assert_eq!(c, mont.from_mont(c_mont));
}
}
#[test]
fn mont_basic_properties() {
let q = 2u32.pow(31) - 1;
let mont = Montgomery32::new(q);
let zero = mont.to_mont(0);
assert_eq!(zero, 0, "Montgomery form of 0 should be 0");
assert_eq!(mont.from_mont(zero), 0);
let one = mont.to_mont(1);
println!("Montgomery form of 1: {}", one);
assert_eq!(mont.from_mont(one), 1);
for _ in 0..10 {
let a = rand::random::<u32>() % q;
let a_mont = mont.to_mont(a);
let result = mont.mul(a_mont, one);
assert_eq!(mont.from_mont(result), a, "a * 1 should equal a");
}
for _ in 0..100 {
let a = rand::random::<u32>() % q;
let a_mont = mont.to_mont(a);
let a_back = mont.from_mont(a_mont);
assert_eq!(a, a_back, "Round trip failed for {}", a);
}
}
#[test]
fn mont_r_squared_check() {
let q = 2u32.pow(31) - 1;
let mont = Montgomery32::new(q);
let r_mod_q = ((1u64 << 32) % q as u64) as u32;
let expected_r_sq = ((r_mod_q as u64 * r_mod_q as u64) % q as u64) as u32;
println!("q: {}", q);
println!("R mod q: {}", r_mod_q);
println!("R^2 mod q (expected): {}", expected_r_sq);
println!("R^2 mod q (computed): {}", mont.r_sq);
assert_eq!(mont.r_sq, expected_r_sq);
}
#[test]
fn mont_q_inv_check() {
let q = 2u32.pow(31) - 1;
let mont = Montgomery32::new(q);
let product = q.wrapping_mul(mont.q_inv);
println!("q: {}", q);
println!("q_inv: {}", mont.q_inv);
println!("q * q_inv: {:#x} (should be 0xFFFFFFFF)", product);
assert_eq!(product, u32::MAX);
}