use alloc::vec::Vec;
use core::arch::x86_64::*;
use core::fmt::Debug;
use core::iter::{Product, Sum};
use core::mem::transmute;
use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
use p3_field::exponentiation::exp_10540996611094048183;
use p3_field::interleave::{interleave_u64, interleave_u128, interleave_u256};
use p3_field::op_assign_macros::{
impl_add_assign, impl_add_base_field, impl_div_methods, impl_mul_base_field, impl_mul_methods,
impl_packed_value, impl_rng, impl_sub_assign, impl_sub_base_field, impl_sum_prod_base_field,
ring_sum,
};
use p3_field::{
Algebra, Field, InjectiveMonomial, PackedField, PackedFieldPow2, PackedValue,
PermutationMonomial, PrimeCharacteristicRing, PrimeField64, dispatch_chunked_mixed_dot_product,
impl_packed_field_pow_2,
};
use p3_util::reconstitute_from_base;
use rand::distr::{Distribution, StandardUniform};
use rand::{Rng, RngExt};
use crate::{Goldilocks, P};
const WIDTH: usize = 8;
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
#[repr(transparent)] #[must_use]
pub struct PackedGoldilocksAVX512(pub [Goldilocks; WIDTH]);
impl PackedGoldilocksAVX512 {
#[inline]
#[must_use]
pub(crate) fn to_vector(self) -> __m512i {
unsafe {
transmute(self)
}
}
#[inline]
pub(crate) fn from_vector(vector: __m512i) -> Self {
unsafe {
transmute(vector)
}
}
#[inline]
const fn broadcast(value: Goldilocks) -> Self {
Self([value; WIDTH])
}
}
impl From<Goldilocks> for PackedGoldilocksAVX512 {
fn from(x: Goldilocks) -> Self {
Self::broadcast(x)
}
}
impl Add for PackedGoldilocksAVX512 {
type Output = Self;
#[inline]
fn add(self, rhs: Self) -> Self {
Self::from_vector(add(self.to_vector(), rhs.to_vector()))
}
}
impl Sub for PackedGoldilocksAVX512 {
type Output = Self;
#[inline]
fn sub(self, rhs: Self) -> Self {
Self::from_vector(sub(self.to_vector(), rhs.to_vector()))
}
}
impl Neg for PackedGoldilocksAVX512 {
type Output = Self;
#[inline]
fn neg(self) -> Self {
Self::from_vector(neg(self.to_vector()))
}
}
impl Mul for PackedGoldilocksAVX512 {
type Output = Self;
#[inline]
fn mul(self, rhs: Self) -> Self {
Self::from_vector(mul(self.to_vector(), rhs.to_vector()))
}
}
impl_add_assign!(PackedGoldilocksAVX512);
impl_sub_assign!(PackedGoldilocksAVX512);
impl_mul_methods!(PackedGoldilocksAVX512);
ring_sum!(PackedGoldilocksAVX512);
impl_rng!(PackedGoldilocksAVX512);
impl PrimeCharacteristicRing for PackedGoldilocksAVX512 {
type PrimeSubfield = Goldilocks;
const ZERO: Self = Self::broadcast(Goldilocks::ZERO);
const ONE: Self = Self::broadcast(Goldilocks::ONE);
const TWO: Self = Self::broadcast(Goldilocks::TWO);
const NEG_ONE: Self = Self::broadcast(Goldilocks::NEG_ONE);
#[inline]
fn from_prime_subfield(f: Self::PrimeSubfield) -> Self {
f.into()
}
#[inline]
fn halve(&self) -> Self {
Self::from_vector(halve(self.to_vector()))
}
#[inline]
fn square(&self) -> Self {
Self::from_vector(square(self.to_vector()))
}
#[inline]
fn zero_vec(len: usize) -> Vec<Self> {
unsafe { reconstitute_from_base(Goldilocks::zero_vec(len * WIDTH)) }
}
}
impl_add_base_field!(PackedGoldilocksAVX512, Goldilocks);
impl_sub_base_field!(PackedGoldilocksAVX512, Goldilocks);
impl_mul_base_field!(PackedGoldilocksAVX512, Goldilocks);
impl_div_methods!(PackedGoldilocksAVX512, Goldilocks);
impl_sum_prod_base_field!(PackedGoldilocksAVX512, Goldilocks);
impl Algebra<Goldilocks> for PackedGoldilocksAVX512 {
const BATCHED_LC_CHUNK: usize = 4;
#[inline(always)]
fn mixed_dot_product<const N: usize>(a: &[Self; N], f: &[Goldilocks; N]) -> Self {
dispatch_chunked_mixed_dot_product::<Self, Goldilocks, N>(
a,
f,
<Self as Algebra<Goldilocks>>::BATCHED_LC_CHUNK,
)
}
}
impl InjectiveMonomial<7> for PackedGoldilocksAVX512 {}
impl PermutationMonomial<7> for PackedGoldilocksAVX512 {
fn injective_exp_root_n(&self) -> Self {
exp_10540996611094048183(*self)
}
}
impl_packed_value!(PackedGoldilocksAVX512, Goldilocks, WIDTH);
unsafe impl PackedField for PackedGoldilocksAVX512 {
type Scalar = Goldilocks;
}
impl_packed_field_pow_2!(
PackedGoldilocksAVX512;
[
(1, interleave_u64),
(2, interleave_u128),
(4, interleave_u256),
],
WIDTH
);
const FIELD_ORDER: __m512i = unsafe { transmute([Goldilocks::ORDER_U64; WIDTH]) };
const EPSILON: __m512i = unsafe { transmute([Goldilocks::ORDER_U64.wrapping_neg(); WIDTH]) };
#[inline]
unsafe fn canonicalize(x: __m512i) -> __m512i {
unsafe {
let mask = _mm512_cmpge_epu64_mask(x, FIELD_ORDER);
_mm512_mask_sub_epi64(x, mask, x, FIELD_ORDER)
}
}
#[inline]
unsafe fn add_no_double_overflow_64_64(x: __m512i, y: __m512i) -> __m512i {
unsafe {
let res_wrapped = _mm512_add_epi64(x, y);
let mask = _mm512_cmplt_epu64_mask(res_wrapped, y); _mm512_mask_sub_epi64(res_wrapped, mask, res_wrapped, FIELD_ORDER)
}
}
#[inline]
unsafe fn sub_no_double_overflow_64_64(x: __m512i, y: __m512i) -> __m512i {
unsafe {
let mask = _mm512_cmplt_epu64_mask(x, y); let res_wrapped = _mm512_sub_epi64(x, y);
_mm512_mask_add_epi64(res_wrapped, mask, res_wrapped, FIELD_ORDER)
}
}
#[inline]
fn add(x: __m512i, y: __m512i) -> __m512i {
unsafe { add_no_double_overflow_64_64(x, canonicalize(y)) }
}
#[inline]
fn sub(x: __m512i, y: __m512i) -> __m512i {
unsafe { sub_no_double_overflow_64_64(x, canonicalize(y)) }
}
#[inline]
fn neg(y: __m512i) -> __m512i {
unsafe { _mm512_sub_epi64(FIELD_ORDER, canonicalize(y)) }
}
#[inline(always)]
pub(crate) fn halve(input: __m512i) -> __m512i {
unsafe {
const ONE: __m512i = unsafe { transmute([1_i64; 8]) };
let half = _mm512_set1_epi64(P.div_ceil(2) as i64);
let least_bit = _mm512_test_epi64_mask(input, ONE); let t = _mm512_srli_epi64::<1>(input);
_mm512_mask_add_epi64(t, least_bit, t, half)
}
}
#[allow(clippy::useless_transmute)]
const LO_32_BITS_MASK: __mmask16 = unsafe { transmute(0b0101010101010101u16) };
#[inline]
fn mul64_64(x: __m512i, y: __m512i) -> (__m512i, __m512i) {
unsafe {
let x_hi = _mm512_castps_si512(_mm512_movehdup_ps(_mm512_castsi512_ps(x)));
let y_hi = _mm512_castps_si512(_mm512_movehdup_ps(_mm512_castsi512_ps(y)));
let mul_ll = _mm512_mul_epu32(x, y);
let mul_lh = _mm512_mul_epu32(x, y_hi);
let mul_hl = _mm512_mul_epu32(x_hi, y);
let mul_hh = _mm512_mul_epu32(x_hi, y_hi);
let mul_ll_hi = _mm512_srli_epi64::<32>(mul_ll);
let t0 = _mm512_add_epi64(mul_hl, mul_ll_hi);
let t0_lo = _mm512_and_si512(t0, EPSILON);
let t0_hi = _mm512_srli_epi64::<32>(t0);
let t1 = _mm512_add_epi64(mul_lh, t0_lo);
let t2 = _mm512_add_epi64(mul_hh, t0_hi);
let t1_hi = _mm512_srli_epi64::<32>(t1);
let res_hi = _mm512_add_epi64(t2, t1_hi);
let t1_lo = _mm512_castps_si512(_mm512_moveldup_ps(_mm512_castsi512_ps(t1)));
let res_lo = _mm512_mask_blend_epi32(LO_32_BITS_MASK, t1_lo, mul_ll);
(res_hi, res_lo)
}
}
#[inline]
fn square64(x: __m512i) -> (__m512i, __m512i) {
unsafe {
let x_hi = _mm512_castps_si512(_mm512_movehdup_ps(_mm512_castsi512_ps(x)));
let mul_ll = _mm512_mul_epu32(x, x);
let mul_lh = _mm512_mul_epu32(x, x_hi);
let mul_hh = _mm512_mul_epu32(x_hi, x_hi);
let mul_ll_hi = _mm512_srli_epi64::<33>(mul_ll);
let t0 = _mm512_add_epi64(mul_lh, mul_ll_hi);
let t0_hi = _mm512_srli_epi64::<31>(t0);
let res_hi = _mm512_add_epi64(mul_hh, t0_hi);
let mul_lh_lo = _mm512_slli_epi64::<33>(mul_lh);
let res_lo = _mm512_add_epi64(mul_ll, mul_lh_lo);
(res_hi, res_lo)
}
}
#[inline]
fn reduce128(x: (__m512i, __m512i)) -> __m512i {
unsafe {
let (hi0, lo0) = x;
let hi_hi0 = _mm512_srli_epi64::<32>(hi0);
let lo1 = sub_no_double_overflow_64_64(lo0, hi_hi0);
let t1 = _mm512_mul_epu32(hi0, EPSILON);
add_no_double_overflow_64_64(lo1, t1)
}
}
#[inline]
fn mul(x: __m512i, y: __m512i) -> __m512i {
reduce128(mul64_64(x, y))
}
#[inline]
fn square(x: __m512i) -> __m512i {
reduce128(square64(x))
}
#[cfg(test)]
mod tests {
use p3_field_testing::test_packed_field;
use super::{Goldilocks, PackedGoldilocksAVX512, WIDTH};
const SPECIAL_VALS: [Goldilocks; WIDTH] = Goldilocks::new_array([
0xFFFF_FFFF_0000_0001,
0xFFFF_FFFF_0000_0000,
0xFFFF_FFFE_FFFF_FFFF,
0xFFFF_FFFF_FFFF_FFFF,
0x0000_0000_0000_0000,
0x0000_0000_0000_0001,
0x0000_0000_0000_0002,
0x0FFF_FFFF_F000_0000,
]);
const ZEROS: PackedGoldilocksAVX512 = PackedGoldilocksAVX512(Goldilocks::new_array([
0x0000_0000_0000_0000,
0xFFFF_FFFF_0000_0001,
0x0000_0000_0000_0000,
0xFFFF_FFFF_0000_0001,
0x0000_0000_0000_0000,
0xFFFF_FFFF_0000_0001,
0x0000_0000_0000_0000,
0xFFFF_FFFF_0000_0001,
]));
const ONES: PackedGoldilocksAVX512 = PackedGoldilocksAVX512(Goldilocks::new_array([
0x0000_0000_0000_0001,
0xFFFF_FFFF_0000_0002,
0x0000_0000_0000_0001,
0xFFFF_FFFF_0000_0002,
0x0000_0000_0000_0001,
0xFFFF_FFFF_0000_0002,
0x0000_0000_0000_0001,
0xFFFF_FFFF_0000_0002,
]));
test_packed_field!(
crate::PackedGoldilocksAVX512,
&[super::ZEROS],
&[super::ONES],
crate::PackedGoldilocksAVX512(super::SPECIAL_VALS)
);
}