use num_bigint::BigUint;
pub struct FiniteField {}
#[derive(Debug, PartialEq)]
pub enum FiniteFieldError {
InvalidArgument(String),
InvalidResult(String),
}
impl FiniteField {
pub fn add(a: &BigUint, b: &BigUint, p: &BigUint) -> Result<BigUint, FiniteFieldError> {
FiniteField::check_less_than(a, p)?;
FiniteField::check_less_than(b, p)?;
Ok((a + b).modpow(&BigUint::from(1u32), p))
}
pub fn mult(a: &BigUint, b: &BigUint, p: &BigUint) -> Result<BigUint, FiniteFieldError> {
FiniteField::check_less_than(a, p)?;
FiniteField::check_less_than(b, p)?;
Ok((a * b).modpow(&BigUint::from(1u32), p))
}
pub fn inv_add(a: &BigUint, p: &BigUint) -> Result<BigUint, FiniteFieldError> {
FiniteField::check_less_than(a, p)?;
if *a == BigUint::from(0u32) {
return Ok(a.clone());
}
Ok(p - a)
}
pub fn subtract(a: &BigUint, b: &BigUint, p: &BigUint) -> Result<BigUint, FiniteFieldError> {
FiniteField::check_less_than(a, p)?;
FiniteField::check_less_than(b, p)?;
let b_inv = FiniteField::inv_add(b, p)?;
FiniteField::add(a, &b_inv, p)
}
pub fn inv_mult_prime(a: &BigUint, p: &BigUint) -> Result<BigUint, FiniteFieldError> {
FiniteField::check_less_than(a, p)?;
Ok(a.modpow(&(p - BigUint::from(2u32)), p))
}
pub fn divide(a: &BigUint, b: &BigUint, p: &BigUint) -> Result<BigUint, FiniteFieldError> {
FiniteField::check_less_than(a, p)?;
FiniteField::check_less_than(b, p)?;
let d_inv = FiniteField::inv_mult_prime(b, p)?;
FiniteField::mult(a, &d_inv, p)
}
pub fn check_less_than(a: &BigUint, b: &BigUint) -> Result<(), FiniteFieldError> {
if a >= b {
return Err(FiniteFieldError::InvalidArgument(format!("{} >= {}", a, b)));
}
Ok(())
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_add() {
let a = BigUint::from(4u32);
let b = BigUint::from(10u32);
let p = BigUint::from(11u32);
let res = FiniteField::add(&a, &b, &p).unwrap();
assert_eq!(res, BigUint::from(3u32));
let a = BigUint::from(10u32);
let b = BigUint::from(1u32);
let p = BigUint::from(11u32);
let res = FiniteField::add(&a, &b, &p).unwrap();
assert_eq!(res, BigUint::from(0u32));
let a = BigUint::from(4u32);
let b = BigUint::from(10u32);
let p = BigUint::from(31u32);
let res = FiniteField::add(&a, &b, &p).unwrap();
assert_eq!(res, BigUint::from(14u32));
}
#[test]
fn test_multiply() {
let a = BigUint::from(4u32);
let b = BigUint::from(10u32);
let p = BigUint::from(11u32);
let res = FiniteField::mult(&a, &b, &p).unwrap();
assert_eq!(res, BigUint::from(7u32));
let p = BigUint::from(51u32);
let res = FiniteField::mult(&a, &b, &p).unwrap();
assert_eq!(res, BigUint::from(40u32));
}
#[test]
fn test_inv_add() {
let a = BigUint::from(4u32);
let p = BigUint::from(51u32);
let res = FiniteField::inv_add(&a, &p).unwrap();
assert_eq!(res, BigUint::from(47u32));
let a = BigUint::from(0u32);
let p = BigUint::from(51u32);
let res = FiniteField::inv_add(&a, &p).unwrap();
assert_eq!(res, BigUint::from(0u32));
let a = BigUint::from(52u32);
let p = BigUint::from(51u32);
assert_eq!(
FiniteField::inv_add(&a, &p),
Err(FiniteFieldError::InvalidArgument(format!("{} >= {}", a, p)))
);
let a = BigUint::from(4u32);
let p = BigUint::from(51u32);
let c_inv = FiniteField::inv_add(&a, &p);
assert_eq!(c_inv, Ok(BigUint::from(47u32)));
assert_eq!(
FiniteField::add(&a, &c_inv.unwrap(), &p),
Ok(BigUint::from(0u32))
);
}
#[test]
fn test_subtract() {
let a = BigUint::from(4u32);
let p = BigUint::from(51u32);
assert_eq!(FiniteField::subtract(&a, &a, &p), Ok(BigUint::from(0u32)));
}
#[test]
fn test_inv_mult() {
let a = BigUint::from(4u32);
let p = BigUint::from(11u32);
let c_inv = FiniteField::inv_mult_prime(&a, &p);
assert_eq!(c_inv, Ok(BigUint::from(3u32)));
assert_eq!(
FiniteField::mult(&a, &c_inv.unwrap(), &p),
Ok(BigUint::from(1u32))
);
}
#[test]
fn test_divide() {
let a = BigUint::from(4u32);
let p = BigUint::from(11u32);
assert_eq!(FiniteField::divide(&a, &a, &p), Ok(BigUint::from(1u32)));
}
}