use super::traits::{
BitManipulation, ConvertFrom, Converter, Modulo, NumberTests, Samplable, ZeroizeBN, EGCD,
};
use super::BigInt;
use super::HexError;
use getrandom::getrandom;
use gmp::mpz::Mpz;
use std::borrow::Borrow;
use std::sync::atomic;
impl ZeroizeBN for Mpz {
fn zeroize_bn(&mut self) {
drop(self);
atomic::fence(atomic::Ordering::SeqCst);
atomic::compiler_fence(atomic::Ordering::SeqCst);
}
}
impl Converter for Mpz {
fn to_vec(value: &Mpz) -> Vec<u8> {
let bytes: Vec<u8> = value.borrow().into();
bytes
}
fn to_hex(&self) -> String {
self.to_str_radix(16)
}
fn from_hex(value: &str) -> Result<Mpz, HexError> {
BigInt::from_str_radix(value, 16)
}
fn from_bytes(bytes: &[u8]) -> Mpz {
BigInt::from(bytes)
}
fn to_bytes(&self) -> Vec<u8> {
self.into()
}
}
impl Modulo for Mpz {
fn mod_pow(base: &Self, exponent: &Self, modulus: &Self) -> Self {
base.powm(exponent, modulus)
}
fn mod_mul(a: &Self, b: &Self, modulus: &Self) -> Self {
(a.mod_floor(modulus) * b.mod_floor(modulus)).mod_floor(modulus)
}
fn mod_sub(a: &Self, b: &Self, modulus: &Self) -> Self {
let a_m = a.mod_floor(modulus);
let b_m = b.mod_floor(modulus);
let sub_op = a_m - b_m + modulus;
sub_op.mod_floor(modulus)
}
fn mod_add(a: &Self, b: &Self, modulus: &Self) -> Self {
(a.mod_floor(modulus) + b.mod_floor(modulus)).mod_floor(modulus)
}
fn mod_inv(a: &Self, modulus: &Self) -> Self {
a.invert(modulus).unwrap()
}
}
impl Samplable for Mpz {
fn sample_below(upper: &Self) -> Self {
assert!(*upper > Mpz::zero());
let bits = upper.bit_length();
loop {
let n = Self::sample(bits);
if n < *upper {
return n;
}
}
}
fn sample_range(lower: &Self, upper: &Self) -> Self {
assert!(upper > lower);
lower + Self::sample_below(&(upper - lower))
}
fn strict_sample_range(lower: &Self, upper: &Self) -> Self {
assert!(upper > lower);
loop {
let n = lower + Self::sample_below(&(upper - lower));
if n > *lower && n < *upper {
return n;
}
}
}
fn sample(bit_size: usize) -> Self {
let bytes = (bit_size - 1) / 8 + 1;
let mut buf: Vec<u8> = vec![0; bytes];
getrandom(&mut buf).unwrap();
Self::from(&*buf) >> (bytes * 8 - bit_size)
}
fn strict_sample(bit_size: usize) -> Self {
loop {
let n = Self::sample(bit_size);
if n.bit_length() == bit_size {
return n;
}
}
}
}
impl NumberTests for Mpz {
fn is_zero(me: &Self) -> bool {
me.is_zero()
}
fn is_even(me: &Self) -> bool {
me.is_multiple_of(&Mpz::from(2))
}
fn is_negative(me: &Self) -> bool {
*me < Mpz::from(0)
}
fn bits(me: &Self) -> usize {
me.bit_length()
}
}
impl EGCD for Mpz {
fn egcd(a: &Self, b: &Self) -> (Self, Self, Self) {
a.gcdext(b)
}
}
impl BitManipulation for Mpz {
fn set_bit(self: &mut Self, bit: usize, bit_val: bool) {
if bit_val {
self.setbit(bit);
} else {
self.clrbit(bit);
}
}
fn test_bit(self: &Self, bit: usize) -> bool {
self.tstbit(bit)
}
}
impl ConvertFrom<Mpz> for u64 {
fn _from(x: &Mpz) -> u64 {
let opt_x: Option<u64> = x.into();
opt_x.unwrap()
}
}
#[cfg(test)]
mod tests {
use super::Converter;
use super::Modulo;
use super::Mpz;
use super::Samplable;
use std::cmp;
#[test]
#[should_panic]
fn sample_below_zero_test() {
Mpz::sample_below(&Mpz::from(-1));
}
#[test]
fn sample_below_test() {
let upper_bound = Mpz::from(10);
for _ in 1..100 {
let r = Mpz::sample_below(&upper_bound);
assert!(r < upper_bound);
}
}
#[test]
#[should_panic]
fn invalid_range_test() {
Mpz::sample_range(&Mpz::from(10), &Mpz::from(9));
}
#[test]
fn sample_range_test() {
let upper_bound = Mpz::from(10);
let lower_bound = Mpz::from(5);
for _ in 1..100 {
let r = Mpz::sample_range(&lower_bound, &upper_bound);
assert!(r < upper_bound && r >= lower_bound);
}
}
#[test]
fn strict_sample_range_test() {
let len = 249;
for _ in 1..100 {
let a = Mpz::sample(len);
let b = Mpz::sample(len);
let lower_bound = cmp::min(a.clone(), b.clone());
let upper_bound = cmp::max(a.clone(), b.clone());
let r = Mpz::strict_sample_range(&lower_bound, &upper_bound);
assert!(r < upper_bound && r >= lower_bound);
}
}
#[test]
fn strict_sample_test() {
let len = 249;
for _ in 1..100 {
let a = Mpz::strict_sample(len);
assert_eq!(a.bit_length(), len);
}
}
#[test]
fn test_mod_sub_modulo() {
let a = Mpz::from(10);
let b = Mpz::from(5);
let modulo = Mpz::from(3);
let res = Mpz::from(2);
assert_eq!(res, Mpz::mod_sub(&a, &b, &modulo));
}
#[test]
fn test_mod_sub_negative_modulo() {
let a = Mpz::from(5);
let b = Mpz::from(10);
let modulo = Mpz::from(3);
let res = Mpz::from(1);
assert_eq!(res, Mpz::mod_sub(&a, &b, &modulo));
}
#[test]
fn test_mod_mul() {
let a = Mpz::from(4);
let b = Mpz::from(5);
let modulo = Mpz::from(3);
let res = Mpz::from(2);
assert_eq!(res, Mpz::mod_mul(&a, &b, &modulo));
}
#[test]
fn test_mod_pow() {
let a = Mpz::from(2);
let b = Mpz::from(3);
let modulo = Mpz::from(3);
let res = Mpz::from(2);
assert_eq!(res, Mpz::mod_pow(&a, &b, &modulo));
}
#[test]
fn test_to_hex() {
let b = Mpz::from(11);
assert_eq!("b", b.to_hex());
}
#[test]
fn test_from_hex() {
let a = Mpz::from(11);
assert_eq!(Mpz::from_hex(&a.to_hex()).unwrap(), a);
}
}