use num_bigint::BigUint;
#[derive(Debug, PartialEq)]
pub enum FiniteFieldError {
InvalidArgument(String),
InvalidResult(String),
}
pub fn add(a: &BigUint, b: &BigUint, p: &BigUint) -> Result<BigUint, FiniteFieldError> {
params_to_mod_check(a,b,p)?;
Ok((a + b).modpow(&BigUint::from(1u32), p))
}
pub fn multiplicate(a: &BigUint, b: &BigUint, p: &BigUint) -> Result<BigUint, FiniteFieldError> {
params_to_mod_check(a,b,p)?;
Ok((a * b).modpow(&BigUint::from(1u32), p))
}
pub fn inverse_add(a: &BigUint, p: &BigUint) -> Result<BigUint, FiniteFieldError> {
params_to_mod_check_single_point(a,p)?;
if *a == BigUint::from(0u32) {
return Ok(a.clone());
}
Ok(p - a)
}
pub fn subtract(a: &BigUint, b: &BigUint, p: &BigUint) -> Result<BigUint, FiniteFieldError> {
params_to_mod_check(a,b,p)?;
let b_inverse = inverse_add(b, p)?;
add(a, &b_inverse, p)
}
pub fn inverse_multiplicate_prime(a: &BigUint, p: &BigUint) -> Result<BigUint, FiniteFieldError> {
params_to_mod_check_single_point(a, p)?;
Ok(a.modpow(&(p - BigUint::from(2u32)), p))
}
pub fn divide(a: &BigUint, b: &BigUint, p: &BigUint) -> Result<BigUint, FiniteFieldError> {
params_to_mod_check(a,b,p)?;
let b_inverse = inverse_multiplicate_prime(b, p)?;
multiplicate(a, &b_inverse, p)
}
pub fn check_is_less_than(a: &BigUint, b: &BigUint) -> bool {
if a < b {
true
} else {
false
}
}
pub fn params_to_mod_check(a: &BigUint, b: &BigUint, p: &BigUint) -> Result<(), FiniteFieldError> {
let params_check = check_is_less_than(a, p) && check_is_less_than(b, p);
if !params_check {
return Err(FiniteFieldError::InvalidArgument(format!("a and b has to be greater than p: {}, {}, {}", a, b, p)))
}
Ok(())
}
pub fn params_to_mod_check_single_point(a: &BigUint, p: &BigUint) -> Result<(), FiniteFieldError> {
let params_check = check_is_less_than(a, p);
if !params_check {
return Err(FiniteFieldError::InvalidArgument(format!("a and b has to be greater than p: {}, {}", a, p)))
}
Ok(())
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_add() {
let a = BigUint::from(4u32);
let b = BigUint::from(10u32);
let p = BigUint::from(11u32);
let res = add(&a, &b, &p).unwrap();
assert_eq!(res, BigUint::from(3u32));
let a = BigUint::from(10u32);
let b = BigUint::from(1u32);
let p = BigUint::from(11u32);
let res = add(&a, &b, &p).unwrap();
assert_eq!(res, BigUint::from(0u32));
let a = BigUint::from(4u32);
let b = BigUint::from(10u32);
let p = BigUint::from(31u32);
let res = add(&a, &b, &p).unwrap();
assert_eq!(res, BigUint::from(14u32));
}
#[test]
fn test_multiply() {
let a = BigUint::from(4u32);
let b = BigUint::from(10u32);
let p = BigUint::from(11u32);
let res = multiplicate(&a, &b, &p).unwrap();
assert_eq!(res, BigUint::from(7u32));
let p = BigUint::from(51u32);
let res = multiplicate(&a, &b, &p).unwrap();
assert_eq!(res, BigUint::from(40u32));
}
#[test]
fn test_inv_add() {
let a = BigUint::from(4u32);
let p = BigUint::from(51u32);
let res = inverse_add(&a, &p).unwrap();
assert_eq!(res, BigUint::from(47u32));
let a = BigUint::from(0u32);
let p = BigUint::from(51u32);
let res = inverse_add(&a, &p).unwrap();
assert_eq!(res, BigUint::from(0u32));
let a = BigUint::from(52u32);
let p = BigUint::from(51u32);
assert_eq!(
inverse_add(&a, &p),
Err(FiniteFieldError::InvalidArgument(format!("a and b has to be greater than p: {}, {}", a, p)))
);
let a = BigUint::from(4u32);
let p = BigUint::from(51u32);
let c_inv = inverse_add(&a, &p);
assert_eq!(c_inv, Ok(BigUint::from(47u32)));
assert_eq!(
add(&a, &c_inv.unwrap(), &p),
Ok(BigUint::from(0u32))
);
}
#[test]
fn test_subtract() {
let a = BigUint::from(4u32);
let p = BigUint::from(51u32);
assert_eq!(subtract(&a, &a, &p), Ok(BigUint::from(0u32)));
}
#[test]
fn test_inv_mult() {
let a = BigUint::from(4u32);
let p = BigUint::from(11u32);
let c_inv = inverse_multiplicate_prime(&a, &p);
assert_eq!(c_inv, Ok(BigUint::from(3u32)));
assert_eq!(
multiplicate(&a, &c_inv.unwrap(), &p),
Ok(BigUint::from(1u32))
);
}
#[test]
fn test_divide() {
let a = BigUint::from(4u32);
let p = BigUint::from(11u32);
assert_eq!(divide(&a, &a, &p), Ok(BigUint::from(1u32)));
}
}