use num::Integer;
use num_bigint::{BigInt, BigUint, ToBigInt};
use num_traits::{One, Zero};
use rand::{distributions::Standard, rngs::OsRng, Rng};
use std::cmp::Ordering;
use std::ops::{Add, Div, Mul, Sub};
pub(crate) fn get_random_number(bits: usize, modulus: &BigUint) -> BigUint {
let num_elements = ((bits + 31) / 32) as usize;
let random_bytes: Vec<u32> = OsRng.sample_iter(Standard).take(num_elements).collect();
BigUint::from_slice(&random_bytes).mod_floor(modulus)
}
fn modular_inverse(number: &BigUint, modulus: &BigUint) -> BigUint {
if modulus == &One::one() {
return One::one();
}
let mut number = number
.to_bigint()
.expect("Conversion to big integer failed.");
let mut modulus = modulus
.to_bigint()
.expect("Conversion to big integer failed.");
let original_modulus = modulus.clone();
let mut x: BigInt = Zero::zero();
let mut inverse: BigInt = One::one();
while number > One::one() {
let (dividend, remainder) = number.div_rem(&modulus);
inverse -= dividend * &x;
number = remainder;
std::mem::swap(&mut number, &mut modulus);
std::mem::swap(&mut x, &mut inverse)
}
if inverse < Zero::zero() {
inverse += original_modulus
}
inverse
.to_biguint()
.expect("Conversion to unsigned big integer failed.")
}
#[derive(Debug, Clone, Eq)]
pub(crate) struct FiniteFieldElement {
pub value: BigUint,
pub modulus: BigUint,
}
impl FiniteFieldElement {
pub fn new(bytes: &[u8], modulus: &BigUint) -> Self {
let mut integers: Vec<u32> = vec![0; bytes.len() >> 2];
for index in 0..(bytes.len() >> 2) {
integers[index] = (bytes[4 * index] as u32)
+ ((bytes[4 * index + 1] as u32) << 8)
+ ((bytes[4 * index + 2] as u32) << 16)
+ ((bytes[4 * index + 3] as u32) << 24)
}
FiniteFieldElement {
value: BigUint::from_slice(&integers),
modulus: modulus.clone(),
}
}
pub fn new_random(num_bits: usize, modulus: &BigUint) -> Self {
FiniteFieldElement {
value: get_random_number(num_bits, modulus),
modulus: modulus.clone(),
}
}
pub fn new_integer(number: u32, modulus: &BigUint) -> Self {
FiniteFieldElement {
value: BigUint::from_slice(&[number]),
modulus: modulus.clone(),
}
}
pub fn get_bytes(&self) -> Vec<u8> {
let mut bytes: Vec<u8> = vec![0; (self.modulus.bits() >> 3) as usize];
let value_bytes = self.value.to_bytes_le();
bytes[..value_bytes.len()].clone_from_slice(&value_bytes[..]);
bytes
}
}
impl PartialOrd for FiniteFieldElement {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for FiniteFieldElement {
fn cmp(&self, other: &Self) -> Ordering {
self.value.cmp(&other.value)
}
}
impl PartialEq for FiniteFieldElement {
fn eq(&self, other: &Self) -> bool {
self.value == other.value
}
}
impl Add for FiniteFieldElement {
type Output = Self;
fn add(self, other: Self) -> Self {
Self {
value: (self.value + other.value).mod_floor(&self.modulus),
modulus: self.modulus.clone(),
}
}
}
impl Sub for FiniteFieldElement {
type Output = Self;
fn sub(self, other: FiniteFieldElement) -> Self {
let value = if self.value > other.value {
self.value - other.value
} else {
self.value + self.modulus.clone() - other.value
};
Self {
value,
modulus: self.modulus,
}
}
}
impl Mul for FiniteFieldElement {
type Output = Self;
fn mul(self, other: Self) -> Self {
Self {
value: (self.value * other.value).mod_floor(&self.modulus),
modulus: self.modulus.clone(),
}
}
}
impl Div for FiniteFieldElement {
type Output = Self;
#[allow(clippy::suspicious_arithmetic_impl)]
fn div(self, other: Self) -> Self {
let inverse_value = modular_inverse(&other.value, &self.modulus);
Self {
value: (self.value * inverse_value).mod_floor(&self.modulus),
modulus: self.modulus.clone(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::secret_sharing::MODULUS_ARRAY_256;
const NUM_TEST_RUNS: u32 = 100;
#[test]
fn test_modular_inverse() {
let modulus = BigUint::from_slice(&MODULUS_ARRAY_256);
for _i in 0..NUM_TEST_RUNS {
let num = get_random_number(256, &modulus);
let inverse = modular_inverse(&num, &modulus);
assert_eq!((num * inverse).mod_floor(&modulus), One::one());
}
}
#[test]
fn test_finite_field_addition() {
let modulus = BigUint::from_slice(&MODULUS_ARRAY_256);
for _i in 0..NUM_TEST_RUNS {
let element_1 = FiniteFieldElement::new_random(256, &modulus);
let element_2 = FiniteFieldElement::new_random(256, &modulus);
let mut sum = element_1.value.clone() + element_2.value.clone();
if sum >= modulus {
sum -= modulus.clone();
}
assert_eq!((element_1 + element_2).value, sum);
}
}
#[test]
fn test_finite_field_subtraction() {
let modulus = BigUint::from_slice(&MODULUS_ARRAY_256);
for _i in 0..NUM_TEST_RUNS {
let element_1 = FiniteFieldElement::new_random(256, &modulus);
let element_2 = FiniteFieldElement::new_random(256, &modulus);
let difference = if element_1 >= element_2 {
element_1.value.clone() - element_2.value.clone()
} else {
element_1.value.clone() + modulus.clone() - element_2.value.clone()
};
assert_eq!((element_1 - element_2).value, difference);
}
}
#[test]
fn test_finite_field_multiplication() {
let modulus = BigUint::from_slice(&MODULUS_ARRAY_256);
for _i in 0..NUM_TEST_RUNS {
let element_1 = FiniteFieldElement::new_random(256, &modulus);
let element_2 = FiniteFieldElement::new_random(256, &modulus);
let product = element_1.value.clone() * element_2.value.clone();
assert_eq!((element_1 * element_2).value, product.mod_floor(&modulus));
}
}
#[test]
fn test_finite_field_division() {
let modulus = BigUint::from_slice(&MODULUS_ARRAY_256);
for _i in 0..NUM_TEST_RUNS {
let element_1 = FiniteFieldElement::new_random(256, &modulus);
let element_2 = FiniteFieldElement::new_random(256, &modulus);
let element_3 = FiniteFieldElement::new_random(256, &modulus);
let term = (element_1.value.clone()
* element_2.value.clone()
* modular_inverse(&element_3.value, &modulus))
.mod_floor(&modulus);
assert_eq!(
(element_1.clone() * element_2.clone() / element_3.clone()).value,
term
);
assert_eq!((element_1 / element_3 * element_2).value, term);
}
}
#[test]
fn test_correct_byte_length() {
let modulus = BigUint::from_slice(&MODULUS_ARRAY_256);
let mut rng = rand::thread_rng();
for _i in 0..NUM_TEST_RUNS {
let length = rng.gen_range(10..256);
let element = FiniteFieldElement::new_random(length, &modulus);
assert_eq!(element.get_bytes().len(), 256 >> 3);
}
}
}