use super::boxed::BoxedUint;
use super::mul::mac;
use super::uint::{Limb, adc, sbb};
use crate::ct::{Choice, ConditionallySelectable};
use alloc::vec;
use alloc::vec::Vec;
fn inv_mod_2_64(n: u64) -> u64 {
let mut x = 1u64;
let mut i = 0;
while i < 6 {
x = x.wrapping_mul(2u64.wrapping_sub(n.wrapping_mul(x)));
i += 1;
}
x
}
fn adc_limbs(a: &[Limb], b: &[Limb], carry_in: Limb) -> (Vec<Limb>, Limb) {
let mut out = vec![0 as Limb; a.len()];
let mut c = carry_in;
for i in 0..a.len() {
let (s, co) = adc(a[i], b[i], c);
out[i] = s;
c = co;
}
(out, c)
}
fn sbb_limbs(a: &[Limb], b: &[Limb], borrow_in: Limb) -> (Vec<Limb>, Limb) {
let mut out = vec![0 as Limb; a.len()];
let mut bo = borrow_in;
for i in 0..a.len() {
let (d, b) = sbb(a[i], b[i], bo);
out[i] = d;
bo = b;
}
(out, bo)
}
fn select_limbs(a: &[Limb], b: &[Limb], choice: Choice) -> Vec<Limb> {
(0..a.len())
.map(|i| Limb::conditional_select(&a[i], &b[i], choice))
.collect()
}
fn add_mod_limbs(n: &[Limb], a: &[Limb], b: &[Limb]) -> Vec<Limb> {
let (sum, carry) = adc_limbs(a, b, 0);
let (diff, borrow) = sbb_limbs(&sum, n, 0);
let subtract = carry | (borrow ^ 1);
select_limbs(&diff, &sum, Choice::from(subtract as u8))
}
fn sub_mod_limbs(n: &[Limb], a: &[Limb], b: &[Limb]) -> Vec<Limb> {
let (diff, borrow) = sbb_limbs(a, b, 0);
let (wrapped, _) = adc_limbs(&diff, n, 0);
select_limbs(&wrapped, &diff, Choice::from(borrow as u8))
}
#[derive(Clone, Debug)]
pub struct BoxedMontModulus {
n: Vec<Limb>,
n_prime: Limb,
r2: Vec<Limb>,
limbs: usize,
}
impl BoxedMontModulus {
pub fn new(modulus: &BoxedUint) -> Self {
assert!(
!modulus.is_zero(),
"BoxedMontModulus::new: modulus must be nonzero"
);
let limbs = modulus.significant_limbs();
let n = modulus.limbs_resized(limbs);
assert!(n[0] & 1 == 1, "Montgomery modulus must be odd");
let n_prime = inv_mod_2_64(n[0]).wrapping_neg();
let mut r2 = vec![0 as Limb; limbs];
r2[0] = 1;
let bits = 2 * 64 * limbs;
for _ in 0..bits {
r2 = add_mod_limbs(&n, &r2, &r2);
}
BoxedMontModulus {
n,
n_prime,
r2,
limbs,
}
}
#[inline]
pub fn limbs(&self) -> usize {
self.limbs
}
fn mont_mul_limbs(&self, a: &[Limb], b: &[Limb]) -> Vec<Limb> {
let l = self.limbs;
let n = &self.n;
let mut t = vec![0 as Limb; l];
let mut ts: Limb = 0;
for &bi in b.iter().take(l) {
let mut carry = 0;
for j in 0..l {
let (s, c) = mac(t[j], a[j], bi, carry);
t[j] = s;
carry = c;
}
let (s, c) = adc(ts, carry, 0);
ts = s;
let ts1 = c;
let m = t[0].wrapping_mul(self.n_prime);
let (_, mut carry) = mac(t[0], m, n[0], 0);
for j in 1..l {
let (s, c) = mac(t[j], m, n[j], carry);
t[j - 1] = s;
carry = c;
}
let (s, c) = adc(ts, carry, 0);
t[l - 1] = s;
ts = ts1 + c;
}
let (diff, borrow_low) = sbb_limbs(&t, n, 0);
let (_, borrow) = sbb(ts, 0, borrow_low);
let ge = Choice::from((borrow ^ 1) as u8);
select_limbs(&diff, &t, ge)
}
fn to_mont_limbs(&self, x: &[Limb]) -> Vec<Limb> {
self.mont_mul_limbs(x, &self.r2)
}
fn demont_limbs(&self, x: &[Limb]) -> Vec<Limb> {
let mut one = vec![0 as Limb; self.limbs];
one[0] = 1;
self.mont_mul_limbs(x, &one)
}
pub fn modulus(&self) -> BoxedUint {
BoxedUint::from_limbs(self.n.clone())
}
pub fn to_mont(&self, x: &BoxedUint) -> BoxedUint {
BoxedUint::from_limbs(self.to_mont_limbs(&x.limbs_resized(self.limbs)))
}
pub fn from_mont(&self, x: &BoxedUint) -> BoxedUint {
BoxedUint::from_limbs(self.demont_limbs(&x.limbs_resized(self.limbs)))
}
pub fn mont_mul(&self, a: &BoxedUint, b: &BoxedUint) -> BoxedUint {
BoxedUint::from_limbs(
self.mont_mul_limbs(&a.limbs_resized(self.limbs), &b.limbs_resized(self.limbs)),
)
}
pub fn mul_mod(&self, a: &BoxedUint, b: &BoxedUint) -> BoxedUint {
let a = a.limbs_resized(self.limbs);
let b = b.limbs_resized(self.limbs);
let t = self.mont_mul_limbs(&a, &b);
BoxedUint::from_limbs(self.mont_mul_limbs(&t, &self.r2))
}
pub fn pow(&self, base: &BoxedUint, exp: &BoxedUint) -> BoxedUint {
let base_m = self.to_mont_limbs(&base.limbs_resized(self.limbs));
let mut one = vec![0 as Limb; self.limbs];
one[0] = 1;
let mut acc = self.to_mont_limbs(&one);
let exp_width = exp.significant_limbs().max(self.limbs);
let exp_limbs = exp.limbs_resized(exp_width);
let mut i = exp_limbs.len();
while i > 0 {
i -= 1;
let limb = exp_limbs[i];
let mut bit = 64;
while bit > 0 {
bit -= 1;
acc = self.mont_mul_limbs(&acc, &acc);
let mult = self.mont_mul_limbs(&acc, &base_m);
let set = Choice::from(((limb >> bit) & 1) as u8);
acc = select_limbs(&mult, &acc, set);
}
}
BoxedUint::from_limbs(self.demont_limbs(&acc))
}
pub fn add_mod(&self, a: &BoxedUint, b: &BoxedUint) -> BoxedUint {
BoxedUint::from_limbs(add_mod_limbs(
&self.n,
&a.limbs_resized(self.limbs),
&b.limbs_resized(self.limbs),
))
}
pub fn sub_mod(&self, a: &BoxedUint, b: &BoxedUint) -> BoxedUint {
BoxedUint::from_limbs(sub_mod_limbs(
&self.n,
&a.limbs_resized(self.limbs),
&b.limbs_resized(self.limbs),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bignum::{MontModulus, Uint};
#[test]
fn modexp_matches_u128() {
let moduli: [u64; 3] = [0xFFFF_FFFF_FFFF_FFFF, 0x8000_0000_0000_0001, 1_000_003];
let bases: [u64; 3] = [2, 3, 0x1234_5678_9abc_def1];
let exps: [u64; 3] = [1, 17, 0xdead_beef];
for &nv in &moduli {
let m = BoxedMontModulus::new(&BoxedUint::from_u64(nv));
for &b in &bases {
for &e in &exps {
let got = m
.pow(&BoxedUint::from_u64(b % nv), &BoxedUint::from_u64(e))
.to_be_bytes(8);
let nn = nv as u128;
let mut r: u128 = 1 % nn;
let mut base = (b % nv) as u128 % nn;
let mut exp = e;
while exp > 0 {
if exp & 1 == 1 {
r = r * base % nn;
}
base = base * base % nn;
exp >>= 1;
}
let mut expected = [0u8; 8];
expected.copy_from_slice(&(r as u64).to_be_bytes());
assert_eq!(got, expected, "n={nv} b={b} e={e}");
}
}
}
}
#[test]
fn textbook_rsa() {
let m = BoxedMontModulus::new(&BoxedUint::from_u64(3233));
let msg = BoxedUint::from_u64(65);
let ct = m.pow(&msg, &BoxedUint::from_u64(17));
assert_eq!(ct, BoxedUint::from_u64(2790));
assert_eq!(m.pow(&ct, &BoxedUint::from_u64(2753)), msg);
}
#[test]
#[should_panic(expected = "modulus must be nonzero")]
fn new_zero_modulus_panics() {
let _ = BoxedMontModulus::new(&BoxedUint::zero(2));
}
#[test]
fn pow_does_not_truncate_overwide_exponent() {
let n: u64 = 0xFFFF_FFFF_FFFF_FFC5; let m = BoxedMontModulus::new(&BoxedUint::from_u64(n));
let exp = BoxedUint::from_limbs(vec![0, 1]);
let got = m.pow(&BoxedUint::from_u64(3), &exp).to_be_bytes(8);
let mut r: u128 = 3;
for _ in 0..64 {
r = (r * r) % n as u128;
}
let expected = (r as u64).to_be_bytes();
assert_eq!(got, expected);
assert_ne!(got, [0, 0, 0, 0, 0, 0, 0, 1]);
}
#[test]
fn matches_const_generic_256bit() {
let n4 = Uint::<4>::from_limbs([
0x1234_5678_9abc_def1,
0xfedc_ba98_7654_3211,
0x0f0f_0f0f_0f0f_0f0f,
0x8000_0000_0000_0001,
]);
let mut n_bytes = [0u8; 32];
n4.write_be_bytes(&mut n_bytes);
let base4 = Uint::<4>::from_u64(0xdead_beef);
let exp4 = Uint::<4>::from_u64(65537);
let fixed = MontModulus::new(n4).pow(&base4, &exp4);
let mut fixed_bytes = [0u8; 32];
fixed.write_be_bytes(&mut fixed_bytes);
let boxed = BoxedMontModulus::new(&BoxedUint::from_be_bytes(&n_bytes)).pow(
&BoxedUint::from_u64(0xdead_beef),
&BoxedUint::from_u64(65537),
);
assert_eq!(boxed.to_be_bytes(32), fixed_bytes);
}
}