#[cfg(target_feature = "avx512f")]
use core::arch::x86_64::__m512i;
use core::arch::x86_64::{self, __m128i, __m256i};
use core::mem::transmute;
#[inline(always)]
#[must_use]
fn mm128_mod_add(a: __m128i, b: __m128i, p: __m128i) -> __m128i {
unsafe {
let t = x86_64::_mm_add_epi32(a, b);
let u = x86_64::_mm_sub_epi32(t, p);
x86_64::_mm_min_epu32(t, u)
}
}
#[inline(always)]
#[must_use]
fn mm128_mod_sub(a: __m128i, b: __m128i, p: __m128i) -> __m128i {
unsafe {
let t = x86_64::_mm_sub_epi32(a, b);
let u = x86_64::_mm_add_epi32(t, p);
x86_64::_mm_min_epu32(t, u)
}
}
#[inline(always)]
#[must_use]
pub fn mm256_mod_add(lhs: __m256i, rhs: __m256i, p: __m256i) -> __m256i {
unsafe {
let t = x86_64::_mm256_add_epi32(lhs, rhs);
let u = x86_64::_mm256_sub_epi32(t, p);
x86_64::_mm256_min_epu32(t, u)
}
}
#[inline(always)]
#[must_use]
pub fn mm256_mod_sub(lhs: __m256i, rhs: __m256i, p: __m256i) -> __m256i {
unsafe {
let t = x86_64::_mm256_sub_epi32(lhs, rhs);
let u = x86_64::_mm256_add_epi32(t, p);
x86_64::_mm256_min_epu32(t, u)
}
}
#[cfg(target_feature = "avx512f")]
#[inline(always)]
#[must_use]
pub fn mm512_mod_add(lhs: __m512i, rhs: __m512i, p: __m512i) -> __m512i {
unsafe {
let t = x86_64::_mm512_add_epi32(lhs, rhs);
let u = x86_64::_mm512_sub_epi32(t, p);
x86_64::_mm512_min_epu32(t, u)
}
}
#[cfg(target_feature = "avx512f")]
#[inline(always)]
#[must_use]
pub fn mm512_mod_sub(lhs: __m512i, rhs: __m512i, p: __m512i) -> __m512i {
unsafe {
let t = x86_64::_mm512_sub_epi32(lhs, rhs);
let u = x86_64::_mm512_add_epi32(t, p);
x86_64::_mm512_min_epu32(t, u)
}
}
#[inline(always)]
pub fn packed_mod_add<const WIDTH: usize>(
a: &[u32; WIDTH],
b: &[u32; WIDTH],
res: &mut [u32; WIDTH],
p: u32,
scalar_add: fn(u32, u32) -> u32,
) {
match WIDTH {
1 => res[0] = scalar_add(a[0], b[0]),
4 => {
let out: [u32; 4] = unsafe {
let a: __m128i = transmute([a[0], a[1], a[2], a[3]]);
let b: __m128i = transmute([b[0], b[1], b[2], b[3]]);
let p: __m128i = x86_64::_mm_set1_epi32(p as i32);
transmute(mm128_mod_add(a, b, p))
};
res.copy_from_slice(&out);
}
5 => {
let out: [u32; 4] = unsafe {
let a: __m128i = transmute([a[0], a[1], a[2], a[3]]);
let b: __m128i = transmute([b[0], b[1], b[2], b[3]]);
let p: __m128i = x86_64::_mm_set1_epi32(p as i32);
transmute(mm128_mod_add(a, b, p))
};
res[4] = scalar_add(a[4], b[4]);
res[..4].copy_from_slice(&out[..4]);
}
8 => {
let out: [u32; 8] = unsafe {
let a: __m256i = transmute([a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7]]);
let b: __m256i = transmute([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]);
let p: __m256i = x86_64::_mm256_set1_epi32(p as i32);
transmute(mm256_mod_add(a, b, p))
};
res.copy_from_slice(&out);
}
_ => panic!("Currently unsupported width for packed addition."),
}
}
#[inline(always)]
pub fn packed_mod_sub<const WIDTH: usize>(
a: &[u32; WIDTH],
b: &[u32; WIDTH],
res: &mut [u32; WIDTH],
p: u32,
scalar_sub: fn(u32, u32) -> u32,
) {
match WIDTH {
1 => res[0] = scalar_sub(a[0], b[0]),
4 => {
let out: [u32; 4] = unsafe {
let a: __m128i = transmute([a[0], a[1], a[2], a[3]]);
let b: __m128i = transmute([b[0], b[1], b[2], b[3]]);
let p: __m128i = x86_64::_mm_set1_epi32(p as i32);
transmute(mm128_mod_sub(a, b, p))
};
res.copy_from_slice(&out);
}
5 => {
let out: [u32; 4] = unsafe {
let a: __m128i = transmute([a[0], a[1], a[2], a[3]]);
let b: __m128i = transmute([b[0], b[1], b[2], b[3]]);
let p: __m128i = x86_64::_mm_set1_epi32(p as i32);
transmute(mm128_mod_sub(a, b, p))
};
res[4] = scalar_sub(a[4], b[4]);
res[..4].copy_from_slice(&out[..4]);
}
8 => {
let out: [u32; 8] = unsafe {
let a: __m256i = transmute([a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7]]);
let b: __m256i = transmute([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]);
let p: __m256i = x86_64::_mm256_set1_epi32(p as i32);
transmute(mm256_mod_sub(a, b, p))
};
res.copy_from_slice(&out);
}
_ => panic!("Currently unsupported width for packed subtraction."),
}
}