#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
use super::field::FieldElement;
const LOW_25_BITS: i64 = (1 << 25) - 1;
const LOW_26_BITS: i64 = (1 << 26) - 1;
const D_BLEND: i32 = 0b1010_0000u8 as i32;
#[derive(Clone, Copy)]
#[cfg(target_arch = "x86_64")]
pub(crate) struct FieldElement2625x4(pub(crate) [__m256i; 5]);
#[derive(Clone, Copy)]
#[repr(u8)]
#[allow(clippy::upper_case_acronyms, dead_code)] pub(crate) enum Shuffle {
ABCD,
BADC,
BACD,
ABDC,
AAAA,
BBBB,
CACA,
DBBD,
ADDA,
CBCB,
ABAB,
}
impl Shuffle {
#[inline(always)]
fn control(self) -> [i32; 8] {
match self {
Self::ABCD => [0, 1, 2, 3, 4, 5, 6, 7],
Self::BADC => [1, 0, 3, 2, 5, 4, 7, 6],
Self::BACD => [1, 0, 3, 2, 4, 5, 6, 7],
Self::ABDC => [0, 1, 2, 3, 5, 4, 7, 6],
Self::AAAA => [0, 0, 2, 2, 0, 0, 2, 2],
Self::BBBB => [1, 1, 3, 3, 1, 1, 3, 3],
Self::CACA => [4, 0, 6, 2, 4, 0, 6, 2],
Self::DBBD => [5, 1, 7, 3, 1, 5, 3, 7],
Self::ADDA => [0, 5, 2, 7, 5, 0, 7, 2],
Self::CBCB => [4, 1, 6, 3, 4, 1, 6, 3],
Self::ABAB => [0, 1, 2, 3, 0, 1, 2, 3],
}
}
}
#[derive(Clone, Copy)]
#[repr(u8)]
#[allow(clippy::upper_case_acronyms, dead_code)] pub(crate) enum Lanes {
A = 0b0000_0101,
B = 0b0000_1010,
C = 0b0101_0000,
D = 0b1010_0000,
AB = 0b0000_1111,
AC = 0b0101_0101,
AD = 0b1010_0101,
BC = 0b0101_1010,
BCD = 0b1111_1010,
CD = 0b1111_0000,
ABCD = 0b1111_1111,
}
#[inline]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn unpack_pair(v: __m256i) -> (__m256i, __m256i) {
let zero = _mm256_setzero_si256();
let lo = _mm256_unpacklo_epi32(v, zero);
let hi = _mm256_unpackhi_epi32(v, zero);
(lo, hi)
}
#[inline]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn repack_pair(lo: __m256i, hi: __m256i) -> __m256i {
let lo_packed = _mm256_shuffle_epi32::<0b10_00_10_00>(lo);
let hi_packed = _mm256_shuffle_epi32::<0b10_00_10_00>(hi);
_mm256_blend_epi32::<0b1100_1100>(lo_packed, hi_packed)
}
#[inline]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn mul32(a: __m256i, b: __m256i) -> __m256i {
_mm256_mul_epu32(a, b)
}
#[inline]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn add64(a: __m256i, b: __m256i) -> __m256i {
_mm256_add_epi64(a, b)
}
#[cfg(target_arch = "x86_64")]
impl FieldElement2625x4 {
#[inline]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
pub(crate) unsafe fn zero() -> Self {
Self([_mm256_setzero_si256(); 5])
}
#[inline]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
pub(crate) unsafe fn new(a: &FieldElement, b: &FieldElement, c: &FieldElement, d: &FieldElement) -> Self {
let al = a.limbs();
let bl = b.limbs();
let cl = c.limbs();
let dl = d.limbs();
let mask = LOW_26_BITS as u64;
let out = [
Self::pack_limb_pair(al[0], bl[0], cl[0], dl[0], mask),
Self::pack_limb_pair(al[1], bl[1], cl[1], dl[1], mask),
Self::pack_limb_pair(al[2], bl[2], cl[2], dl[2], mask),
Self::pack_limb_pair(al[3], bl[3], cl[3], dl[3], mask),
Self::pack_limb_pair(al[4], bl[4], cl[4], dl[4], mask),
];
Self(out).reduce()
}
#[inline]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn pack_limb_pair(al: u64, bl: u64, cl: u64, dl: u64, mask: u64) -> __m256i {
_mm256_setr_epi32(
(al & mask) as i32,
(bl & mask) as i32,
(al >> 26) as i32,
(bl >> 26) as i32,
(cl & mask) as i32,
(dl & mask) as i32,
(cl >> 26) as i32,
(dl >> 26) as i32,
)
}
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
pub(crate) unsafe fn split(&self) -> [FieldElement; 4] {
let mut al = [0u64; 5];
let mut bl = [0u64; 5];
let mut cl = [0u64; 5];
let mut dl = [0u64; 5];
for ((((a_out, b_out), c_out), d_out), vec) in al
.iter_mut()
.zip(bl.iter_mut())
.zip(cl.iter_mut())
.zip(dl.iter_mut())
.zip(self.0.iter())
{
let mut tmp = [0u32; 8];
_mm256_storeu_si256(tmp.as_mut_ptr().cast(), *vec);
*a_out = u64::from(tmp[0]) | (u64::from(tmp[2]) << 26);
*b_out = u64::from(tmp[1]) | (u64::from(tmp[3]) << 26);
*c_out = u64::from(tmp[4]) | (u64::from(tmp[6]) << 26);
*d_out = u64::from(tmp[5]) | (u64::from(tmp[7]) << 26);
}
[
FieldElement::from_limbs(al),
FieldElement::from_limbs(bl),
FieldElement::from_limbs(cl),
FieldElement::from_limbs(dl),
]
}
#[inline]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
pub(crate) unsafe fn add(&self, rhs: &Self) -> Self {
Self([
_mm256_add_epi32(self.0[0], rhs.0[0]),
_mm256_add_epi32(self.0[1], rhs.0[1]),
_mm256_add_epi32(self.0[2], rhs.0[2]),
_mm256_add_epi32(self.0[3], rhs.0[3]),
_mm256_add_epi32(self.0[4], rhs.0[4]),
])
}
#[inline]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
pub(crate) unsafe fn sub(&self, rhs: &Self) -> Self {
let bias_0 = _mm256_setr_epi32(
(2 * ((1i64 << 26) - 19)) as i32,
(2 * ((1i64 << 26) - 19)) as i32,
(2 * ((1i64 << 25) - 1)) as i32,
(2 * ((1i64 << 25) - 1)) as i32,
(2 * ((1i64 << 26) - 19)) as i32,
(2 * ((1i64 << 26) - 19)) as i32,
(2 * ((1i64 << 25) - 1)) as i32,
(2 * ((1i64 << 25) - 1)) as i32,
);
let bias_n = _mm256_setr_epi32(
(2 * ((1i64 << 26) - 1)) as i32,
(2 * ((1i64 << 26) - 1)) as i32,
(2 * ((1i64 << 25) - 1)) as i32,
(2 * ((1i64 << 25) - 1)) as i32,
(2 * ((1i64 << 26) - 1)) as i32,
(2 * ((1i64 << 26) - 1)) as i32,
(2 * ((1i64 << 25) - 1)) as i32,
(2 * ((1i64 << 25) - 1)) as i32,
);
Self([
_mm256_sub_epi32(_mm256_add_epi32(self.0[0], bias_0), rhs.0[0]),
_mm256_sub_epi32(_mm256_add_epi32(self.0[1], bias_n), rhs.0[1]),
_mm256_sub_epi32(_mm256_add_epi32(self.0[2], bias_n), rhs.0[2]),
_mm256_sub_epi32(_mm256_add_epi32(self.0[3], bias_n), rhs.0[3]),
_mm256_sub_epi32(_mm256_add_epi32(self.0[4], bias_n), rhs.0[4]),
])
}
#[inline]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
pub(crate) unsafe fn negate_lazy(&self) -> Self {
Self::zero().sub(self)
}
#[inline]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
pub(crate) unsafe fn shuffle(&self, pattern: Shuffle) -> Self {
let ctrl = pattern.control();
let c = _mm256_setr_epi32(ctrl[0], ctrl[1], ctrl[2], ctrl[3], ctrl[4], ctrl[5], ctrl[6], ctrl[7]);
Self([
_mm256_permutevar8x32_epi32(self.0[0], c),
_mm256_permutevar8x32_epi32(self.0[1], c),
_mm256_permutevar8x32_epi32(self.0[2], c),
_mm256_permutevar8x32_epi32(self.0[3], c),
_mm256_permutevar8x32_epi32(self.0[4], c),
])
}
#[inline]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
pub(crate) unsafe fn blend(&self, other: &Self, lanes: Lanes) -> Self {
macro_rules! do_blend {
($imm:expr) => {
Self([
_mm256_blend_epi32::<$imm>(self.0[0], other.0[0]),
_mm256_blend_epi32::<$imm>(self.0[1], other.0[1]),
_mm256_blend_epi32::<$imm>(self.0[2], other.0[2]),
_mm256_blend_epi32::<$imm>(self.0[3], other.0[3]),
_mm256_blend_epi32::<$imm>(self.0[4], other.0[4]),
])
};
}
match lanes {
Lanes::A => do_blend!(0b0000_0101),
Lanes::B => do_blend!(0b0000_1010),
Lanes::C => do_blend!(0b0101_0000),
Lanes::D => do_blend!(0b1010_0000),
Lanes::AB => do_blend!(0b0000_1111),
Lanes::AC => do_blend!(0b0101_0101),
Lanes::AD => do_blend!(0b1010_0101),
Lanes::BC => do_blend!(0b0101_1010),
Lanes::BCD => do_blend!(0b1111_1010),
Lanes::CD => do_blend!(0b1111_0000),
Lanes::ABCD => do_blend!(0b1111_1111),
}
}
#[inline]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
pub(crate) unsafe fn diff_sum(&self) -> Self {
let swapped = self.shuffle(Shuffle::BADC); let negated = self.negate_lazy(); let neg_ac = self.blend(&negated, Lanes::AC); swapped.add(&neg_ac) }
#[inline]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
pub(crate) unsafe fn reduce(&self) -> Self {
let (z0, z1) = unpack_pair(self.0[0]);
let (z2, z3) = unpack_pair(self.0[1]);
let (z4, z5) = unpack_pair(self.0[2]);
let (z6, z7) = unpack_pair(self.0[3]);
let (z8, z9) = unpack_pair(self.0[4]);
let mut z = [z0, z1, z2, z3, z4, z5, z6, z7, z8, z9];
Self::reduce64(&mut z)
}
#[inline]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn reduce64(z: &mut [__m256i; 10]) -> Self {
let mask_26 = _mm256_set1_epi64x(LOW_26_BITS);
let mask_25 = _mm256_set1_epi64x(LOW_25_BITS);
let v19 = _mm256_set1_epi64x(19);
macro_rules! carry_even {
($i:expr) => {
let carry = _mm256_srli_epi64::<26>(z[$i]);
z[$i] = _mm256_and_si256(z[$i], mask_26);
z[$i + 1] = add64(z[$i + 1], carry);
};
}
macro_rules! carry_odd {
($i:expr) => {
let carry = _mm256_srli_epi64::<25>(z[$i]);
z[$i] = _mm256_and_si256(z[$i], mask_25);
z[$i + 1] = add64(z[$i + 1], carry);
};
}
carry_even!(0);
carry_even!(4);
carry_odd!(1);
carry_odd!(5);
carry_even!(2);
carry_even!(6);
carry_odd!(3);
carry_odd!(7);
carry_even!(4);
carry_even!(8);
let carry9 = _mm256_srli_epi64::<25>(z[9]);
z[9] = _mm256_and_si256(z[9], mask_25);
let c0 = _mm256_and_si256(carry9, mask_26);
let c1 = _mm256_srli_epi64::<26>(carry9);
z[0] = add64(z[0], _mm256_mul_epu32(c0, v19));
z[1] = add64(z[1], _mm256_mul_epu32(c1, v19));
carry_even!(0);
Self([
repack_pair(z[0], z[1]),
repack_pair(z[2], z[3]),
repack_pair(z[4], z[5]),
repack_pair(z[6], z[7]),
repack_pair(z[8], z[9]),
])
}
#[inline]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
pub(crate) unsafe fn mul(&self, rhs: &Self) -> Self {
let v19 = _mm256_set1_epi64x(19);
let (x0, x1) = unpack_pair(self.0[0]);
let (x2, x3) = unpack_pair(self.0[1]);
let (x4, x5) = unpack_pair(self.0[2]);
let (x6, x7) = unpack_pair(self.0[3]);
let (x8, x9) = unpack_pair(self.0[4]);
let x1_2 = add64(x1, x1);
let x3_2 = add64(x3, x3);
let x5_2 = add64(x5, x5);
let x7_2 = add64(x7, x7);
let x9_2 = add64(x9, x9);
let (y0, y1) = unpack_pair(rhs.0[0]);
let (y2, y3) = unpack_pair(rhs.0[1]);
let (y4, y5) = unpack_pair(rhs.0[2]);
let (y6, y7) = unpack_pair(rhs.0[3]);
let (y8, y9) = unpack_pair(rhs.0[4]);
let y1_19 = mul32(y1, v19);
let y2_19 = mul32(y2, v19);
let y3_19 = mul32(y3, v19);
let y4_19 = mul32(y4, v19);
let y5_19 = mul32(y5, v19);
let y6_19 = mul32(y6, v19);
let y7_19 = mul32(y7, v19);
let y8_19 = mul32(y8, v19);
let y9_19 = mul32(y9, v19);
let z0 = add64(
add64(
add64(mul32(x0, y0), mul32(x1_2, y9_19)),
add64(mul32(x2, y8_19), mul32(x3_2, y7_19)),
),
add64(
add64(mul32(x4, y6_19), mul32(x5_2, y5_19)),
add64(
add64(mul32(x6, y4_19), mul32(x7_2, y3_19)),
add64(mul32(x8, y2_19), mul32(x9_2, y1_19)),
),
),
);
let z1 = add64(
add64(
add64(mul32(x0, y1), mul32(x1, y0)),
add64(mul32(x2, y9_19), mul32(x3, y8_19)),
),
add64(
add64(mul32(x4, y7_19), mul32(x5, y6_19)),
add64(
add64(mul32(x6, y5_19), mul32(x7, y4_19)),
add64(mul32(x8, y3_19), mul32(x9, y2_19)),
),
),
);
let z2 = add64(
add64(
add64(mul32(x0, y2), mul32(x1_2, y1)),
add64(mul32(x2, y0), mul32(x3_2, y9_19)),
),
add64(
add64(mul32(x4, y8_19), mul32(x5_2, y7_19)),
add64(
add64(mul32(x6, y6_19), mul32(x7_2, y5_19)),
add64(mul32(x8, y4_19), mul32(x9_2, y3_19)),
),
),
);
let z3 = add64(
add64(add64(mul32(x0, y3), mul32(x1, y2)), add64(mul32(x2, y1), mul32(x3, y0))),
add64(
add64(mul32(x4, y9_19), mul32(x5, y8_19)),
add64(
add64(mul32(x6, y7_19), mul32(x7, y6_19)),
add64(mul32(x8, y5_19), mul32(x9, y4_19)),
),
),
);
let z4 = add64(
add64(
add64(mul32(x0, y4), mul32(x1_2, y3)),
add64(mul32(x2, y2), mul32(x3_2, y1)),
),
add64(
add64(mul32(x4, y0), mul32(x5_2, y9_19)),
add64(
add64(mul32(x6, y8_19), mul32(x7_2, y7_19)),
add64(mul32(x8, y6_19), mul32(x9_2, y5_19)),
),
),
);
let z5 = add64(
add64(add64(mul32(x0, y5), mul32(x1, y4)), add64(mul32(x2, y3), mul32(x3, y2))),
add64(
add64(mul32(x4, y1), mul32(x5, y0)),
add64(
add64(mul32(x6, y9_19), mul32(x7, y8_19)),
add64(mul32(x8, y7_19), mul32(x9, y6_19)),
),
),
);
let z6 = add64(
add64(
add64(mul32(x0, y6), mul32(x1_2, y5)),
add64(mul32(x2, y4), mul32(x3_2, y3)),
),
add64(
add64(mul32(x4, y2), mul32(x5_2, y1)),
add64(
add64(mul32(x6, y0), mul32(x7_2, y9_19)),
add64(mul32(x8, y8_19), mul32(x9_2, y7_19)),
),
),
);
let z7 = add64(
add64(add64(mul32(x0, y7), mul32(x1, y6)), add64(mul32(x2, y5), mul32(x3, y4))),
add64(
add64(mul32(x4, y3), mul32(x5, y2)),
add64(
add64(mul32(x6, y1), mul32(x7, y0)),
add64(mul32(x8, y9_19), mul32(x9, y8_19)),
),
),
);
let z8 = add64(
add64(
add64(mul32(x0, y8), mul32(x1_2, y7)),
add64(mul32(x2, y6), mul32(x3_2, y5)),
),
add64(
add64(mul32(x4, y4), mul32(x5_2, y3)),
add64(
add64(mul32(x6, y2), mul32(x7_2, y1)),
add64(mul32(x8, y0), mul32(x9_2, y9_19)),
),
),
);
let z9 = add64(
add64(add64(mul32(x0, y9), mul32(x1, y8)), add64(mul32(x2, y7), mul32(x3, y6))),
add64(
add64(mul32(x4, y5), mul32(x5, y4)),
add64(add64(mul32(x6, y3), mul32(x7, y2)), add64(mul32(x8, y1), mul32(x9, y0))),
),
);
let mut z = [z0, z1, z2, z3, z4, z5, z6, z7, z8, z9];
Self::reduce64(&mut z)
}
#[inline]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn square_accum(&self) -> [__m256i; 10] {
let v19 = _mm256_set1_epi64x(19);
let (x0, x1) = unpack_pair(self.0[0]);
let (x2, x3) = unpack_pair(self.0[1]);
let (x4, x5) = unpack_pair(self.0[2]);
let (x6, x7) = unpack_pair(self.0[3]);
let (x8, x9) = unpack_pair(self.0[4]);
let x0_2 = add64(x0, x0);
let x1_2 = add64(x1, x1);
let x2_2 = add64(x2, x2);
let x3_2 = add64(x3, x3);
let x4_2 = add64(x4, x4);
let x5_2 = add64(x5, x5);
let x6_2 = add64(x6, x6);
let x7_2 = add64(x7, x7);
let x5_19 = mul32(x5, v19);
let x6_19 = mul32(x6, v19);
let x7_19 = mul32(x7, v19);
let x8_19 = mul32(x8, v19);
let x9_19 = mul32(x9, v19);
let x1_4 = _mm256_slli_epi64::<2>(x1);
let x3_4 = _mm256_slli_epi64::<2>(x3);
let x5_4 = _mm256_slli_epi64::<2>(x5);
let x7_4 = _mm256_slli_epi64::<2>(x7);
let z0 = add64(
add64(
add64(mul32(x0, x0), mul32(x1_4, x9_19)),
add64(mul32(x2_2, x8_19), mul32(x3_4, x7_19)),
),
add64(mul32(x4_2, x6_19), mul32(x5_2, x5_19)),
);
let z1 = add64(
add64(mul32(x0_2, x1), mul32(x2_2, x9_19)),
add64(add64(mul32(x3_2, x8_19), mul32(x4_2, x7_19)), mul32(x5_2, x6_19)),
);
let z2 = add64(
add64(
add64(mul32(x0_2, x2), mul32(x1_2, x1)),
add64(mul32(x3_4, x9_19), mul32(x4_2, x8_19)),
),
add64(mul32(x5_4, x7_19), mul32(x6, x6_19)),
);
let z3 = add64(
add64(mul32(x0_2, x3), mul32(x1_2, x2)),
add64(add64(mul32(x4_2, x9_19), mul32(x5_2, x8_19)), mul32(x6_2, x7_19)),
);
let z4 = add64(
add64(
add64(mul32(x0_2, x4), mul32(x1_4, x3)),
add64(mul32(x2, x2), mul32(x5_4, x9_19)),
),
add64(mul32(x6_2, x8_19), mul32(x7_2, x7_19)),
);
let z5 = add64(
add64(mul32(x0_2, x5), mul32(x1_2, x4)),
add64(mul32(x2_2, x3), add64(mul32(x6_2, x9_19), mul32(x7_2, x8_19))),
);
let z6 = add64(
add64(
add64(mul32(x0_2, x6), mul32(x1_4, x5)),
add64(mul32(x2_2, x4), mul32(x3_2, x3)),
),
add64(mul32(x7_4, x9_19), mul32(x8, x8_19)),
);
let x8_2 = add64(x8, x8);
let z7 = add64(
add64(mul32(x0_2, x7), mul32(x1_2, x6)),
add64(add64(mul32(x2_2, x5), mul32(x3_2, x4)), mul32(x8_2, x9_19)),
);
let x9_2 = add64(x9, x9);
let z8 = add64(
add64(
add64(mul32(x0_2, x8), mul32(x1_4, x7)),
add64(mul32(x2_2, x6), mul32(x3_4, x5)),
),
add64(mul32(x4, x4), mul32(x9_2, x9_19)),
);
let z9 = add64(
add64(mul32(x0_2, x9), mul32(x1_2, x8)),
add64(add64(mul32(x2_2, x7), mul32(x3_2, x6)), mul32(x4_2, x5)),
);
[z0, z1, z2, z3, z4, z5, z6, z7, z8, z9]
}
#[cfg(test)]
#[inline]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
pub(crate) unsafe fn square(&self) -> Self {
let mut z = self.square_accum();
Self::reduce64(&mut z)
}
#[inline]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
pub(crate) unsafe fn square_and_negate_d(&self) -> Self {
let mut z = self.square_accum();
Self::negate_d_accum(&mut z);
Self::reduce64(&mut z)
}
#[inline]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn negate_d_accum(z: &mut [__m256i; 10]) {
let bias_even_0 = _mm256_set1_epi64x(((1i64 << 26) - 19) << 37);
let bias_even = _mm256_set1_epi64x(((1i64 << 26) - 1) << 37);
let bias_odd = _mm256_set1_epi64x(((1i64 << 25) - 1) << 37);
const D_U64: i32 = 0b1100_0000;
macro_rules! neg_d {
($idx:expr, $bias:expr) => {
let negated = _mm256_sub_epi64($bias, z[$idx]);
z[$idx] = _mm256_blend_epi32::<D_U64>(z[$idx], negated);
};
}
neg_d!(0, bias_even_0);
neg_d!(1, bias_odd);
neg_d!(2, bias_even);
neg_d!(3, bias_odd);
neg_d!(4, bias_even);
neg_d!(5, bias_odd);
neg_d!(6, bias_even);
neg_d!(7, bias_odd);
neg_d!(8, bias_even);
neg_d!(9, bias_odd);
}
#[inline]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn, dead_code)]
unsafe fn negate_d_lane(fe: &mut Self) {
let p2_limb0_even = (2i64.wrapping_mul((1i64 << 26) - 19)) as i32;
let p2_limb_even = (2i64.wrapping_mul((1i64 << 26) - 1)) as i32;
let p2_limb_odd = (2i64.wrapping_mul((1i64 << 25) - 1)) as i32;
let bias_0 = _mm256_setr_epi32(0, 0, 0, 0, 0, p2_limb0_even, 0, p2_limb_odd);
let bias_n = _mm256_setr_epi32(0, 0, 0, 0, 0, p2_limb_even, 0, p2_limb_odd);
let neg0 = _mm256_sub_epi32(bias_0, fe.0[0]);
fe.0[0] = _mm256_blend_epi32::<D_BLEND>(fe.0[0], neg0);
let neg1 = _mm256_sub_epi32(bias_n, fe.0[1]);
fe.0[1] = _mm256_blend_epi32::<D_BLEND>(fe.0[1], neg1);
let neg2 = _mm256_sub_epi32(bias_n, fe.0[2]);
fe.0[2] = _mm256_blend_epi32::<D_BLEND>(fe.0[2], neg2);
let neg3 = _mm256_sub_epi32(bias_n, fe.0[3]);
fe.0[3] = _mm256_blend_epi32::<D_BLEND>(fe.0[3], neg3);
let neg4 = _mm256_sub_epi32(bias_n, fe.0[4]);
fe.0[4] = _mm256_blend_epi32::<D_BLEND>(fe.0[4], neg4);
}
}
#[cfg(test)]
#[cfg(target_arch = "x86_64")]
mod tests {
use super::{FieldElement, *};
fn avx512ifma_available_for_tests() -> bool {
!cfg!(miri) && std::arch::is_x86_feature_detected!("avx512ifma")
}
fn test_field_elements() -> [FieldElement; 4] {
let a = FieldElement::from_limbs([
1_234_567_890_123,
987_654_321_012,
111_222_333_444,
555_666_777_888,
999_000_111_222,
]);
let b = FieldElement::from_limbs([
2_111_222_333_444,
1_555_666_777_888,
333_444_555_666,
777_888_999_000,
100_200_300_400,
]);
let c = FieldElement::from_limbs([
42_000_000_001,
123_456_789_012,
2_000_000_000_000,
1_500_000_000_000,
750_000_000_000,
]);
let d = FieldElement::from_limbs([
1_999_999_999_999,
888_777_666_555,
444_333_222_111,
1_111_222_333_444,
2_222_111_000_999,
]);
[a, b, c, d]
}
fn small_field_elements() -> [FieldElement; 4] {
[
FieldElement::from_limbs([100, 200, 300, 400, 500]),
FieldElement::from_limbs([600, 700, 800, 900, 1000]),
FieldElement::from_limbs([1100, 1200, 1300, 1400, 1500]),
FieldElement::from_limbs([1600, 1700, 1800, 1900, 2000]),
]
}
#[test]
fn pack_unpack_roundtrip() {
if !std::arch::is_x86_feature_detected!("avx2") {
return;
}
let [a, b, c, d] = test_field_elements();
unsafe {
let packed = FieldElement2625x4::new(&a, &b, &c, &d);
let [ra, rb, rc, rd] = packed.split();
assert_eq!(ra.normalize(), a.normalize(), "A roundtrip failed");
assert_eq!(rb.normalize(), b.normalize(), "B roundtrip failed");
assert_eq!(rc.normalize(), c.normalize(), "C roundtrip failed");
assert_eq!(rd.normalize(), d.normalize(), "D roundtrip failed");
}
}
#[test]
fn add_matches_scalar() {
if !std::arch::is_x86_feature_detected!("avx2") {
return;
}
let [a, b, c, d] = test_field_elements();
let [e, f, g, h] = small_field_elements();
unsafe {
let lhs = FieldElement2625x4::new(&a, &b, &c, &d);
let rhs = FieldElement2625x4::new(&e, &f, &g, &h);
let sum = lhs.add(&rhs).reduce();
let [ra, rb, rc, rd] = sum.split();
assert_eq!(ra.normalize(), a.add(&e).normalize(), "A add failed");
assert_eq!(rb.normalize(), b.add(&f).normalize(), "B add failed");
assert_eq!(rc.normalize(), c.add(&g).normalize(), "C add failed");
assert_eq!(rd.normalize(), d.add(&h).normalize(), "D add failed");
}
}
#[test]
fn sub_matches_scalar() {
if !std::arch::is_x86_feature_detected!("avx2") {
return;
}
let [a, b, c, d] = test_field_elements();
let [e, f, g, h] = small_field_elements();
unsafe {
let lhs = FieldElement2625x4::new(&a, &b, &c, &d);
let rhs = FieldElement2625x4::new(&e, &f, &g, &h);
let diff = lhs.sub(&rhs).reduce();
let [ra, rb, rc, rd] = diff.split();
assert_eq!(ra.normalize(), a.sub(&e).normalize(), "A sub failed");
assert_eq!(rb.normalize(), b.sub(&f).normalize(), "B sub failed");
assert_eq!(rc.normalize(), c.sub(&g).normalize(), "C sub failed");
assert_eq!(rd.normalize(), d.sub(&h).normalize(), "D sub failed");
}
}
#[test]
fn mul_matches_scalar() {
if !std::arch::is_x86_feature_detected!("avx2") {
return;
}
let [a, b, c, d] = test_field_elements();
let [e, f, g, h] = small_field_elements();
unsafe {
let lhs = FieldElement2625x4::new(&a, &b, &c, &d);
let rhs = FieldElement2625x4::new(&e, &f, &g, &h);
let product = lhs.mul(&rhs);
let [ra, rb, rc, rd] = product.split();
assert_eq!(ra.normalize(), a.mul(&e).normalize(), "A mul failed");
assert_eq!(rb.normalize(), b.mul(&f).normalize(), "B mul failed");
assert_eq!(rc.normalize(), c.mul(&g).normalize(), "C mul failed");
assert_eq!(rd.normalize(), d.mul(&h).normalize(), "D mul failed");
}
}
#[test]
fn square_matches_scalar() {
if !std::arch::is_x86_feature_detected!("avx2") {
return;
}
let [a, b, c, d] = test_field_elements();
unsafe {
let packed = FieldElement2625x4::new(&a, &b, &c, &d);
let squared = packed.square();
let [ra, rb, rc, rd] = squared.split();
assert_eq!(ra.normalize(), a.square().normalize(), "A square failed");
assert_eq!(rb.normalize(), b.square().normalize(), "B square failed");
assert_eq!(rc.normalize(), c.square().normalize(), "C square failed");
assert_eq!(rd.normalize(), d.square().normalize(), "D square failed");
}
}
#[test]
fn square_matches_mul_self() {
if !std::arch::is_x86_feature_detected!("avx2") {
return;
}
let [a, b, c, d] = test_field_elements();
unsafe {
let packed = FieldElement2625x4::new(&a, &b, &c, &d);
let sq = packed.square();
let mm = packed.mul(&packed);
let [sa, sb, sc, sd] = sq.split();
let [ma, mb, mc, md] = mm.split();
assert_eq!(sa.normalize(), ma.normalize(), "A square vs mul mismatch");
assert_eq!(sb.normalize(), mb.normalize(), "B square vs mul mismatch");
assert_eq!(sc.normalize(), mc.normalize(), "C square vs mul mismatch");
assert_eq!(sd.normalize(), md.normalize(), "D square vs mul mismatch");
}
}
#[test]
fn square_and_negate_d_matches_scalar() {
if !std::arch::is_x86_feature_detected!("avx2") {
return;
}
let [a, b, c, d] = test_field_elements();
unsafe {
let packed = FieldElement2625x4::new(&a, &b, &c, &d);
let result = packed.square_and_negate_d();
let [ra, rb, rc, rd] = result.split();
assert_eq!(ra.normalize(), a.square().normalize(), "A square_neg_d failed");
assert_eq!(rb.normalize(), b.square().normalize(), "B square_neg_d failed");
assert_eq!(rc.normalize(), c.square().normalize(), "C square_neg_d failed");
assert_eq!(rd.normalize(), d.square().neg().normalize(), "D should be negated");
}
}
#[test]
fn shuffle_badc_swaps_pairs() {
if !std::arch::is_x86_feature_detected!("avx2") {
return;
}
let [a, b, c, d] = test_field_elements();
unsafe {
let packed = FieldElement2625x4::new(&a, &b, &c, &d);
let shuffled = packed.shuffle(Shuffle::BADC);
let [ra, rb, rc, rd] = shuffled.split();
assert_eq!(ra.normalize(), b.normalize(), "BADC: A should be B");
assert_eq!(rb.normalize(), a.normalize(), "BADC: B should be A");
assert_eq!(rc.normalize(), d.normalize(), "BADC: C should be D");
assert_eq!(rd.normalize(), c.normalize(), "BADC: D should be C");
}
}
#[test]
fn diff_sum_matches_scalar() {
if !std::arch::is_x86_feature_detected!("avx2") {
return;
}
let [a, b, c, d] = small_field_elements();
unsafe {
let packed = FieldElement2625x4::new(&a, &b, &c, &d);
let ds = packed.diff_sum().reduce();
let [r0, r1, r2, r3] = ds.split();
assert_eq!(r0.normalize(), b.sub(&a).normalize(), "diff_sum[0] = B-A");
assert_eq!(r1.normalize(), a.add(&b).normalize(), "diff_sum[1] = A+B");
assert_eq!(r2.normalize(), d.sub(&c).normalize(), "diff_sum[2] = D-C");
assert_eq!(r3.normalize(), c.add(&d).normalize(), "diff_sum[3] = C+D");
}
}
#[test]
fn blend_ab_cd() {
if !std::arch::is_x86_feature_detected!("avx2") {
return;
}
let [a, b, c, d] = test_field_elements();
let [e, f, g, h] = [
FieldElement::from_limbs([1, 2, 3, 4, 5]),
FieldElement::from_limbs([6, 7, 8, 9, 10]),
FieldElement::from_limbs([11, 12, 13, 14, 15]),
FieldElement::from_limbs([16, 17, 18, 19, 20]),
];
unsafe {
let lhs = FieldElement2625x4::new(&a, &b, &c, &d);
let rhs = FieldElement2625x4::new(&e, &f, &g, &h);
let blended = lhs.blend(&rhs, Lanes::AB);
let [ra, rb, rc, rd] = blended.split();
assert_eq!(ra.normalize(), e.normalize(), "AB blend: A should be from rhs");
assert_eq!(rb.normalize(), f.normalize(), "AB blend: B should be from rhs");
assert_eq!(rc.normalize(), c.normalize(), "AB blend: C should be from self");
assert_eq!(rd.normalize(), d.normalize(), "AB blend: D should be from self");
}
}
#[test]
fn ifma_mul_matches_avx2() {
if !std::arch::is_x86_feature_detected!("avx2") || !avx512ifma_available_for_tests() {
return;
}
let [a, b, c, d] = test_field_elements();
let [e, f, g, h] = small_field_elements();
unsafe {
use super::super::field_ifma::FieldElement51x4;
let avx2_lhs = FieldElement2625x4::new(&a, &b, &c, &d);
let avx2_rhs = FieldElement2625x4::new(&e, &f, &g, &h);
let avx2_result = avx2_lhs.mul(&avx2_rhs);
let [a2, b2, c2, d2] = avx2_result.split();
let ifma_lhs = FieldElement51x4::new(&a, &b, &c, &d);
let ifma_rhs = FieldElement51x4::new(&e, &f, &g, &h);
let ifma_result = ifma_lhs.mul(&ifma_rhs);
let [ai, bi, ci, di] = ifma_result.split();
assert_eq!(a2.normalize(), ai.normalize(), "IFMA mul A mismatch");
assert_eq!(b2.normalize(), bi.normalize(), "IFMA mul B mismatch");
assert_eq!(c2.normalize(), ci.normalize(), "IFMA mul C mismatch");
assert_eq!(d2.normalize(), di.normalize(), "IFMA mul D mismatch");
}
}
#[test]
fn ifma_square_matches_avx2() {
if !std::arch::is_x86_feature_detected!("avx2") || !avx512ifma_available_for_tests() {
return;
}
let [a, b, c, d] = test_field_elements();
unsafe {
use super::super::field_ifma::FieldElement51x4;
let avx2_packed = FieldElement2625x4::new(&a, &b, &c, &d);
let avx2_result = avx2_packed.square();
let [a2, b2, c2, d2] = avx2_result.split();
let ifma_packed = FieldElement51x4::new(&a, &b, &c, &d);
let ifma_result = ifma_packed.square();
let [ai, bi, ci, di] = ifma_result.split();
assert_eq!(a2.normalize(), ai.normalize(), "IFMA square A mismatch");
assert_eq!(b2.normalize(), bi.normalize(), "IFMA square B mismatch");
assert_eq!(c2.normalize(), ci.normalize(), "IFMA square C mismatch");
assert_eq!(d2.normalize(), di.normalize(), "IFMA square D mismatch");
}
}
}