use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
pub const MODULUS: u64 = 0xFFFF_FFFF_0000_0001;
const R2: u64 = 0xFFFF_FFFE_0000_0001;
const TWO_ADICITY: u32 = 32;
const POWER_OF_TWO_GENERATOR: u64 = 7_277_203_076_849_721_926;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct Fp(pub(super) u64);
impl Fp {
pub const ZERO: Self = Self::from_u64_reduce(0);
pub const ONE: Self = Self::from_u64_reduce(1);
pub const MINUS_ONE: Self = Self::from_u64_reduce(MODULUS - 1);
#[inline(always)]
const fn montyred(x: u128) -> u64 {
let xl = x as u64;
let xh = (x >> 64) as u64;
let (a, e) = xl.overflowing_add(xl << 32);
let b = a.wrapping_sub(a >> 32).wrapping_sub(e as u64);
let (r, c) = xh.overflowing_sub(b);
r.wrapping_sub(0u32.wrapping_sub(c as u32) as u64)
}
#[inline(always)]
pub const fn from_u64_reduce(v: u64) -> Self {
Self(Self::montyred((v as u128) * (R2 as u128)))
}
#[inline(always)]
pub fn from_u64_canonical(v: u64) -> Option<Self> {
if v < MODULUS {
Some(Self::from_u64_reduce(v))
} else {
None
}
}
#[inline(always)]
pub const fn to_u64(self) -> u64 {
Self::montyred(self.0 as u128)
}
#[inline]
pub fn try_from_le_bytes(bytes: [u8; 8]) -> Option<Self> {
Self::from_u64_canonical(u64::from_le_bytes(bytes))
}
#[inline]
pub fn to_le_bytes(self) -> [u8; 8] {
self.to_u64().to_le_bytes()
}
#[inline(always)]
pub const fn is_zero(self) -> bool {
self.0 == 0
}
#[inline(always)]
pub const fn ct_eq(self, rhs: Self) -> u64 {
let t = self.0 ^ rhs.0;
!((((t | t.wrapping_neg()) as i64) >> 63) as u64)
}
#[inline(always)]
#[must_use]
pub const fn ct_select(mask: u64, a: Self, b: Self) -> Self {
Self(a.0 ^ (mask & (a.0 ^ b.0)))
}
#[inline(always)]
const fn add_inner(self, rhs: Self) -> Self {
let (x1, c1) = self.0.overflowing_sub(MODULUS - rhs.0);
let adj = 0u32.wrapping_sub(c1 as u32);
Self(x1.wrapping_sub(adj as u64))
}
#[inline(always)]
const fn sub_inner(self, rhs: Self) -> Self {
let (x1, c1) = self.0.overflowing_sub(rhs.0);
let adj = 0u32.wrapping_sub(c1 as u32);
Self(x1.wrapping_sub(adj as u64))
}
#[inline(always)]
const fn neg_inner(self) -> Self {
Self::ZERO.sub_inner(self)
}
#[inline(always)]
const fn mul_inner(self, rhs: Self) -> Self {
Self(Self::montyred((self.0 as u128) * (rhs.0 as u128)))
}
#[inline(always)]
#[must_use]
pub const fn square(self) -> Self {
self.mul_inner(self)
}
#[inline]
#[must_use]
pub fn msquare(self, n: u32) -> Self {
let mut x = self;
for _ in 0..n {
x = x.square();
}
x
}
#[must_use]
pub fn invert(self) -> Self {
let x = self;
let x2 = x * x.square();
let x4 = x2 * x2.msquare(2);
let x5 = x * x4.square();
let x10 = x5 * x5.msquare(5);
let x15 = x5 * x10.msquare(5);
let x16 = x * x15.square();
let x31 = x15 * x16.msquare(15);
let x32 = x * x31.square();
x32 * x31.msquare(33)
}
#[must_use]
pub fn pow(self, mut exp: u64) -> Self {
let mut result = Self::ONE;
let mut base = self;
while exp != 0 {
if exp & 1 == 1 {
result *= base;
}
base = base.square();
exp >>= 1;
}
result
}
#[must_use]
pub fn sqrt(self) -> Option<Self> {
if self.is_zero() {
return Some(Self::ZERO);
}
let qr = self.pow((MODULUS - 1) >> 1);
if qr == Self::MINUS_ONE {
return None;
}
debug_assert_eq!(qr, Self::ONE);
let t: u64 = (1u64 << (64 - TWO_ADICITY)) - 1;
let mut z = Self::from_u64_reduce(POWER_OF_TWO_GENERATOR);
let mut w = self.pow((t - 1) >> 1);
let mut x = self * w;
let mut b = x * w;
let mut v = TWO_ADICITY;
while b != Self::ONE {
let mut k = 0u32;
let mut b2k = b;
while b2k != Self::ONE {
b2k = b2k.square();
k += 1;
}
let j = v - k - 1;
w = z.msquare(j);
z = w.square();
b *= z;
x *= w;
v = k;
}
Some(x)
}
}
impl Default for Fp {
#[inline]
fn default() -> Self {
Self::ZERO
}
}
impl Add for Fp {
type Output = Self;
#[inline(always)]
fn add(self, rhs: Self) -> Self {
self.add_inner(rhs)
}
}
impl AddAssign for Fp {
#[inline(always)]
fn add_assign(&mut self, rhs: Self) {
*self = self.add_inner(rhs);
}
}
impl Sub for Fp {
type Output = Self;
#[inline(always)]
fn sub(self, rhs: Self) -> Self {
self.sub_inner(rhs)
}
}
impl SubAssign for Fp {
#[inline(always)]
fn sub_assign(&mut self, rhs: Self) {
*self = self.sub_inner(rhs);
}
}
impl Neg for Fp {
type Output = Self;
#[inline(always)]
fn neg(self) -> Self {
self.neg_inner()
}
}
impl Mul for Fp {
type Output = Self;
#[inline(always)]
fn mul(self, rhs: Self) -> Self {
self.mul_inner(rhs)
}
}
impl MulAssign for Fp {
#[inline(always)]
fn mul_assign(&mut self, rhs: Self) {
*self = self.mul_inner(rhs);
}
}
#[cfg(test)]
mod tests {
use proptest::prelude::*;
use rstest::rstest;
use serde::Deserialize;
use super::*;
use crate::signing::fixtures::{arb_fp, arb_fp_nonzero, hex_to_bytes};
const VECTORS_JSON: &str = include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/test_data/signing_field_goldilocks_vectors.json",
));
#[derive(Debug, Deserialize)]
struct Vectors {
vectors: Vec<Vector>,
}
#[derive(Debug, Deserialize)]
struct Vector {
a: String,
b: String,
e: String,
add: String,
sub: String,
mul: String,
neg_a: String,
inv_a: String,
pow_a_e: String,
a_eq_b: bool,
}
fn decode_le8(hex: &str) -> [u8; 8] {
let bytes = hex_to_bytes(hex);
assert_eq!(bytes.len(), 8, "expected 8 bytes, was {}", bytes.len());
let mut out = [0u8; 8];
out.copy_from_slice(&bytes);
out
}
fn parse_u64(s: &str) -> u64 {
if let Some(stripped) = s.strip_prefix("0x") {
u64::from_str_radix(stripped, 16).unwrap()
} else {
s.parse::<u64>().unwrap()
}
}
#[rstest]
fn modulus_constant_is_goldilocks_prime() {
assert_eq!(u128::from(MODULUS), (1u128 << 64) - (1u128 << 32) + 1);
}
#[rstest]
fn round_trip_le_bytes_canonical() {
for v in [0u64, 1, 42, MODULUS - 1] {
let f = Fp::from_u64_canonical(v).unwrap();
assert_eq!(f.to_u64(), v);
let bytes = f.to_le_bytes();
assert_eq!(Fp::try_from_le_bytes(bytes).unwrap(), f);
}
}
#[rstest]
fn rejects_non_canonical_decoding() {
let bad = MODULUS.to_le_bytes();
assert!(Fp::try_from_le_bytes(bad).is_none());
let worse = u64::MAX.to_le_bytes();
assert!(Fp::try_from_le_bytes(worse).is_none());
}
#[rstest]
fn invert_zero_returns_zero() {
assert_eq!(Fp::ZERO.invert(), Fp::ZERO);
}
#[rstest]
fn sqrt_round_trip_for_known_squares() {
for v in [1u64, 2, 4, 9, 16, 100, 1_000_000] {
let x = Fp::from_u64_reduce(v);
let xs = x.square();
let s = xs.sqrt().expect("known squares are residues");
assert_eq!(s.square(), xs);
}
}
#[rstest]
fn sqrt_zero_returns_zero() {
assert_eq!(Fp::ZERO.sqrt(), Some(Fp::ZERO));
}
#[rstest]
fn sqrt_returns_none_for_non_square() {
let mut v = 2u64;
loop {
let x = Fp::from_u64_reduce(v);
if x.pow((MODULUS - 1) >> 1) == Fp::MINUS_ONE {
assert_eq!(x.sqrt(), None);
break;
}
v += 1;
}
}
#[rstest]
fn ct_eq_matches_partial_eq() {
let a = Fp::from_u64_reduce(123);
let b = Fp::from_u64_reduce(123);
let c = Fp::from_u64_reduce(124);
assert_eq!(a.ct_eq(b), u64::MAX);
assert_eq!(a.ct_eq(c), 0);
}
#[rstest]
fn ct_select_picks_branch_by_mask() {
let a = Fp::from_u64_reduce(123);
let b = Fp::from_u64_reduce(456);
assert_eq!(Fp::ct_select(0, a, b), a);
assert_eq!(Fp::ct_select(u64::MAX, a, b), b);
}
proptest! {
#[rstest]
fn prop_add_commutative(a in arb_fp(), b in arb_fp()) {
prop_assert_eq!(a + b, b + a);
}
#[rstest]
fn prop_add_associative(a in arb_fp(), b in arb_fp(), c in arb_fp()) {
prop_assert_eq!((a + b) + c, a + (b + c));
}
#[rstest]
fn prop_distributive(a in arb_fp(), b in arb_fp(), c in arb_fp()) {
prop_assert_eq!(a * (b + c), a * b + a * c);
}
#[rstest]
fn prop_mul_commutative(a in arb_fp(), b in arb_fp()) {
prop_assert_eq!(a * b, b * a);
}
#[rstest]
fn prop_mul_associative(a in arb_fp(), b in arb_fp(), c in arb_fp()) {
prop_assert_eq!((a * b) * c, a * (b * c));
}
#[rstest]
fn prop_neg_round_trip(a in arb_fp()) {
prop_assert_eq!(a + (-a), Fp::ZERO);
}
#[rstest]
fn prop_sub_via_add_neg(a in arb_fp(), b in arb_fp()) {
prop_assert_eq!(a - b, a + (-b));
}
#[rstest]
fn prop_sub_round_trip(a in arb_fp(), b in arb_fp()) {
prop_assert_eq!((a + b) - b, a);
}
#[rstest]
fn prop_square_matches_self_mul(a in arb_fp()) {
prop_assert_eq!(a.square(), a * a);
}
#[rstest]
fn prop_invert_round_trip(a in arb_fp_nonzero()) {
prop_assert_eq!(a * a.invert(), Fp::ONE);
}
#[rstest]
fn prop_fermat_little(a in arb_fp_nonzero()) {
prop_assert_eq!(a.pow(MODULUS - 1), Fp::ONE);
}
#[rstest]
fn prop_sqrt_round_trip(a in arb_fp()) {
let sq = a.square();
let s = sq.sqrt().expect("squares are quadratic residues");
prop_assert_eq!(s.square(), sq);
}
#[rstest]
fn prop_le_bytes_round_trip(a in arb_fp()) {
let bytes = a.to_le_bytes();
prop_assert_eq!(Fp::try_from_le_bytes(bytes).unwrap(), a);
}
#[rstest]
fn prop_from_u64_canonical_accepts_in_range(v in 0u64..MODULUS) {
let f = Fp::from_u64_canonical(v).expect("in-range value");
prop_assert_eq!(f.to_u64(), v);
}
#[rstest]
fn prop_from_u64_canonical_rejects_out_of_range(v in MODULUS..=u64::MAX) {
prop_assert!(Fp::from_u64_canonical(v).is_none());
}
#[rstest]
fn prop_msquare_matches_iterated_square(a in arb_fp(), n in 0u32..16) {
let mut iter = a;
for _ in 0..n {
iter = iter.square();
}
prop_assert_eq!(a.msquare(n), iter);
}
#[rstest]
fn prop_ct_eq_matches_partial_eq(a in arb_fp(), b in arb_fp()) {
let ct = a.ct_eq(b);
if a == b {
prop_assert_eq!(ct, u64::MAX);
} else {
prop_assert_eq!(ct, 0);
}
}
#[rstest]
fn prop_ct_select_picks_branch(a in arb_fp(), b in arb_fp()) {
prop_assert_eq!(Fp::ct_select(0, a, b), a);
prop_assert_eq!(Fp::ct_select(u64::MAX, a, b), b);
}
}
#[rstest]
fn matches_go_reference_vectors() {
let suite: Vectors = serde_json::from_str(VECTORS_JSON).expect("parse vectors");
assert!(!suite.vectors.is_empty(), "vector file is empty");
for (i, v) in suite.vectors.iter().enumerate() {
let a = Fp::try_from_le_bytes(decode_le8(&v.a))
.unwrap_or_else(|| panic!("vector {i}: decode a"));
let b = Fp::try_from_le_bytes(decode_le8(&v.b))
.unwrap_or_else(|| panic!("vector {i}: decode b"));
let e = parse_u64(&v.e);
assert_eq!((a + b).to_le_bytes(), decode_le8(&v.add), "vector {i}: add");
assert_eq!((a - b).to_le_bytes(), decode_le8(&v.sub), "vector {i}: sub");
assert_eq!((a * b).to_le_bytes(), decode_le8(&v.mul), "vector {i}: mul");
assert_eq!((-a).to_le_bytes(), decode_le8(&v.neg_a), "vector {i}: neg");
assert_eq!(
a.invert().to_le_bytes(),
decode_le8(&v.inv_a),
"vector {i}: inv"
);
assert_eq!(
a.pow(e).to_le_bytes(),
decode_le8(&v.pow_a_e),
"vector {i}: pow"
);
assert_eq!(a == b, v.a_eq_b, "vector {i}: eq");
}
}
}