use crate::rng::Rng64;
use crate::rng64::SplitMix64;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
#[repr(C, align(64))]
pub struct Sfc64 {
a: u64,
b: u64,
c: u64,
counter: u64,
}
impl Sfc64 {
pub fn new(seed: u64) -> Self {
let mut seedgen = SplitMix64::new(seed);
Self {
a: seedgen.nextu(),
b: seedgen.nextu(),
c: seedgen.nextu(),
counter: 1,
}
}
}
impl Rng64 for Sfc64 {
#[inline(always)]
fn nextu(&mut self) -> u64 {
let res = self.a.wrapping_add(self.b).wrapping_add(self.counter);
self.a = self.b ^ (self.b >> 11);
self.b = self.c.wrapping_add(self.c << 3);
self.c = res.rotate_left(24);
self.counter = self.counter.wrapping_add(1);
res
}
}
#[cfg(target_arch = "x86_64")]
#[repr(C, align(64))]
pub struct Sfc64x8 {
a: __m512i,
b: __m512i,
c: __m512i,
counter: __m512i,
}
#[cfg(target_arch = "x86_64")]
impl Sfc64x8 {
#[inline(always)]
pub fn new(seed: u64) -> Self {
let mut a = [0u64; 8];
let mut b = [0u64; 8];
let mut c = [0u64; 8];
let mut sg = SplitMix64::new(seed);
for i in 0..8 {
a[i] = sg.nextu();
b[i] = sg.nextu();
c[i] = sg.nextu();
}
unsafe {
Self {
a: _mm512_loadu_si512(a.as_ptr() as *const __m512i),
b: _mm512_loadu_si512(b.as_ptr() as *const __m512i),
c: _mm512_loadu_si512(c.as_ptr() as *const __m512i),
counter: _mm512_set1_epi64(1),
}
}
}
#[inline(always)]
pub unsafe fn nextu(&mut self) -> [u64; 8] {
unsafe {
let one = _mm512_set1_epi64(1);
let res = _mm512_add_epi64(_mm512_add_epi64(self.a, self.b), self.counter);
self.a = _mm512_xor_si512(self.b, _mm512_srli_epi64(self.b, 11));
self.b = _mm512_add_epi64(self.c, _mm512_slli_epi64(self.c, 3));
self.c = _mm512_or_si512(_mm512_slli_epi64(res, 24), _mm512_srli_epi64(res, 40));
self.counter = _mm512_add_epi64(self.counter, one);
let mut out = [0u64; 8];
_mm512_storeu_si512(out.as_mut_ptr() as *mut __m512i, res);
out
}
}
#[inline(always)]
pub unsafe fn nextf(&mut self) -> [f64; 8] {
unsafe {
let u = self.nextu();
let mut out = [0f64; 8];
let scale = 1.0 / (u64::MAX as f64 + 1.0);
for i in 0..8 {
out[i] = u[i] as f64 * scale;
}
out
}
}
#[inline(always)]
pub unsafe fn randi(&mut self, min: i64, max: i64) -> [i64; 8] {
unsafe {
let u = self.nextu();
let range = (max as i128 - min as i128 + 1) as u128;
let mut out = [0i64; 8];
for i in 0..8 {
out[i] = ((u[i] as u128 * range) >> 64) as i64 + min;
}
out
}
}
#[inline(always)]
pub unsafe fn randf(&mut self, min: f64, max: f64) -> [f64; 8] {
unsafe {
let u = self.nextu();
let range = max - min;
let scale = range * (1.0 / (u64::MAX as f64 + 1.0));
let mut out = [0f64; 8];
for i in 0..8 {
out[i] = (u[i] as f64 * scale) + min;
}
out
}
}
}
#[cfg(test)]
mod tests {
use super::*;
crate::unsafe_test!(Sfc64x8);
}