use crate::{
rng::{Rng32, Rng64},
rng64::SplitMix64,
};
use std::arch::x86_64::*;
pub struct Squares32 {
pub(crate) c: u64,
pub(crate) k: u64,
}
impl Squares32 {
#[inline]
pub fn new(seed: u64) -> Self {
let mut seedgen = SplitMix64::new(seed | 1);
Self {
c: 0,
k: seedgen.nextu(),
}
}
#[inline(always)]
pub fn compute_yz(y: u64, z: u64) -> u32 {
let mut x: u64;
x = y.wrapping_mul(y).wrapping_add(y);
x = x.rotate_left(32);
x = x.wrapping_mul(x).wrapping_add(z);
x = x.rotate_left(32);
x = x.wrapping_mul(x).wrapping_add(y);
x = x.rotate_left(32);
(x.wrapping_mul(x).wrapping_add(z) >> 32) as u32
}
#[inline(always)]
pub fn compute(ctr: u64, key: u64) -> u32 {
let y = ctr.wrapping_mul(key);
let z = y.wrapping_add(key);
Self::compute_yz(y, z)
}
}
impl Rng32 for Squares32 {
#[inline(always)]
fn nextu(&mut self) -> u32 {
let out = Self::compute(self.c, self.k);
self.c = self.c.wrapping_add(1);
out
}
}
#[allow(non_upper_case_globals)]
pub const SQUARES32x8: usize = 8;
#[cfg(target_arch = "x86_64")]
#[repr(C)]
#[repr(align(64))]
pub struct Squares32x8 {
pub c: __m512i,
pub k: __m512i,
}
#[cfg(target_arch = "x86_64")]
impl Squares32x8 {
#[target_feature(enable = "avx512f")]
pub unsafe fn new(seed: u32) -> Self {
let mut k = [0u64; SQUARES32x8];
let mut seedgen = SplitMix64::new(seed as u64 | 1);
k.iter_mut().for_each(|v| {
use crate::rng::Rng64;
*v = seedgen.nextu();
});
unsafe {
Self {
c: _mm512_setr_epi64(0, 1, 2, 3, 4, 5, 6, 7),
k: _mm512_loadu_si512(k.as_ptr() as *const _),
}
}
}
#[target_feature(enable = "avx512f,avx512dq")]
pub unsafe fn compute_yz(y: __m512i, z: __m512i) -> __m256i {
let mut x = _mm512_add_epi64(_mm512_mullo_epi64(y, y), y);
x = _mm512_or_si512(_mm512_slli_epi64(x, 32), _mm512_srli_epi64(x, 32));
x = _mm512_add_epi64(_mm512_mullo_epi64(x, x), z);
x = _mm512_or_si512(_mm512_slli_epi64(x, 32), _mm512_srli_epi64(x, 32));
x = _mm512_add_epi64(_mm512_mullo_epi64(x, x), y);
x = _mm512_or_si512(_mm512_slli_epi64(x, 32), _mm512_srli_epi64(x, 32));
_mm512_cvtepi64_epi32(_mm512_srli_epi64(
_mm512_add_epi64(_mm512_mullo_epi64(x, x), z),
32,
))
}
#[target_feature(enable = "avx512f,avx512dq")]
pub(crate) unsafe fn compute(c: __m512i, k: __m512i) -> __m256i {
unsafe {
let y = _mm512_mullo_epi64(c, k);
let z = _mm512_add_epi64(y, k);
Self::compute_yz(y, z)
}
}
#[target_feature(enable = "avx512f,avx512dq")]
pub unsafe fn nextu(&mut self) -> [u32; SQUARES32x8] {
unsafe {
let v = Self::compute(self.c, self.k);
self.c = _mm512_add_epi64(self.c, _mm512_set1_epi64(8));
let mut out = [0u32; SQUARES32x8];
_mm256_storeu_si256(out.as_mut_ptr() as *mut _, v);
out
}
}
#[target_feature(enable = "avx512f,avx512dq")]
pub unsafe fn nextf(&mut self) -> [f32; SQUARES32x8] {
let out = unsafe { self.nextu() };
let mut dst = [0f32; SQUARES32x8];
let scale = 1.0 / (u32::MAX as f32 + 1.0);
for i in 0..SQUARES32x8 {
dst[i] = out[i] as f32 * scale;
}
dst
}
}
#[repr(C)]
pub struct Squares32Simd([u8; 0]);
#[cfg(test)]
mod tests {
use super::*;
crate::safe_test!(Squares32);
crate::unsafe_test!(Squares32x8);
}