use num_bigint::{BigInt, Sign};
use num_traits::identities::{One, Zero};
use std::clone::Clone;
pub struct Util {}
impl Util {
pub fn extend_gcd(a: &BigInt, b: &BigInt) -> (BigInt, BigInt, BigInt) {
if a == &BigInt::zero() {
(b.clone(), BigInt::zero(), BigInt::one())
} else {
let (g, x, y) = Util::extend_gcd(&(b % a), a);
(g, y - (b / a) * &x, x)
}
}
pub fn mod_inverse(a: &BigInt, modular: &BigInt) -> Option<BigInt> {
let (g, x, _) = Util::extend_gcd(a, modular);
if g != BigInt::one() {
None
} else {
let result = (&x % modular + modular) % modular;
Some(result)
}
}
pub fn lagrange_coefficient(i: &i64, values: &[i64]) -> (BigInt, BigInt) {
if !values.contains(i) {
return (BigInt::zero(), BigInt::one());
}
let mut numerator = BigInt::one();
let mut denominator = BigInt::one();
let vec_to = values.to_vec();
let max = vec_to.iter().max().unwrap();
for j in 1..=*max {
if j != *i && values.contains(&j) {
numerator *= j;
denominator *= j - *i;
}
}
(numerator, denominator)
}
pub fn abs(n: &BigInt) -> BigInt {
match n.sign() {
Sign::Minus => BigInt::new(Sign::Plus, n.to_u32_digits().1),
Sign::Plus => n.clone(),
Sign::NoSign => n.clone(),
}
}
}
#[cfg(test)]
mod tests {
use super::Util;
use num_bigint::{BigInt, BigUint};
use num_traits::Num;
use num_traits::{One, Zero};
use sha2::{Digest, Sha256};
#[test]
fn test_extend_gcd() {
let a = BigInt::from(26);
let b = BigInt::from(3);
let (g, x, y) = Util::extend_gcd(&a, &b);
assert_eq!(g, BigInt::one());
assert_eq!(x, BigInt::from(-1));
assert_eq!(y, BigInt::from(9));
assert_eq!((a.clone() * x) + (b.clone() * y), g);
}
#[test]
fn test_mod_inverse() {
let does_exist = Util::mod_inverse(&BigInt::from(3), &BigInt::from(26));
let does_not_exist =
Util::mod_inverse(&BigInt::from(4), &BigInt::from(32));
match does_exist {
Some(x) => assert_eq!(x, BigInt::from(9)),
None => panic!("mod_inverse() didn't work as expected"),
}
match does_not_exist {
Some(x) => {
drop(x);
panic!("mod_inverse() found an inverse when it shouldn't have")
}
None => {}
}
}
#[test]
fn test_lagrange_coefficient() {
let i_array = [9, 1, 2, 3];
let values = [0, 1, 2, 3, 4, 5, 6];
let result = Util::lagrange_coefficient(&i_array[0], &values);
assert_eq!(result, (BigInt::zero(), BigInt::one()));
let result = Util::lagrange_coefficient(&i_array[1], &values);
assert_eq!(result, (BigInt::from(720), BigInt::from(120)));
let result = Util::lagrange_coefficient(&i_array[2], &values);
assert_eq!(result, (BigInt::from(360), BigInt::from(-24)));
let result = Util::lagrange_coefficient(&i_array[3], &values);
assert_eq!(result, (BigInt::from(240), BigInt::from(12)));
let result = Util::lagrange_coefficient(&3, &[1, 3, 4]);
assert_eq!(result, (BigInt::from(4), BigInt::from(-2)));
}
#[test]
fn test_abs() {
let minus = BigInt::from(-100);
assert_eq!(Util::abs(&minus), BigInt::from(100));
let minus = BigInt::from(-0);
assert_eq!(Util::abs(&minus), BigInt::from(0));
let plus = BigInt::from(0);
assert_eq!(Util::abs(&plus), BigInt::from(0));
let plus = BigInt::from(100);
assert_eq!(Util::abs(&plus), BigInt::from(100));
}
#[test]
fn test_xor() {
let a = BigUint::from(1337_u32);
let b = BigUint::from(42_u32);
let xor = a ^ b;
assert_eq!(xor, BigUint::from(1299_u32));
}
#[test]
fn test_hash() {
let mut sha256 = Sha256::new();
let value_1 =
BigInt::from_str_radix("43589072349864890574839", 10).unwrap();
let value_2 =
BigInt::from_str_radix("14735247304952934566", 10).unwrap();
let value_1_uint = value_1.to_biguint().unwrap();
let value_2_uint = value_2.to_biguint().unwrap();
sha256.update(value_1_uint.to_str_radix(10).as_bytes());
sha256.update(value_2_uint.to_str_radix(10).as_bytes());
let result = sha256.finalize();
let challenge_big_uint = BigUint::from_bytes_be(&result[..]);
assert_eq!(
challenge_big_uint.to_str_radix(16),
"e25e5b7edf4ea66e5238393fb4f183e0fc1593c69a522f9255a51bd0bc2b7ba7"
);
assert_eq!(
BigInt::from_str_radix(&challenge_big_uint.to_str_radix(16), 16),
BigInt::from_str_radix(
"102389418883295205726805934198606438410316463205994911160958467170744727731111",
10
)
);
}
}