field-cat 0.1.0

Finite field algebra shared across plonkish-cat, proof-cat, and stark-cat
Documentation
//! The [`BFieldElement`] prime field: integers modulo
//! `p = 2^64 - 2^32 + 1` (the Goldilocks prime).
//!
//! Goldilocks is the base field used by Triton VM, Risc0, and
//! Plonky3 in its wider-field modes.  The modulus is chosen so
//! that:
//!
//! - elements fit in a single `u64`,
//! - reduction has a particularly efficient form on 64-bit
//!   hardware (not yet exploited here),
//! - two-adicity is 32, which is large enough for any practical
//!   NTT-friendly size.
//!
//! This implementation is **naive but correct**: arithmetic uses
//! `u128` intermediates and a `%` reduction.  A future revision
//! will swap in the Montgomery / single-shot Goldilocks reduction
//! used by Plonky3 and twenty-first for ~3-5x speedup.

use crate::bytes::FieldBytes;
use crate::error::Error;
use crate::field::Field;

/// The Goldilocks modulus: `2^64 - 2^32 + 1 = 18_446_744_069_414_584_321`.
const P: u64 = 0xFFFF_FFFF_0000_0001;

/// A field element in the Goldilocks prime field (mod `2^64 - 2^32 + 1`).
///
/// Stored as a canonical `u64` value in `[0, p)`.  Arithmetic
/// promotes to `u128` to avoid `u64` overflow.
///
/// The name `BFieldElement` (Base Field Element) follows the
/// convention from the Triton VM / twenty-first ecosystem.  In
/// the Plonky3 ecosystem the same field is called `Goldilocks`.
///
/// # Examples
///
/// ```
/// use field_cat::{BFieldElement, Field};
///
/// let a = BFieldElement::new(42);
/// let b = BFieldElement::new(7);
///
/// assert_eq!((a * b).value(), 294);
///
/// // Multiplicative inverse via Fermat's little theorem.
/// let a_inv = a.inv()?;
/// assert_eq!(a * a_inv, BFieldElement::one());
/// # Ok::<(), field_cat::Error>(())
/// ```
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct BFieldElement(u64);

impl BFieldElement {
    /// Create a new field element, reducing modulo `p`.
    #[must_use]
    pub fn new(value: u64) -> Self {
        Self(value % P)
    }

    /// The underlying integer value in `[0, p)`.
    #[must_use]
    pub fn value(self) -> u64 {
        self.0
    }

    /// Reduce a `u128` intermediate to a canonical `u64` in `[0, p)`.
    ///
    /// The `try_from` is mathematically infallible (the reduced
    /// value is strictly less than `p < 2^64`), so the
    /// `unwrap_or_default()` only guards against a Rust-level
    /// failure that the math rules out.
    fn reduce(x: u128) -> u64 {
        let p_128 = u128::from(P);
        u64::try_from(x % p_128).unwrap_or_default()
    }
}

impl core::fmt::Display for BFieldElement {
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        write!(f, "{}", self.0)
    }
}

impl std::ops::Add for BFieldElement {
    type Output = Self;
    fn add(self, rhs: Self) -> Self {
        Self(Self::reduce(u128::from(self.0) + u128::from(rhs.0)))
    }
}

impl std::ops::Sub for BFieldElement {
    type Output = Self;
    fn sub(self, rhs: Self) -> Self {
        // Add P first so the subtraction in u128 never underflows.
        let lhs = u128::from(self.0) + u128::from(P);
        Self(Self::reduce(lhs - u128::from(rhs.0)))
    }
}

impl std::ops::Mul for BFieldElement {
    type Output = Self;
    fn mul(self, rhs: Self) -> Self {
        Self(Self::reduce(u128::from(self.0) * u128::from(rhs.0)))
    }
}

impl std::ops::Neg for BFieldElement {
    type Output = Self;
    fn neg(self) -> Self {
        if self.0 == 0 {
            Self(0)
        } else {
            Self(P - self.0)
        }
    }
}

impl Field for BFieldElement {
    fn zero() -> Self {
        Self(0)
    }

    fn one() -> Self {
        Self(1)
    }

    fn inv(&self) -> Result<Self, Error> {
        if self.0 == 0 {
            Err(Error::DivisionByZero)
        } else {
            // Fermat's little theorem: a^(p-2) is the inverse for prime p.
            Ok(pow(*self, P - 2))
        }
    }
}

impl FieldBytes for BFieldElement {
    fn to_le_bytes(&self) -> Vec<u8> {
        self.0.to_le_bytes().to_vec()
    }

    fn from_le_bytes(bytes: &[u8]) -> Result<Self, Error> {
        bytes
            .get(..8)
            .ok_or(Error::InvalidFieldEncoding)
            .and_then(|slice| <[u8; 8]>::try_from(slice).map_err(|_| Error::InvalidFieldEncoding))
            .map(u64::from_le_bytes)
            .map(Self::new)
    }
}

