use alloc::vec::Vec;
use core::arch::x86_64::*;
use core::fmt::Debug;
use core::iter::{Product, Sum};
use core::mem::transmute;
use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
use p3_field::exponentiation::exp_10540996611094048183;
use p3_field::interleave::{interleave_u64, interleave_u128};
use p3_field::op_assign_macros::{
impl_add_assign, impl_add_base_field, impl_div_methods, impl_mul_base_field, impl_mul_methods,
impl_packed_value, impl_rng, impl_sub_assign, impl_sub_base_field, impl_sum_prod_base_field,
ring_sum,
};
use p3_field::{
Algebra, Field, InjectiveMonomial, PackedField, PackedFieldPow2, PackedValue,
PermutationMonomial, PrimeCharacteristicRing, PrimeField64, dispatch_chunked_mixed_dot_product,
impl_packed_field_pow_2,
};
use p3_util::reconstitute_from_base;
use rand::distr::{Distribution, StandardUniform};
use rand::{Rng, RngExt};
use crate::{Goldilocks, P};
const WIDTH: usize = 4;
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
#[repr(transparent)] #[must_use]
pub struct PackedGoldilocksAVX2(pub [Goldilocks; WIDTH]);
impl PackedGoldilocksAVX2 {
#[inline]
#[must_use]
pub(crate) fn to_vector(self) -> __m256i {
unsafe {
transmute(self)
}
}
#[inline]
pub(crate) fn from_vector(vector: __m256i) -> Self {
unsafe {
transmute(vector)
}
}
#[inline]
const fn broadcast(value: Goldilocks) -> Self {
Self([value; WIDTH])
}
}
impl From<Goldilocks> for PackedGoldilocksAVX2 {
fn from(x: Goldilocks) -> Self {
Self::broadcast(x)
}
}
impl Add for PackedGoldilocksAVX2 {
type Output = Self;
#[inline]
fn add(self, rhs: Self) -> Self {
Self::from_vector(add(self.to_vector(), rhs.to_vector()))
}
}
impl Sub for PackedGoldilocksAVX2 {
type Output = Self;
#[inline]
fn sub(self, rhs: Self) -> Self {
Self::from_vector(sub(self.to_vector(), rhs.to_vector()))
}
}
impl Neg for PackedGoldilocksAVX2 {
type Output = Self;
#[inline]
fn neg(self) -> Self {
Self::from_vector(neg(self.to_vector()))
}
}
impl Mul for PackedGoldilocksAVX2 {
type Output = Self;
#[inline]
fn mul(self, rhs: Self) -> Self {
Self::from_vector(mul(self.to_vector(), rhs.to_vector()))
}
}
impl_add_assign!(PackedGoldilocksAVX2);
impl_sub_assign!(PackedGoldilocksAVX2);
impl_mul_methods!(PackedGoldilocksAVX2);
ring_sum!(PackedGoldilocksAVX2);
impl_rng!(PackedGoldilocksAVX2);
impl PrimeCharacteristicRing for PackedGoldilocksAVX2 {
type PrimeSubfield = Goldilocks;
const ZERO: Self = Self::broadcast(Goldilocks::ZERO);
const ONE: Self = Self::broadcast(Goldilocks::ONE);
const TWO: Self = Self::broadcast(Goldilocks::TWO);
const NEG_ONE: Self = Self::broadcast(Goldilocks::NEG_ONE);
#[inline]
fn from_prime_subfield(f: Self::PrimeSubfield) -> Self {
f.into()
}
#[inline]
fn halve(&self) -> Self {
Self::from_vector(halve(self.to_vector()))
}
#[inline]
fn square(&self) -> Self {
Self::from_vector(square(self.to_vector()))
}
#[inline]
fn zero_vec(len: usize) -> Vec<Self> {
unsafe { reconstitute_from_base(Goldilocks::zero_vec(len * WIDTH)) }
}
}
impl InjectiveMonomial<7> for PackedGoldilocksAVX2 {}
impl PermutationMonomial<7> for PackedGoldilocksAVX2 {
fn injective_exp_root_n(&self) -> Self {
exp_10540996611094048183(*self)
}
}
impl_add_base_field!(PackedGoldilocksAVX2, Goldilocks);
impl_sub_base_field!(PackedGoldilocksAVX2, Goldilocks);
impl_mul_base_field!(PackedGoldilocksAVX2, Goldilocks);
impl_div_methods!(PackedGoldilocksAVX2, Goldilocks);
impl_sum_prod_base_field!(PackedGoldilocksAVX2, Goldilocks);
impl Algebra<Goldilocks> for PackedGoldilocksAVX2 {
const BATCHED_LC_CHUNK: usize = 32;
#[inline(always)]
fn mixed_dot_product<const N: usize>(a: &[Self; N], f: &[Goldilocks; N]) -> Self {
dispatch_chunked_mixed_dot_product::<Self, Goldilocks, N>(
a,
f,
<Self as Algebra<Goldilocks>>::BATCHED_LC_CHUNK,
)
}
}
impl_packed_value!(PackedGoldilocksAVX2, Goldilocks, WIDTH);
unsafe impl PackedField for PackedGoldilocksAVX2 {
type Scalar = Goldilocks;
}
impl_packed_field_pow_2!(
PackedGoldilocksAVX2;
[
(1, interleave_u64),
(2, interleave_u128),
],
WIDTH
);
const SIGN_BIT: __m256i = unsafe { transmute([i64::MIN; WIDTH]) };
const SHIFTED_FIELD_ORDER: __m256i =
unsafe { transmute([Goldilocks::ORDER_U64 ^ (i64::MIN as u64); WIDTH]) };
const EPSILON: __m256i = unsafe { transmute([Goldilocks::ORDER_U64.wrapping_neg(); WIDTH]) };
#[inline]
pub fn shift(x: __m256i) -> __m256i {
unsafe { _mm256_xor_si256(x, SIGN_BIT) }
}
#[inline]
unsafe fn canonicalize_s(x_s: __m256i) -> __m256i {
unsafe {
let mask = _mm256_cmpgt_epi64(SHIFTED_FIELD_ORDER, x_s);
let wrapback_amt = _mm256_andnot_si256(mask, EPSILON);
_mm256_add_epi64(x_s, wrapback_amt)
}
}
#[inline]
unsafe fn add_no_double_overflow_64_64s_s(x: __m256i, y_s: __m256i) -> __m256i {
unsafe {
let res_wrapped_s = _mm256_add_epi64(x, y_s);
let mask = _mm256_cmpgt_epi64(y_s, res_wrapped_s); let wrapback_amt = _mm256_srli_epi64::<32>(mask); _mm256_add_epi64(res_wrapped_s, wrapback_amt)
}
}
#[inline]
fn add(x: __m256i, y: __m256i) -> __m256i {
unsafe {
let y_s = shift(y);
let res_s = add_no_double_overflow_64_64s_s(x, canonicalize_s(y_s));
shift(res_s)
}
}
#[inline]
fn sub(x: __m256i, y: __m256i) -> __m256i {
unsafe {
let mut y_s = shift(y);
y_s = canonicalize_s(y_s);
let x_s = shift(x);
let mask = _mm256_cmpgt_epi64(y_s, x_s); let wrapback_amt = _mm256_srli_epi64::<32>(mask); let res_wrapped = _mm256_sub_epi64(x_s, y_s);
_mm256_sub_epi64(res_wrapped, wrapback_amt)
}
}
#[inline]
fn neg(y: __m256i) -> __m256i {
unsafe {
let y_s = shift(y);
_mm256_sub_epi64(SHIFTED_FIELD_ORDER, canonicalize_s(y_s))
}
}
#[inline(always)]
pub(crate) fn halve(input: __m256i) -> __m256i {
unsafe {
const ONE: __m256i = unsafe { transmute([1_i64; 4]) };
const ZERO: __m256i = unsafe { transmute([0_i64; 4]) };
let half = _mm256_set1_epi64x(P.div_ceil(2) as i64);
let least_bit = _mm256_and_si256(input, ONE); let t = _mm256_srli_epi64::<1>(input);
let neg_least_bit = _mm256_sub_epi64(ZERO, least_bit);
let maybe_half = _mm256_and_si256(half, neg_least_bit);
_mm256_add_epi64(t, maybe_half)
}
}
#[inline]
fn mul64_64(x: __m256i, y: __m256i) -> (__m256i, __m256i) {
unsafe {
let x_hi = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(x)));
let y_hi = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(y)));
let mul_ll = _mm256_mul_epu32(x, y);
let mul_lh = _mm256_mul_epu32(x, y_hi);
let mul_hl = _mm256_mul_epu32(x_hi, y);
let mul_hh = _mm256_mul_epu32(x_hi, y_hi);
let mul_ll_hi = _mm256_srli_epi64::<32>(mul_ll);
let t0 = _mm256_add_epi64(mul_hl, mul_ll_hi);
let t0_lo = _mm256_and_si256(t0, EPSILON);
let t0_hi = _mm256_srli_epi64::<32>(t0);
let t1 = _mm256_add_epi64(mul_lh, t0_lo);
let t2 = _mm256_add_epi64(mul_hh, t0_hi);
let t1_hi = _mm256_srli_epi64::<32>(t1);
let res_hi = _mm256_add_epi64(t2, t1_hi);
let t1_lo = _mm256_castps_si256(_mm256_moveldup_ps(_mm256_castsi256_ps(t1)));
let res_lo = _mm256_blend_epi32::<0xaa>(mul_ll, t1_lo);
(res_hi, res_lo)
}
}
#[inline]
fn square64(x: __m256i) -> (__m256i, __m256i) {
unsafe {
let x_hi = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(x)));
let mul_ll = _mm256_mul_epu32(x, x);
let mul_lh = _mm256_mul_epu32(x, x_hi);
let mul_hh = _mm256_mul_epu32(x_hi, x_hi);
let mul_ll_hi = _mm256_srli_epi64::<33>(mul_ll);
let t0 = _mm256_add_epi64(mul_lh, mul_ll_hi);
let t0_hi = _mm256_srli_epi64::<31>(t0);
let res_hi = _mm256_add_epi64(mul_hh, t0_hi);
let mul_lh_lo = _mm256_slli_epi64::<33>(mul_lh);
let res_lo = _mm256_add_epi64(mul_ll, mul_lh_lo);
(res_hi, res_lo)
}
}
#[inline]
unsafe fn add_small_64s_64_s(x_s: __m256i, y: __m256i) -> __m256i {
unsafe {
let res_wrapped_s = _mm256_add_epi64(x_s, y);
let mask = _mm256_cmpgt_epi32(x_s, res_wrapped_s); let wrapback_amt = _mm256_srli_epi64::<32>(mask); _mm256_add_epi64(res_wrapped_s, wrapback_amt)
}
}
#[inline]
unsafe fn sub_small_64s_64_s(x_s: __m256i, y: __m256i) -> __m256i {
unsafe {
let res_wrapped_s = _mm256_sub_epi64(x_s, y);
let mask = _mm256_cmpgt_epi32(res_wrapped_s, x_s); let wrapback_amt = _mm256_srli_epi64::<32>(mask); _mm256_sub_epi64(res_wrapped_s, wrapback_amt)
}
}
#[inline]
fn reduce128(x: (__m256i, __m256i)) -> __m256i {
unsafe {
let (hi0, lo0) = x;
let lo0_s = shift(lo0);
let hi_hi0 = _mm256_srli_epi64::<32>(hi0);
let lo1_s = sub_small_64s_64_s(lo0_s, hi_hi0);
let t1 = _mm256_mul_epu32(hi0, EPSILON);
let lo2_s = add_small_64s_64_s(lo1_s, t1);
shift(lo2_s)
}
}
#[inline]
fn mul(x: __m256i, y: __m256i) -> __m256i {
reduce128(mul64_64(x, y))
}
#[inline]
fn square(x: __m256i) -> __m256i {
reduce128(square64(x))
}
#[cfg(test)]
mod tests {
use p3_field_testing::test_packed_field;
use super::{Goldilocks, PackedGoldilocksAVX2, WIDTH};
const SPECIAL_VALS: [Goldilocks; WIDTH] = Goldilocks::new_array([
0xFFFF_FFFF_0000_0000,
0xFFFF_FFFF_FFFF_FFFF,
0x0000_0000_0000_0001,
0xFFFF_FFFF_0000_0001,
]);
const ZEROS: PackedGoldilocksAVX2 = PackedGoldilocksAVX2(Goldilocks::new_array([
0x0000_0000_0000_0000,
0xFFFF_FFFF_0000_0001,
0x0000_0000_0000_0000,
0xFFFF_FFFF_0000_0001,
]));
const ONES: PackedGoldilocksAVX2 = PackedGoldilocksAVX2(Goldilocks::new_array([
0x0000_0000_0000_0001,
0xFFFF_FFFF_0000_0002,
0x0000_0000_0000_0001,
0xFFFF_FFFF_0000_0002,
]));
test_packed_field!(
crate::PackedGoldilocksAVX2,
&[super::ZEROS],
&[super::ONES],
crate::PackedGoldilocksAVX2(super::SPECIAL_VALS)
);
}