extern crate num_bigint;
extern crate num_traits;
extern crate rand;
use num_bigint::{BigUint, BigInt, RandBigInt, ToBigInt};
use num_traits::{One,Zero};
use rand::thread_rng;
use num_prime::RandPrime;
use num_prime::PrimalityTestConfig;
use sha2::{Sha256, Digest};
pub struct Polynomial {
pub coefficients: Vec<BigUint>,
}
impl Polynomial {
pub fn new(degree: usize, max_bit_size: usize) -> Self {
let mut rng = thread_rng();
let mut coefficients = Vec::with_capacity(degree + 1);
let n = BigUint::one() << max_bit_size;
for _ in 0..=degree {
let coef = rng.gen_biguint_range(&BigUint::one(), &n);
coefficients.push(coef);
}
Polynomial { coefficients }
}
pub fn new_for_shamir(threshold: usize, secret_bits: usize, secret: &BigUint) -> Self {
let mut rng = thread_rng();
let mut coefficients = vec![secret.clone()];
for _ in 1..threshold {
let coef = rng.gen_biguint_range(&BigUint::one(), &(BigUint::one() << secret_bits));
coefficients.push(coef);
}
Polynomial { coefficients }
}
pub fn evaluate(&self, x: &BigUint) -> BigUint {
let mut result = BigUint::zero();
let mut x_pow = BigUint::one();
for coef in &self.coefficients {
result += coef * &x_pow;
x_pow *= x;
}
result
}
pub fn to_string(&self) -> String {
self.coefficients.iter().enumerate().map(|(index, coef)| {
match index {
0 => format!("{}", coef),
1 => format!("{}x", coef),
_ => format!("{}x^{}", coef, index),
}
}).collect::<Vec<String>>().join(" + ")
}
}
pub fn gen_rand(modulus: &BigUint) -> BigUint{
let mut rng = thread_rng();
rng.gen_biguint_range(&BigUint::one(), modulus)
}
pub fn generate_prime(bit_size: usize) -> BigUint {
let mut rng = thread_rng();
let config = PrimalityTestConfig::default();
rng.gen_prime(bit_size, Some(config))
}
pub fn hash_data(data: &[u8]) -> Vec<u8> {
let mut hasher = Sha256::new();
hasher.update(data);
hasher.finalize().to_vec()
}
pub fn mod_exp(base: &BigUint, exponent: &BigUint, modulus: &BigUint) -> BigUint {
base.modpow(exponent, modulus)
}
pub fn egcd(a: BigInt, b: BigInt) -> (BigInt, BigInt, BigInt) {
if a.is_zero() {
(b, Zero::zero(), One::one())
} else {
let (g, x, y) = egcd(b.clone() % a.clone(), a.clone());
(g, y - (b / a.clone()) * x.clone(), x)
}
}
pub fn mod_inv(a: &BigUint, m: &BigUint) -> Option<BigUint> {
let (g, x, _) = egcd(a.to_bigint().unwrap(), m.to_bigint().unwrap());
if g == One::one() {
let x_mod_m = ((x % m.to_bigint().unwrap()) + m.to_bigint().unwrap()) % m.to_bigint().unwrap();
Some(x_mod_m.to_biguint().unwrap())
} else {
None
}
}
pub fn lagrange_interpolation_zero(points: &[(BigUint, BigUint)], modulus: &BigUint) -> Option<BigUint> {
let mut secret = BigUint::zero();
for (i, (x_i, y_i)) in points.iter().enumerate() {
let mut numerator = BigUint::one();
let mut denominator = BigUint::one();
for (j, (x_j, _)) in points.iter().enumerate() {
if i != j {
let x_diff = (modulus - x_j) % modulus;
numerator = (numerator * x_diff) % modulus;
denominator = (denominator * (x_i + modulus - x_j) % modulus) % modulus;
}
}
let inv_denominator = mod_inv(&denominator, modulus)?;
let term = (y_i * &numerator * inv_denominator) % modulus;
secret = (secret + term) % modulus;
}
Some(secret)
}
#[cfg(test)]
mod tests {
use super::*;
use num_bigint::ToBigUint;
#[test]
fn test_polynomial_to_string() {
let poly = Polynomial {
coefficients: vec![1.to_biguint().unwrap(), 2.to_biguint().unwrap(), 3.to_biguint().unwrap()],
};
let expected = "1 + 2x + 3x^2".to_string();
assert_eq!(poly.to_string(), expected);
}
#[test]
fn test_evaluation(){
let poly = Polynomial {
coefficients: vec![90782.to_biguint().unwrap(), 222234.to_biguint().unwrap(), 123343.to_biguint().unwrap()],
};
let x=poly.evaluate(&(1.to_biguint().unwrap()));
let y = poly.coefficients.iter().fold(BigUint::zero(), |acc, coeff| acc + coeff);
assert_eq!(y,x);
}
#[test]
fn test_prime_generation() {
let prime = generate_prime(128);
println!("Prime:{}", prime);
}
#[test]
fn test_hash_data() {
let data = b"hello, world";
let hash = hash_data(data);
assert_eq!(hash.len(), 32); }
#[test]
fn test_mod_exp() {
let base = 2.to_biguint().unwrap();
let exponent = 10.to_biguint().unwrap();
let modulus = 1000.to_biguint().unwrap();
let result = mod_exp(&base, &exponent, &modulus);
assert_eq!(result, 24.to_biguint().unwrap());
}
#[test]
fn test_mod_inv() {
let a = 3.to_biguint().unwrap();
let m = 11.to_biguint().unwrap();
let inv = mod_inv(&a, &m).unwrap();
assert_eq!(inv, 4.to_biguint().unwrap());
}
#[test]
fn test_lagrange_interpolation_zero() {
let points = vec![
(1.to_biguint().unwrap(), 90.to_biguint().unwrap()),
(2.to_biguint().unwrap(), 87.to_biguint().unwrap()),
(3.to_biguint().unwrap(), 678.to_biguint().unwrap())
];
let modulus = 1009.to_biguint().unwrap();
let secret = lagrange_interpolation_zero(&points, &modulus).unwrap();
assert_eq!(secret, 687.to_biguint().unwrap());
}
}