/// Modular exponentiation `base^exp` in the Goldilocks field.
///
/// Builds the squares `base^(2^i)` lazily via `successors`, then
/// folds the ones whose bit in `exp` is set.  Linear in the bit
/// width of `exp`.
fn pow(base: BFieldElement, exp: u64) -> BFieldElement {
    std::iter::successors(Some(base), |&b| Some(b * b))
        .zip(0..u64::BITS)
        .filter(|&(_, i)| (exp >> i) & 1 == 1)
        .map(|(p, _)| p)
        .fold(BFieldElement::one(), |acc, p| acc * p)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn modulus_constant_is_correct() {
        // p = 2^64 - 2^32 + 1.  Verify by reconstructing.
        assert_eq!(P, 0xFFFF_FFFF_0000_0001);
        assert_eq!(P, 18_446_744_069_414_584_321);
    }

    #[test]
    fn zero_is_additive_identity() {
        let a = BFieldElement::new(123_456_789);
        assert_eq!(a + BFieldElement::zero(), a);
        assert_eq!(BFieldElement::zero() + a, a);
    }

    #[test]
    fn one_is_multiplicative_identity() {
        let a = BFieldElement::new(123_456_789);
        assert_eq!(a * BFieldElement::one(), a);
        assert_eq!(BFieldElement::one() * a, a);
    }

    #[test]
    fn additive_inverse() {
        let a = BFieldElement::new(999_999_999_999);
        assert_eq!(a + (-a), BFieldElement::zero());
    }

    #[test]
    fn multiplicative_inverse() -> Result<(), Error> {
        let a = BFieldElement::new(42);
        let a_inv = a.inv()?;
        assert_eq!(a * a_inv, BFieldElement::one());
        Ok(())
    }

    #[test]
    fn inverse_of_zero_fails() {
        let result = BFieldElement::zero().inv();
        assert!(result.is_err());
    }

    #[test]
    fn sample_inverses() -> Result<(), Error> {
        let samples = [1u64, 2, 7, 100, 1_000_000, 1_000_000_000_000, P - 1, P - 2];
        samples.iter().try_for_each(|&v| {
            let a = BFieldElement::new(v);
            let a_inv = a.inv()?;
            assert_eq!(a * a_inv, BFieldElement::one(), "failed for {v}");
            Ok(())
        })
    }

    #[test]
    fn subtraction_is_add_neg() {
        let a = BFieldElement::new(123_456_789_000);
        let b = BFieldElement::new(987_654_321);
        assert_eq!(a - b, a + (-b));
    }

    #[test]
    fn multiplication_is_commutative() {
        let a = BFieldElement::new(12_345_678);
        let b = BFieldElement::new(98_765_432);
        assert_eq!(a * b, b * a);
    }

    #[test]
    fn distributivity() {
        let a = BFieldElement::new(111_111);
        let b = BFieldElement::new(222_222);
        let c = BFieldElement::new(333_333);
        assert_eq!(a * (b + c), a * b + a * c);
    }

    #[test]
    fn new_reduces_mod_p() {
        assert_eq!(BFieldElement::new(P), BFieldElement::new(0));
        assert_eq!(BFieldElement::new(P + 1), BFieldElement::new(1));
    }

    #[test]
    fn p_minus_one_squared_is_one() -> Result<(), Error> {
        // (p - 1) ≡ -1 (mod p), and (-1)^2 = 1.
        let neg_one = BFieldElement::new(P - 1);
        assert_eq!(neg_one * neg_one, BFieldElement::one());
        // -1 is its own inverse.
        let neg_one_inv = neg_one.inv()?;
        assert_eq!(neg_one_inv, neg_one);
        Ok(())
    }

    #[test]
    fn high_u64_values_multiply_without_overflow() {
        // Both inputs near 2^64 stress the u128 intermediate path.
        let a = BFieldElement::new(P - 5);
        let b = BFieldElement::new(P - 7);
        // (p - 5) * (p - 7) ≡ 35 (mod p).
        assert_eq!(a * b, BFieldElement::new(35));
    }

    #[test]
    fn bytes_roundtrip() -> Result<(), Error> {
        let a = BFieldElement::new(0xDEAD_BEEF_1234_5678);
        let bytes = a.to_le_bytes();
        let b = BFieldElement::from_le_bytes(&bytes)?;
        assert_eq!(a, b);
        Ok(())
    }

    #[test]
    fn bytes_zero_roundtrip() -> Result<(), Error> {
        let a = BFieldElement::zero();
        let b = BFieldElement::from_le_bytes(&a.to_le_bytes())?;
        assert_eq!(a, b);
        Ok(())
    }

    #[test]
    fn bytes_too_short_fails() {
        let result = BFieldElement::from_le_bytes(&[1, 2, 3]);
        assert!(result.is_err());
    }

    #[test]
    fn bytes_empty_fails() {
        let result = BFieldElement::from_le_bytes(&[]);
        assert!(result.is_err());
    }
}