use crate::primitives::big_number::BigNumber;
use crate::primitives::reduction_context::ReductionContext;
use std::sync::Arc;
#[derive(Debug)]
pub struct Montgomery {
pub m: BigNumber,
pub m_limbs: [u64; 4],
pub r: BigNumber,
pub r2: BigNumber,
pub r2_limbs: [u64; 4],
pub minv: u64,
pub limb_count: usize,
}
#[inline]
pub fn mont_mul_4(a: &[u64; 4], b: &[u64; 4], m: &[u64; 4], m_inv: u64) -> [u64; 4] {
let mut t = [0u64; 5];
for &a_limb in a.iter().take(4) {
let mut carry: u128 = 0;
for j in 0..4 {
let prod = (a_limb as u128) * (b[j] as u128) + (t[j] as u128) + carry;
t[j] = prod as u64;
carry = prod >> 64;
}
let sum = (t[4] as u128) + carry;
t[4] = sum as u64;
let q = t[0].wrapping_mul(m_inv);
carry = 0;
for j in 0..4 {
let prod = (q as u128) * (m[j] as u128) + (t[j] as u128) + carry;
if j > 0 {
t[j - 1] = prod as u64;
}
carry = prod >> 64;
}
let sum = (t[4] as u128) + carry;
t[3] = sum as u64;
t[4] = (sum >> 64) as u64;
}
let mut result = [t[0], t[1], t[2], t[3]];
if t[4] > 0 || ge_4(&result, m) {
sub_4_inplace(&mut result, m);
}
result
}
#[inline]
#[allow(dead_code)]
pub fn mont_sqr_4(a: &[u64; 4], m: &[u64; 4], m_inv: u64) -> [u64; 4] {
mont_mul_4(a, a, m, m_inv)
}
#[inline(always)]
fn ge_4(a: &[u64; 4], b: &[u64; 4]) -> bool {
for i in (0..4).rev() {
if a[i] > b[i] {
return true;
}
if a[i] < b[i] {
return false;
}
}
true }
#[inline(always)]
fn sub_4_inplace(a: &mut [u64; 4], b: &[u64; 4]) {
let mut borrow: u64 = 0;
for i in 0..4 {
let (d1, c1) = a[i].overflowing_sub(b[i]);
let (d2, c2) = d1.overflowing_sub(borrow);
a[i] = d2;
borrow = (c1 as u64) + (c2 as u64);
}
}
impl Montgomery {
pub fn new(m: &BigNumber) -> Self {
let limb_count = m.bit_length().div_ceil(64);
let bits = limb_count * 64;
let r = BigNumber::one().ushln(bits);
let r2 = r.sqr().umod(m).unwrap_or_else(|_| BigNumber::zero());
let m_slice = m.get_limbs();
let mut m_limbs = [0u64; 4];
for (i, &v) in m_slice.iter().take(4).enumerate() {
m_limbs[i] = v;
}
let r2_slice = r2.get_limbs();
let mut r2_limbs = [0u64; 4];
for (i, &v) in r2_slice.iter().take(4).enumerate() {
r2_limbs[i] = v;
}
let m_low = m_limbs[0];
let minv = compute_minv(m_low);
Montgomery {
m: m.clone(),
m_limbs,
r,
r2,
r2_limbs,
minv,
limb_count,
}
}
pub fn to_mont(&self, a: &BigNumber) -> BigNumber {
if self.limb_count == 4 {
let a_slice = a.get_limbs();
let mut a_limbs = [0u64; 4];
for (i, &v) in a_slice.iter().take(4).enumerate() {
a_limbs[i] = v;
}
let result = mont_mul_4(&a_limbs, &self.r2_limbs, &self.m_limbs, self.minv);
return BigNumber::from_raw_limbs(&result);
}
a.mul(&self.r2)
.umod(&self.m)
.unwrap_or_else(|_| BigNumber::zero())
}
#[allow(clippy::wrong_self_convention)]
pub fn from_mont(&self, a: &BigNumber) -> BigNumber {
if self.limb_count == 4 {
let a_slice = a.get_limbs();
let mut a_limbs = [0u64; 4];
for (i, &v) in a_slice.iter().take(4).enumerate() {
a_limbs[i] = v;
}
let one = [1u64, 0, 0, 0];
let result = mont_mul_4(&a_limbs, &one, &self.m_limbs, self.minv);
return BigNumber::from_raw_limbs(&result);
}
let r_inv = self.r.invm(&self.m).unwrap_or_else(|_| BigNumber::one());
a.mul(&r_inv)
.umod(&self.m)
.unwrap_or_else(|_| BigNumber::zero())
}
pub fn mul(&self, a: &BigNumber, b: &BigNumber) -> BigNumber {
if self.limb_count == 4 {
let a_slice = a.get_limbs();
let b_slice = b.get_limbs();
let mut a_limbs = [0u64; 4];
let mut b_limbs = [0u64; 4];
for (i, &v) in a_slice.iter().take(4).enumerate() {
a_limbs[i] = v;
}
for (i, &v) in b_slice.iter().take(4).enumerate() {
b_limbs[i] = v;
}
let result = mont_mul_4(&a_limbs, &b_limbs, &self.m_limbs, self.minv);
return BigNumber::from_raw_limbs(&result);
}
let product = a.mul(b);
self.reduce(&product)
}
pub fn reduce(&self, t: &BigNumber) -> BigNumber {
if self.limb_count == 4 {
let t_slice = t.get_limbs();
let mut t_limbs = [0u64; 4];
for (i, &v) in t_slice.iter().take(4).enumerate() {
t_limbs[i] = v;
}
let one = [1u64, 0, 0, 0];
let result = mont_mul_4(&t_limbs, &one, &self.m_limbs, self.minv);
return BigNumber::from_raw_limbs(&result);
}
let r_inv = self.r.invm(&self.m).unwrap_or_else(|_| BigNumber::one());
t.mul(&r_inv)
.umod(&self.m)
.unwrap_or_else(|_| BigNumber::zero())
}
#[allow(clippy::wrong_self_convention)]
pub fn to_reduction_context(self) -> Arc<ReductionContext> {
ReductionContext::new(self.m)
}
}
fn compute_minv(m_low: u64) -> u64 {
if m_low == 0 {
return 0;
}
let mut x: u64 = 1;
for _ in 0..6 {
x = x.wrapping_mul(2u64.wrapping_sub(m_low.wrapping_mul(x)));
}
(0u64).wrapping_sub(x)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_montgomery_roundtrip() {
let m = BigNumber::from_number(17);
let mont = Montgomery::new(&m);
let a = BigNumber::from_number(7);
let a_mont = mont.to_mont(&a);
let a_back = mont.from_mont(&a_mont);
assert_eq!(a_back.to_number(), Some(7));
}
#[test]
fn test_montgomery_mul() {
let m = BigNumber::from_number(17);
let mont = Montgomery::new(&m);
let a = BigNumber::from_number(7);
let b = BigNumber::from_number(5);
let a_mont = mont.to_mont(&a);
let b_mont = mont.to_mont(&b);
let result_mont = mont.mul(&a_mont, &b_mont);
let result = mont.from_mont(&result_mont);
assert_eq!(result.to_number(), Some(1));
}
#[test]
fn test_compute_minv() {
let minv = compute_minv(17);
let check = 17u64.wrapping_mul(minv);
assert_eq!(check, u64::MAX); }
}