use crate::bytes::FieldBytes;
use crate::error::Error;
use crate::field::Field;
const P: u64 = 0xFFFF_FFFF_0000_0001;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct BFieldElement(u64);
impl BFieldElement {
#[must_use]
pub fn new(value: u64) -> Self {
Self(value % P)
}
#[must_use]
pub fn value(self) -> u64 {
self.0
}
fn reduce(x: u128) -> u64 {
let p_128 = u128::from(P);
u64::try_from(x % p_128).unwrap_or_default()
}
}
impl core::fmt::Display for BFieldElement {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "{}", self.0)
}
}
impl std::ops::Add for BFieldElement {
type Output = Self;
fn add(self, rhs: Self) -> Self {
Self(Self::reduce(u128::from(self.0) + u128::from(rhs.0)))
}
}
impl std::ops::Sub for BFieldElement {
type Output = Self;
fn sub(self, rhs: Self) -> Self {
let lhs = u128::from(self.0) + u128::from(P);
Self(Self::reduce(lhs - u128::from(rhs.0)))
}
}
impl std::ops::Mul for BFieldElement {
type Output = Self;
fn mul(self, rhs: Self) -> Self {
Self(Self::reduce(u128::from(self.0) * u128::from(rhs.0)))
}
}
impl std::ops::Neg for BFieldElement {
type Output = Self;
fn neg(self) -> Self {
if self.0 == 0 {
Self(0)
} else {
Self(P - self.0)
}
}
}
impl Field for BFieldElement {
fn zero() -> Self {
Self(0)
}
fn one() -> Self {
Self(1)
}
fn inv(&self) -> Result<Self, Error> {
if self.0 == 0 {
Err(Error::DivisionByZero)
} else {
Ok(pow(*self, P - 2))
}
}
}
impl FieldBytes for BFieldElement {
fn to_le_bytes(&self) -> Vec<u8> {
self.0.to_le_bytes().to_vec()
}
fn from_le_bytes(bytes: &[u8]) -> Result<Self, Error> {
bytes
.get(..8)
.ok_or(Error::InvalidFieldEncoding)
.and_then(|slice| <[u8; 8]>::try_from(slice).map_err(|_| Error::InvalidFieldEncoding))
.map(u64::from_le_bytes)
.map(Self::new)
}
}
fn pow(base: BFieldElement, exp: u64) -> BFieldElement {
std::iter::successors(Some(base), |&b| Some(b * b))
.zip(0..u64::BITS)
.filter(|&(_, i)| (exp >> i) & 1 == 1)
.map(|(p, _)| p)
.fold(BFieldElement::one(), |acc, p| acc * p)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn modulus_constant_is_correct() {
assert_eq!(P, 0xFFFF_FFFF_0000_0001);
assert_eq!(P, 18_446_744_069_414_584_321);
}
#[test]
fn zero_is_additive_identity() {
let a = BFieldElement::new(123_456_789);
assert_eq!(a + BFieldElement::zero(), a);
assert_eq!(BFieldElement::zero() + a, a);
}
#[test]
fn one_is_multiplicative_identity() {
let a = BFieldElement::new(123_456_789);
assert_eq!(a * BFieldElement::one(), a);
assert_eq!(BFieldElement::one() * a, a);
}
#[test]
fn additive_inverse() {
let a = BFieldElement::new(999_999_999_999);
assert_eq!(a + (-a), BFieldElement::zero());
}
#[test]
fn multiplicative_inverse() -> Result<(), Error> {
let a = BFieldElement::new(42);
let a_inv = a.inv()?;
assert_eq!(a * a_inv, BFieldElement::one());
Ok(())
}
#[test]
fn inverse_of_zero_fails() {
let result = BFieldElement::zero().inv();
assert!(result.is_err());
}
#[test]
fn sample_inverses() -> Result<(), Error> {
let samples = [1u64, 2, 7, 100, 1_000_000, 1_000_000_000_000, P - 1, P - 2];
samples.iter().try_for_each(|&v| {
let a = BFieldElement::new(v);
let a_inv = a.inv()?;
assert_eq!(a * a_inv, BFieldElement::one(), "failed for {v}");
Ok(())
})
}
#[test]
fn subtraction_is_add_neg() {
let a = BFieldElement::new(123_456_789_000);
let b = BFieldElement::new(987_654_321);
assert_eq!(a - b, a + (-b));
}
#[test]
fn multiplication_is_commutative() {
let a = BFieldElement::new(12_345_678);
let b = BFieldElement::new(98_765_432);
assert_eq!(a * b, b * a);
}
#[test]
fn distributivity() {
let a = BFieldElement::new(111_111);
let b = BFieldElement::new(222_222);
let c = BFieldElement::new(333_333);
assert_eq!(a * (b + c), a * b + a * c);
}
#[test]
fn new_reduces_mod_p() {
assert_eq!(BFieldElement::new(P), BFieldElement::new(0));
assert_eq!(BFieldElement::new(P + 1), BFieldElement::new(1));
}
#[test]
fn p_minus_one_squared_is_one() -> Result<(), Error> {
let neg_one = BFieldElement::new(P - 1);
assert_eq!(neg_one * neg_one, BFieldElement::one());
let neg_one_inv = neg_one.inv()?;
assert_eq!(neg_one_inv, neg_one);
Ok(())
}
#[test]
fn high_u64_values_multiply_without_overflow() {
let a = BFieldElement::new(P - 5);
let b = BFieldElement::new(P - 7);
assert_eq!(a * b, BFieldElement::new(35));
}
#[test]
fn bytes_roundtrip() -> Result<(), Error> {
let a = BFieldElement::new(0xDEAD_BEEF_1234_5678);
let bytes = a.to_le_bytes();
let b = BFieldElement::from_le_bytes(&bytes)?;
assert_eq!(a, b);
Ok(())
}
#[test]
fn bytes_zero_roundtrip() -> Result<(), Error> {
let a = BFieldElement::zero();
let b = BFieldElement::from_le_bytes(&a.to_le_bytes())?;
assert_eq!(a, b);
Ok(())
}
#[test]
fn bytes_too_short_fails() {
let result = BFieldElement::from_le_bytes(&[1, 2, 3]);
assert!(result.is_err());
}
#[test]
fn bytes_empty_fails() {
let result = BFieldElement::from_le_bytes(&[]);
assert!(result.is_err());
}
}