use crate::bytes::FieldBytes;
use crate::error::Error;
use crate::field::Field;
const P: u64 = 2_147_483_647;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct BabyBear(u64);
impl BabyBear {
#[must_use]
pub fn new(value: u64) -> Self {
Self(value % P)
}
#[must_use]
pub fn value(self) -> u64 {
self.0
}
}
impl core::fmt::Display for BabyBear {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "{}", self.0)
}
}
impl std::ops::Add for BabyBear {
type Output = Self;
fn add(self, rhs: Self) -> Self {
Self((self.0 + rhs.0) % P)
}
}
impl std::ops::Sub for BabyBear {
type Output = Self;
fn sub(self, rhs: Self) -> Self {
Self((self.0 + P - rhs.0) % P)
}
}
impl std::ops::Mul for BabyBear {
type Output = Self;
fn mul(self, rhs: Self) -> Self {
Self((self.0 * rhs.0) % P)
}
}
impl std::ops::Neg for BabyBear {
type Output = Self;
fn neg(self) -> Self {
Self((P - self.0) % P)
}
}
impl Field for BabyBear {
fn zero() -> Self {
Self(0)
}
fn one() -> Self {
Self(1)
}
fn inv(&self) -> Result<Self, Error> {
if self.0 == 0 {
Err(Error::DivisionByZero)
} else {
Ok(Self(pow_mod(self.0, P - 2, P)))
}
}
}
impl FieldBytes for BabyBear {
fn to_le_bytes(&self) -> Vec<u8> {
u32::try_from(self.0)
.map(|n| n.to_le_bytes().to_vec())
.unwrap_or_default()
}
fn from_le_bytes(bytes: &[u8]) -> Result<Self, Error> {
bytes
.get(..4)
.ok_or(Error::InvalidFieldEncoding)
.and_then(|slice| <[u8; 4]>::try_from(slice).map_err(|_| Error::InvalidFieldEncoding))
.map(u32::from_le_bytes)
.map(|n| Self::new(u64::from(n)))
}
}
fn pow_mod(base: u64, exp: u64, modulus: u64) -> u64 {
std::iter::successors(Some(base % modulus), |&b| Some((b * b) % modulus))
.zip(0..u64::BITS)
.filter(|&(_, i)| (exp >> i) & 1 == 1)
.map(|(p, _)| p)
.fold(1u64, |acc, p| (acc * p) % modulus)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn zero_is_additive_identity() {
let a = BabyBear::new(123_456);
assert_eq!(a + BabyBear::zero(), a);
assert_eq!(BabyBear::zero() + a, a);
}
#[test]
fn one_is_multiplicative_identity() {
let a = BabyBear::new(123_456);
assert_eq!(a * BabyBear::one(), a);
assert_eq!(BabyBear::one() * a, a);
}
#[test]
fn additive_inverse() {
let a = BabyBear::new(999_999);
assert_eq!(a + (-a), BabyBear::zero());
}
#[test]
fn multiplicative_inverse() -> Result<(), Error> {
let a = BabyBear::new(42);
let a_inv = a.inv()?;
assert_eq!(a * a_inv, BabyBear::one());
Ok(())
}
#[test]
fn inverse_of_zero_fails() {
let result = BabyBear::zero().inv();
assert!(result.is_err());
}
#[test]
fn sample_inverses() -> Result<(), Error> {
let samples = [1u64, 2, 7, 100, 1_000_000, P - 1, P - 2];
samples.iter().try_for_each(|&v| {
let a = BabyBear::new(v);
let a_inv = a.inv()?;
assert_eq!(a * a_inv, BabyBear::one(), "failed for {v}");
Ok(())
})
}
#[test]
fn subtraction_is_add_neg() {
let a = BabyBear::new(1_000_000);
let b = BabyBear::new(500_000);
assert_eq!(a - b, a + (-b));
}
#[test]
fn multiplication_is_commutative() {
let a = BabyBear::new(12_345);
let b = BabyBear::new(67_890);
assert_eq!(a * b, b * a);
}
#[test]
fn distributivity() {
let a = BabyBear::new(111);
let b = BabyBear::new(222);
let c = BabyBear::new(333);
assert_eq!(a * (b + c), a * b + a * c);
}
#[test]
fn new_reduces_mod_p() {
assert_eq!(BabyBear::new(P), BabyBear::new(0));
assert_eq!(BabyBear::new(P + 1), BabyBear::new(1));
assert_eq!(BabyBear::new(2 * P), BabyBear::new(0));
}
#[test]
fn bytes_roundtrip() -> Result<(), Error> {
let a = BabyBear::new(1_234_567);
let bytes = a.to_le_bytes();
let b = BabyBear::from_le_bytes(&bytes)?;
assert_eq!(a, b);
Ok(())
}
#[test]
fn bytes_zero_roundtrip() -> Result<(), Error> {
let a = BabyBear::zero();
let b = BabyBear::from_le_bytes(&a.to_le_bytes())?;
assert_eq!(a, b);
Ok(())
}
#[test]
fn bytes_empty_fails() {
let result = BabyBear::from_le_bytes(&[]);
assert!(result.is_err());
}
}