use crate::GcdResult;
use alloc::vec::Vec;
use core::{
cmp::{Eq, Ordering, PartialEq, PartialOrd},
fmt::{self, Debug, Display},
iter::{Product, Sum},
mem::swap,
ops::{
Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Shl, Shr, Sub,
SubAssign,
},
};
use openssl::bn::{BigNum, BigNumContext, BigNumRef};
use rand::RngCore;
use subtle::{Choice, ConstantTimeEq};
use zeroize::Zeroize;
#[derive(Ord, PartialOrd)]
pub struct Bn(pub(crate) BigNum);
fn from_isize(d: isize) -> BigNum {
if d < 0 {
let mut b = BigNum::from_slice(&(-d).to_be_bytes()).unwrap();
b.set_negative(true);
b
} else {
BigNum::from_slice(&d.to_be_bytes()).unwrap()
}
}
clone_impl!(|b: &Bn| {
let mut t = BigNum::from_slice(&b.0.to_vec()).unwrap();
t.set_negative(b.0.is_negative());
t
});
default_impl!(|| BigNum::new().unwrap());
display_impl!();
eq_impl!();
from_impl!(
|d: usize| BigNum::from_slice(&d.to_be_bytes()).unwrap(),
usize
);
#[cfg(target_pointer_width = "64")]
from_impl!(
|d: i128| {
if d < 0 {
let mut b = BigNum::from_slice(&(-d).to_be_bytes()).unwrap();
b.set_negative(true);
b
} else {
BigNum::from_slice(&d.to_be_bytes()).unwrap()
}
},
i128
);
#[cfg(target_pointer_width = "64")]
from_impl!(
|d: u128| BigNum::from_slice(&d.to_be_bytes()).unwrap(),
u128
);
from_impl!(|d: u64| BigNum::from_slice(&d.to_be_bytes()).unwrap(), u64);
from_impl!(|d: u32| BigNum::from_u32(d).unwrap(), u32);
from_impl!(|d: u16| BigNum::from_u32(d as u32).unwrap(), u16);
from_impl!(|d: u8| BigNum::from_u32(d as u32).unwrap(), u8);
from_impl!(from_isize, isize);
from_impl!(|d: i64| from_isize(d as isize), i64);
from_impl!(|d: i32| from_isize(d as isize), i32);
from_impl!(|d: i16| from_isize(d as isize), i16);
from_impl!(|d: i8| from_isize(d as isize), i8);
iter_impl!();
serdes_impl!(
|b: &Bn| b.0.to_hex_str().unwrap(),
|s: &str| { BigNum::from_hex_str(s).ok() },
|b: &Bn| {
let mut digits = b.0.to_vec();
digits.insert(0, if b.0.is_negative() { 1 } else { 0 });
digits
},
|s: &[u8]| -> Option<BigNum> {
if s.is_empty() {
return None;
}
let result = BigNum::from_slice(&s[1..]).ok()?;
Some(if s[0] == 1 { -result } else { result })
}
);
zeroize_impl!(|b: &mut Bn| b.0.clear());
impl<'a, 'b> Add<&'b Bn> for &'a Bn {
type Output = Bn;
fn add(self, rhs: &Self::Output) -> Self::Output {
let mut bn = BigNum::new().unwrap();
BigNumRef::checked_add(&mut bn, &self.0, &rhs.0).unwrap();
Bn(bn)
}
}
impl<'a, 'b> Sub<&'b Bn> for &'a Bn {
type Output = Bn;
fn sub(self, rhs: &Self::Output) -> Self::Output {
let mut bn = BigNum::new().unwrap();
BigNumRef::checked_sub(&mut bn, &self.0, &rhs.0).unwrap();
Bn(bn)
}
}
impl<'a, 'b> Mul<&'b Bn> for &'a Bn {
type Output = Bn;
fn mul(self, rhs: &Self::Output) -> Self::Output {
let mut bn = BigNum::new().unwrap();
let mut ctx = BigNumContext::new().unwrap();
BigNumRef::checked_mul(&mut bn, &self.0, &rhs.0, &mut ctx).unwrap();
Bn(bn)
}
}
impl<'a, 'b> Div<&'b Bn> for &'a Bn {
type Output = Bn;
fn div(self, rhs: &Self::Output) -> Self::Output {
let mut bn = BigNum::new().unwrap();
let mut ctx = BigNumContext::new().unwrap();
BigNumRef::checked_div(&mut bn, &self.0, &rhs.0, &mut ctx).unwrap();
Bn(bn)
}
}
impl<'a, 'b> Rem<&'b Bn> for &'a Bn {
type Output = Bn;
fn rem(self, rhs: &Self::Output) -> Self::Output {
let mut bn = BigNum::new().unwrap();
let mut ctx = BigNumContext::new().unwrap();
BigNumRef::checked_rem(&mut bn, &self.0, &rhs.0, &mut ctx).unwrap();
Bn(bn)
}
}
impl<'b> AddAssign<&'b Bn> for Bn {
fn add_assign(&mut self, rhs: &'b Bn) {
let b = self.clone();
BigNumRef::checked_add(&mut self.0, &b.0, &rhs.0).unwrap();
}
}
impl<'b> SubAssign<&'b Bn> for Bn {
fn sub_assign(&mut self, rhs: &'b Bn) {
let b = self.clone();
BigNumRef::checked_sub(&mut self.0, &b.0, &rhs.0).unwrap();
}
}
impl<'b> MulAssign<&'b Bn> for Bn {
fn mul_assign(&mut self, rhs: &'b Bn) {
let mut ctx = BigNumContext::new().unwrap();
let b = self.clone();
BigNumRef::checked_mul(&mut self.0, &b.0, &rhs.0, &mut ctx).unwrap();
}
}
impl<'b> DivAssign<&'b Bn> for Bn {
fn div_assign(&mut self, rhs: &'b Bn) {
let mut ctx = BigNumContext::new().unwrap();
let b = self.clone();
BigNumRef::checked_div(&mut self.0, &b.0, &rhs.0, &mut ctx).unwrap();
}
}
impl<'b> RemAssign<&'b Bn> for Bn {
fn rem_assign(&mut self, rhs: &'b Bn) {
let mut ctx = BigNumContext::new().unwrap();
let b = self.clone();
BigNumRef::checked_rem(&mut self.0, &b.0, &rhs.0, &mut ctx).unwrap();
}
}
ops_impl!(Add, add, AddAssign, add_assign, +, +=);
ops_impl!(Sub, sub, SubAssign, sub_assign, -, -=);
ops_impl!(Mul, mul, MulAssign, mul_assign, *, *=);
ops_impl!(Div, div, DivAssign, div_assign, /, /=);
ops_impl!(Rem, rem, RemAssign, rem_assign, %, %=);
neg_impl!(|b: &BigNum| {
let mut n = BigNum::from_slice(b.to_vec().as_slice()).unwrap();
n.set_negative(!b.is_negative());
Bn(n)
});
shift_impl!(Shl, shl, |lhs: &BigNum, rhs| {
let mut n = BigNum::new().unwrap();
if rhs == 1 {
BigNumRef::lshift1(&mut n, lhs).unwrap();
} else {
BigNumRef::lshift(&mut n, lhs, rhs as i32).unwrap();
}
Bn(n)
});
shift_impl!(Shr, shr, |lhs: &BigNum, rhs| {
let mut n = BigNum::new().unwrap();
if rhs == 1 {
BigNumRef::rshift1(&mut n, lhs).unwrap();
} else {
BigNumRef::rshift(&mut n, lhs, rhs as i32).unwrap();
}
Bn(n)
});
impl ConstantTimeEq for Bn {
fn ct_eq(&self, other: &Self) -> Choice {
Choice::from(if self.0.ucmp(&other.0) == Ordering::Equal {
1u8
} else {
0u8
})
}
}
impl Bn {
pub fn modpow(&self, exponent: &Self, n: &Self) -> Self {
let mut ctx = BigNumContext::new().unwrap();
let mut bn = BigNum::new().unwrap();
if exponent.0.is_negative() {
match self.invert(n) {
None => {}
Some(a) => {
let e = -exponent.clone();
BigNumRef::mod_exp(&mut bn, &a.0, &e.0, &n.0, &mut ctx).unwrap();
}
}
} else {
BigNumRef::mod_exp(&mut bn, &self.0, &exponent.0, &n.0, &mut ctx).unwrap();
}
Self(bn)
}
pub fn modadd(&self, rhs: &Self, n: &Self) -> Self {
let mut ctx = BigNumContext::new().unwrap();
let mut t = BigNum::new().unwrap();
BigNumRef::mod_add(&mut t, &self.0, &rhs.0, &n.0, &mut ctx).unwrap();
Bn(t)
}
pub fn modsub(&self, rhs: &Self, n: &Self) -> Self {
let mut ctx = BigNumContext::new().unwrap();
let mut t = BigNum::new().unwrap();
BigNumRef::mod_sub(&mut t, &self.0, &rhs.0, &n.0, &mut ctx).unwrap();
Bn(t)
}
pub fn modmul(&self, rhs: &Self, n: &Self) -> Self {
let mut ctx = BigNumContext::new().unwrap();
let mut t = BigNum::new().unwrap();
BigNumRef::mod_mul(&mut t, &self.0, &rhs.0, &n.0, &mut ctx).unwrap();
Bn(t)
}
pub fn moddiv(&self, rhs: &Self, n: &Self) -> Self {
let mut ctx = BigNumContext::new().unwrap();
let mut s = BigNum::new().unwrap();
let mut t = BigNum::new().unwrap();
BigNumRef::mod_inverse(&mut s, &rhs.0, &n.0, &mut ctx).unwrap();
BigNumRef::mod_mul(&mut t, &self.0, &s, &n.0, &mut ctx).unwrap();
Bn(t)
}
pub fn modneg(&self, n: &Self) -> Self {
let mut t = self.clone() % n.clone();
t = n.clone() - t.clone();
t %= n.clone();
t
}
pub fn nmod(&self, n: &Self) -> Self {
let mut ctx = BigNumContext::new().unwrap();
let mut t = BigNum::new().unwrap();
BigNumRef::nnmod(&mut t, &self.0, &n.0, &mut ctx).unwrap();
Bn(t)
}
pub fn invert(&self, modulus: &Bn) -> Option<Bn> {
if self.is_zero() || modulus.is_zero() || modulus.is_one() {
return None;
}
let mut ctx = BigNumContext::new().unwrap();
let mut bn = BigNum::new().unwrap();
BigNumRef::mod_inverse(&mut bn, &self.0, &modulus.0, &mut ctx).unwrap();
Some(Self(bn))
}
pub fn zero() -> Self {
Self(BigNum::new().unwrap())
}
pub fn one() -> Self {
Self(BigNum::from_u32(1).unwrap())
}
pub fn is_zero(&self) -> bool {
self.0.num_bits() == 0
}
pub fn is_one(&self) -> bool {
self.0.num_bits() == 1 && self.0.is_bit_set(0)
}
pub fn bit_length(&self) -> usize {
self.0.num_bits() as usize
}
pub fn gcd(&self, other: &Bn) -> Self {
let mut bn = BigNum::new().unwrap();
let mut ctx = BigNumContext::new().unwrap();
BigNumRef::gcd(&mut bn, &self.0, &other.0, &mut ctx).unwrap();
Self(bn)
}
pub fn lcm(&self, other: &Bn) -> Self {
if self.is_zero() && other.is_zero() {
Self::zero()
} else {
self / self.gcd(other) * other
}
}
pub fn random(n: &Self) -> Self {
let mut b = BigNum::new().unwrap();
BigNumRef::rand_range(&n.0, &mut b).unwrap();
Self(b)
}
pub fn from_rng(n: &Self, _rng: &mut impl RngCore) -> Self {
Self::random(n)
}
pub fn from_digest<D>(hasher: D) -> Self
where
D: digest::Digest,
{
Self(BigNum::from_slice(hasher.finalize().as_slice()).unwrap())
}
pub fn from_slice<B>(b: B) -> Self
where
B: AsRef<[u8]>,
{
Self(BigNum::from_slice(b.as_ref()).unwrap())
}
pub fn to_bytes(&self) -> Vec<u8> {
self.0.to_vec()
}
pub fn extended_gcd(&self, other: &Bn) -> GcdResult {
let mut s = (Self::zero(), Self::one());
let mut t = (Self::one(), Self::zero());
let mut r = (other.clone(), self.clone());
while !r.0.is_zero() {
let q = r.1.clone() / r.0.clone();
let f = |mut r: (Self, Self)| {
swap(&mut r.0, &mut r.1);
r.0 -= q.clone() * r.1.clone();
r
};
r = f(r);
s = f(s);
t = f(t);
}
if r.1 >= Self::zero() {
GcdResult {
gcd: r.1,
x: s.1,
y: t.1,
}
} else {
GcdResult {
gcd: Self::zero() - r.1,
x: Self::zero() - s.1,
y: Self::zero() - t.1,
}
}
}
pub fn safe_prime(size: usize) -> Self {
let mut p = BigNum::new().unwrap();
BigNumRef::generate_prime(&mut p, size as i32, true, None, None).unwrap();
Self(p)
}
pub fn prime(size: usize) -> Self {
let mut p = BigNum::new().unwrap();
BigNumRef::generate_prime(&mut p, size as i32, false, None, None).unwrap();
Self(p)
}
pub fn is_prime(&self) -> bool {
let mut ctx = BigNumContext::new().unwrap();
BigNumRef::is_prime(&self.0, 15, &mut ctx).unwrap()
}
pub fn div_rem(&self, other: &Self) -> (Self, Self) {
let mut ctx = BigNumContext::new().unwrap();
let mut div = BigNum::new().unwrap();
let mut rem = BigNum::new().unwrap();
BigNumRef::div_rem(&mut div, &mut rem, &self.0, &other.0, &mut ctx).unwrap();
(Self(div), Self(rem))
}
}
#[test]
fn safe_prime() {
let n = Bn::safe_prime(1024);
assert_eq!(n.0.num_bits(), 1024);
assert!(n.is_prime());
let sg: Bn = n >> 1;
assert!(sg.is_prime())
}
#[test]
fn ct_eq() {
let a = Bn::from(8);
let b = Bn::from(8);
assert_eq!(a.ct_eq(&b).unwrap_u8(), 1u8);
}