use crate::primitives::big_number::BigNumber;
use crate::primitives::k256::k256_reduce_limbs;
use crate::primitives::montgomery::Montgomery;
use std::sync::Arc;
#[derive(Debug)]
pub struct ReductionContext {
pub m: BigNumber,
prime: Option<Box<dyn MersennePrime>>,
pub mont: Option<Montgomery>,
}
pub trait MersennePrime: std::fmt::Debug + Send + Sync {
fn ireduce(&self, num: &mut BigNumber);
fn p(&self) -> &BigNumber;
}
impl ReductionContext {
pub fn new(m: BigNumber) -> Arc<Self> {
Arc::new(ReductionContext {
m,
prime: None,
mont: None,
})
}
pub fn k256() -> Arc<Self> {
let k = crate::primitives::k256::K256::new();
let m = k.p().clone();
let mont = Montgomery::new(&m);
Arc::new(ReductionContext {
m,
prime: Some(Box::new(k)),
mont: Some(mont),
})
}
pub fn with_prime(prime: Box<dyn MersennePrime>) -> Arc<Self> {
let m = prime.p().clone();
Arc::new(ReductionContext {
m,
prime: Some(prime),
mont: None,
})
}
pub fn imod(&self, a: &BigNumber) -> BigNumber {
if let Some(ref prime) = self.prime {
let mut r = a.clone();
prime.ireduce(&mut r);
r
} else {
a.umod(&self.m).unwrap_or_else(|_| BigNumber::zero())
}
}
pub fn convert_to(&self, num: &BigNumber) -> BigNumber {
num.umod(&self.m).unwrap_or_else(|_| BigNumber::zero())
}
pub fn convert_from(&self, num: &BigNumber) -> BigNumber {
let mut r = num.clone();
r.red = None;
r
}
pub fn neg(&self, a: &BigNumber) -> BigNumber {
if a.is_zero() {
return a.clone();
}
self.m.sub(a)
}
pub fn add(&self, a: &BigNumber, b: &BigNumber) -> BigNumber {
let mut res = a.add(b);
res = res.sub(&self.m);
if res.is_neg() {
res = res.add(&self.m);
}
res.red = a.red.clone();
res
}
pub fn sub(&self, a: &BigNumber, b: &BigNumber) -> BigNumber {
let mut res = a.sub(b);
if res.is_neg() {
res = res.add(&self.m);
}
res.red = a.red.clone();
res
}
pub fn mul(&self, a: &BigNumber, b: &BigNumber) -> BigNumber {
let a_limbs = a.get_limbs();
let b_limbs = b.get_limbs();
if self.prime.is_some() && a_limbs.len() == 4 && b_limbs.len() == 4 {
let a4: [u64; 4] = [a_limbs[0], a_limbs[1], a_limbs[2], a_limbs[3]];
let b4: [u64; 4] = [b_limbs[0], b_limbs[1], b_limbs[2], b_limbs[3]];
let prod8 = crate::primitives::big_number::mul_4x4(&a4, &b4);
let reduced = k256_reduce_limbs(&prod8);
let mut result = BigNumber::from_raw_limbs(&reduced);
result.red = a.red.clone();
return result;
}
let prod = a.mul(b);
let mut result = self.imod(&prod);
result.red = a.red.clone();
result
}
pub fn sqr(&self, a: &BigNumber) -> BigNumber {
let a_limbs = a.get_limbs();
if self.prime.is_some() && a_limbs.len() == 4 {
let a4: [u64; 4] = [a_limbs[0], a_limbs[1], a_limbs[2], a_limbs[3]];
let prod8 = crate::primitives::big_number::mul_4x4(&a4, &a4);
let reduced = k256_reduce_limbs(&prod8);
let mut result = BigNumber::from_raw_limbs(&reduced);
result.red = a.red.clone();
return result;
}
let sq = a.sqr();
let mut result = self.imod(&sq);
result.red = a.red.clone();
result
}
pub fn invm(&self, a: &BigNumber) -> BigNumber {
let inv = a.invm(&self.m).unwrap_or_else(|_| BigNumber::zero());
let mut result = self.imod(&inv);
result.red = a.red.clone();
result
}
pub fn pow(&self, a: &BigNumber, exp: &BigNumber) -> BigNumber {
if exp.is_zero() {
let mut one = BigNumber::one();
one.red = a.red.clone();
return one;
}
let mut result = BigNumber::one();
result.red = a.red.clone();
let base = a.clone();
let bits = exp.bit_length();
for i in (0..bits).rev() {
result = self.sqr(&result);
if exp.testn(i) {
result = self.mul(&result, &base);
}
}
result
}
pub fn sqrt(&self, a: &BigNumber) -> BigNumber {
if a.is_zero() {
return a.clone();
}
let mod4 = self.m.andln(2);
if mod4 != 0 {
let exp = self.m.addn(1);
let exp = exp.ushrn(2);
return self.pow(a, &exp);
}
let mut q = self.m.subn(1);
let mut s = 0usize;
while !q.is_zero() && q.andln(1) == 0 {
s += 1;
q.iushrn(1);
}
let one = BigNumber::one();
let one_red = {
let mut o = one.clone();
o.red = a.red.clone();
o
};
let neg_one = self.neg(&one_red);
let lpow = self.m.subn(1).ushrn(1);
let zl = self.m.bit_length();
let mut z = BigNumber::from_number(2 * (zl * zl) as i64);
z.red = a.red.clone();
while self.pow(&z, &lpow).cmp(&neg_one) != 0 {
let neg_one_clone = neg_one.clone();
z = self.add(&z, &neg_one_clone);
}
let mut c = self.pow(&z, &q);
let mut r = self.pow(a, &q.addn(1).ushrn(1));
let mut t = self.pow(a, &q);
let mut m = s;
while t.cmp(&one_red) != 0 {
let mut tmp = t.clone();
let mut i = 0;
while tmp.cmp(&one_red) != 0 {
tmp = self.sqr(&tmp);
i += 1;
}
let mut shift = BigNumber::one();
shift.iushln(m - i - 1);
let b = self.pow(&c, &shift);
r = self.mul(&r, &b);
c = self.sqr(&b);
t = self.mul(&t, &c);
m = i;
}
r
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_reduction_context_basic() {
let ctx = ReductionContext::new(BigNumber::from_number(7));
let a = BigNumber::from_number(10);
let result = ctx.imod(&a);
assert_eq!(result.to_number(), Some(3)); }
#[test]
fn test_reduction_context_add() {
let ctx = ReductionContext::new(BigNumber::from_number(7));
let a = BigNumber::from_number(5);
let b = BigNumber::from_number(4);
let result = ctx.add(&a, &b);
assert_eq!(result.to_number(), Some(2)); }
#[test]
fn test_reduction_context_sub() {
let ctx = ReductionContext::new(BigNumber::from_number(7));
let a = BigNumber::from_number(3);
let b = BigNumber::from_number(5);
let result = ctx.sub(&a, &b);
assert_eq!(result.to_number(), Some(5)); }
#[test]
fn test_reduction_context_mul() {
let ctx = ReductionContext::new(BigNumber::from_number(7));
let a = BigNumber::from_number(3);
let b = BigNumber::from_number(4);
let result = ctx.mul(&a, &b);
assert_eq!(result.to_number(), Some(5)); }
#[test]
fn test_reduction_context_invm() {
let ctx = ReductionContext::new(BigNumber::from_number(11));
let a = BigNumber::from_number(3);
let inv = ctx.invm(&a);
let check = ctx.mul(&a, &inv);
assert_eq!(check.to_number(), Some(1));
}
#[test]
fn test_reduction_context_pow() {
let ctx = ReductionContext::new(BigNumber::from_number(7));
let a = BigNumber::from_number(3);
let exp = BigNumber::from_number(2);
let result = ctx.pow(&a, &exp);
assert_eq!(result.to_number(), Some(2)); }
}