pub const P: u64 = 0xFFFFFFFF00000001;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct Goldilocks(pub u64);
pub struct GoldilocksSIMD;
#[cfg(target_feature = "avx512f")]
impl GoldilocksSIMD {
pub fn add_avx512(a: &[Goldilocks], b: &[Goldilocks], result: &mut [Goldilocks]) -> Result<(), crate::error::FieldError> {
use std::arch::x86_64::*;
assert_eq!(a.len(), b.len());
assert_eq!(a.len(), result.len());
assert_eq!(a.len() % 8, 0, "Length must be multiple of 8 for AVX-512");
for chunk in a.chunks_exact(8).zip(b.chunks_exact(8)).zip(result.chunks_exact_mut(8)) {
let (a_chunk, (b_chunk, res_chunk)) = chunk;
let a_vals: [u64; 8] = a_chunk.iter().map(|x| x.0).collect::<Vec<_>>().try_into()
.map_err(|_| crate::error::FieldError::SimdError { reason: "Failed to convert AVX-512 vectors".to_string() })?;
let b_vals: [u64; 8] = b_chunk.iter().map(|x| x.0).collect::<Vec<_>>().try_into()
.map_err(|_| crate::error::FieldError::SimdError { reason: "Failed to convert AVX-512 vectors".to_string() })?;
unsafe {
let a_vec = _mm512_loadu_epi64(a_vals.as_ptr() as *const i64);
let b_vec = _mm512_loadu_epi64(b_vals.as_ptr() as *const i64);
let sum_vec = _mm512_add_epi64(a_vec, b_vec);
let mut sum_vals = [0u64; 8];
_mm512_storeu_epi64(sum_vals.as_mut_ptr() as *mut i64, sum_vec);
for i in 0..8 {
res_chunk[i] = Goldilocks::new(sum_vals[i]);
}
}
}
Ok(())
}
pub fn mul_avx512(a: &[Goldilocks], b: &[Goldilocks], result: &mut [Goldilocks]) {
for i in 0..a.len().min(b.len()).min(result.len()) {
result[i] = a[i].mul(b[i]);
}
}
}
#[cfg(target_feature = "avx2")]
impl GoldilocksSIMD {
pub fn add_avx2(a: &[Goldilocks], b: &[Goldilocks], result: &mut [Goldilocks]) -> Result<(), crate::error::FieldError> {
use std::arch::x86_64::*;
assert_eq!(a.len(), b.len());
assert_eq!(a.len(), result.len());
assert_eq!(a.len() % 4, 0, "Length must be multiple of 4 for AVX2");
for chunk in a.chunks_exact(4).zip(b.chunks_exact(4)).zip(result.chunks_exact_mut(4)) {
let (a_chunk, (b_chunk, res_chunk)) = chunk;
let a_vals: [u64; 4] = a_chunk.iter().map(|x| x.0).collect::<Vec<_>>().try_into()
.map_err(|_| crate::error::FieldError::SimdError { reason: "Failed to convert AVX2 vectors".to_string() })?;
let b_vals: [u64; 4] = b_chunk.iter().map(|x| x.0).collect::<Vec<_>>().try_into()
.map_err(|_| crate::error::FieldError::SimdError { reason: "Failed to convert AVX2 vectors".to_string() })?;
unsafe {
let a_vec = _mm256_loadu_epi64(a_vals.as_ptr() as *const i64);
let b_vec = _mm256_loadu_epi64(b_vals.as_ptr() as *const i64);
let sum_vec = _mm256_add_epi64(a_vec, b_vec);
let mut sum_vals = [0u64; 4];
_mm256_storeu_epi64(sum_vals.as_mut_ptr() as *mut i64, sum_vec);
for i in 0..4 {
res_chunk[i] = Goldilocks::new(sum_vals[i]);
}
}
}
Ok(())
}
}
#[cfg(not(any(target_feature = "avx512f", target_feature = "avx2")))]
impl GoldilocksSIMD {
pub fn add_batch(a: &[Goldilocks], b: &[Goldilocks], result: &mut [Goldilocks]) {
for i in 0..a.len().min(b.len()).min(result.len()) {
result[i] = a[i].add(b[i]);
}
}
pub fn mul_batch(a: &[Goldilocks], b: &[Goldilocks], result: &mut [Goldilocks]) {
for i in 0..a.len().min(b.len()).min(result.len()) {
result[i] = a[i].mul(b[i]);
}
}
}
impl GoldilocksSIMD {
pub fn has_avx512() -> bool {
#[cfg(target_feature = "avx512f")]
{
true
}
#[cfg(not(target_feature = "avx512f"))]
{
false
}
}
pub fn has_avx2() -> bool {
#[cfg(target_feature = "avx2")]
{
true
}
#[cfg(not(target_feature = "avx2"))]
{
false
}
}
pub fn add_vectors(a: &[Goldilocks], b: &[Goldilocks]) -> Vec<Goldilocks> {
let mut result = vec![Goldilocks(0); a.len()];
#[cfg(target_feature = "avx512f")]
if Self::has_avx512() && a.len() % 8 == 0 {
Self::add_avx512(a, b, &mut result);
return result;
}
#[cfg(target_feature = "avx2")]
if Self::has_avx2() && a.len() % 4 == 0 {
Self::add_avx2(a, b, &mut result);
return result;
}
#[cfg(not(any(target_feature = "avx512f", target_feature = "avx2")))]
Self::add_batch(a, b, &mut result);
result
}
pub fn mul_vectors(a: &[Goldilocks], b: &[Goldilocks]) -> Vec<Goldilocks> {
let mut result = vec![Goldilocks(0); a.len()];
#[cfg(target_feature = "avx512f")]
if Self::has_avx512() && a.len() % 8 == 0 {
Self::mul_avx512(a, b, &mut result);
return result;
}
#[cfg(not(target_feature = "avx512f"))]
Self::mul_batch(a, b, &mut result);
result
}
}
impl Goldilocks {
#[inline]
pub fn new(val: u64) -> Self {
Self::reduce_u64(val)
}
#[inline]
pub fn from_field_bits(val: i64) -> Self {
let unsigned = val as u64;
Self::reduce_u64(unsigned)
}
pub fn from_i64(val: i64) -> Self {
if val >= 0 {
Self::reduce_u64(val as u64)
} else {
let abs_val = val.unsigned_abs();
let rem = Self::reduce_u64(abs_val);
if rem.0 == 0 { Self(0) } else { Self(P - rem.0) }
}
}
#[inline]
fn reduce_u64(val: u64) -> Self {
if val < P {
return Self(val);
}
let reduced = val.wrapping_sub(P);
if reduced < P {
Self(reduced)
} else {
Self(val % P)
}
}
#[inline]
pub fn reduce_constant_time(val: u64) -> Self {
let mask = ((val >= P) as u64).wrapping_neg();
Self(val.wrapping_sub(P & mask))
}
#[inline]
pub fn reduce128(x: u128) -> Self {
let x0 = (x & 0xFFFFFFFF) as i128;
let x1 = ((x >> 32) & 0xFFFFFFFF) as i128;
let x2 = ((x >> 64) & 0xFFFFFFFF) as i128;
let x3 = ((x >> 96) & 0xFFFFFFFF) as i128;
let res = (x1 + x2) * 0x100000000 + (x0 - x2 - x3);
let mut final_res = res;
while final_res < 0 {
final_res += P as i128;
}
while final_res >= P as i128 {
final_res -= P as i128;
}
Self(final_res as u64)
}
pub fn add(self, other: Self) -> Self {
let (sum, overflow) = self.0.overflowing_add(other.0);
if overflow {
let res = sum.wrapping_add(0xFFFFFFFF);
Self::reduce_u64(res)
} else {
Self::reduce_u64(sum)
}
}
pub fn sub(self, other: Self) -> Self {
if self.0 >= other.0 {
Self(self.0 - other.0)
} else {
Self(P - (other.0 - self.0))
}
}
pub fn mul(self, other: Self) -> Self {
let res = self.0 as u128 * other.0 as u128;
Self::reduce128(res)
}
pub fn mul_constant_time(self, other: Self) -> Self {
let a = self.0 as u128;
let b = other.0 as u128;
let product = a * b;
let x0 = (product & 0xFFFFFFFF) as i128;
let x1 = ((product >> 32) & 0xFFFFFFFF) as i128;
let x2 = ((product >> 64) & 0xFFFFFFFF) as i128;
let x3 = ((product >> 96) & 0xFFFFFFFF) as i128;
let mut res = (x1 + x2) * 0x100000000 + (x0 - x2 - x3);
let p_i128 = P as i128;
let mask_neg = ((res < 0) as i128) * (2 * p_i128);
res += mask_neg;
let q = res / p_i128;
res -= q * p_i128;
let mask_over = ((res >= p_i128) as i128) * p_i128;
res -= mask_over;
Self(res as u64)
}
pub fn pow(self, mut exp: u64) -> Self {
let mut res = Self(1);
let mut base = self;
while exp > 0 {
if exp % 2 == 1 {
res = res.mul(base);
}
base = base.mul(base);
exp /= 2;
}
res
}
#[inline]
pub fn div(self, other: Self) -> Self {
self.mul(other.inv())
}
#[inline]
pub fn neg(self) -> Self {
if self.0 == 0 {
Self(0)
} else {
Self(P - self.0)
}
}
pub fn inv(self) -> Self {
assert_ne!(self.0, 0, "Modular inverse of zero is undefined.");
self.pow(P - 2)
}
#[inline]
pub fn is_valid(&self) -> bool {
self.0 < P
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_from_field_bits_reduces() {
let over_p = Goldilocks::from_field_bits((P + 1) as i64);
assert!(over_p.0 < P, "from_field_bits must reduce values >= P");
assert_eq!(over_p.0, 1);
}
#[test]
fn test_boundary_values() {
let at_p_minus_1 = Goldilocks::new(P - 1);
assert_eq!(at_p_minus_1.0, P - 1);
let at_p = Goldilocks::new(P);
assert_eq!(at_p.0, 0);
let at_p_plus_1 = Goldilocks::new(P + 1);
assert_eq!(at_p_plus_1.0, 1);
}
#[test]
fn test_addition_overflow() {
let a = Goldilocks(P - 1);
let b = Goldilocks(P - 1);
let sum = a.add(b);
assert!(sum.0 < P);
assert_eq!(sum.0, P - 2);
}
#[test]
fn test_subtraction_underflow() {
let a = Goldilocks(0);
let b = Goldilocks(1);
let diff = a.sub(b);
assert_eq!(diff.0, P - 1);
}
#[test]
fn test_constant_time_multiplication() {
let test_cases = [
(Goldilocks::new(0), Goldilocks::new(0)),
(Goldilocks::new(1), Goldilocks::new(1)),
(Goldilocks::new(2), Goldilocks::new(3)),
(Goldilocks::new(P - 1), Goldilocks::new(P - 1)),
(Goldilocks::new(12345), Goldilocks::new(67890)),
];
for (a, b) in test_cases {
let regular = a.mul(b);
let constant_time = a.mul_constant_time(b);
assert_eq!(regular, constant_time,
"Constant-time mul should match regular mul: {} * {} = {} vs {}",
a.0, b.0, regular.0, constant_time.0);
}
}
}