use crate::rng::Rng64;
use crate::rng64::{Sfc64, Sfc64x8, SplitMix64};
use rayon::prelude::*;
use std::slice::from_raw_parts_mut;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
#[unsafe(no_mangle)]
pub extern "C" fn sfc64_new(seed: u64) -> *mut Sfc64 {
Box::into_raw(Box::new(Sfc64::new(seed)))
}
#[unsafe(no_mangle)]
pub extern "C" fn sfc64_free(ptr: *mut Sfc64) {
if !ptr.is_null() {
unsafe { drop(Box::from_raw(ptr)) };
}
}
const SFC64_PAR_CHUNK: usize = 0x10000;
#[unsafe(no_mangle)]
pub extern "C" fn sfc64_next_u64s(ptr: *mut Sfc64, out: *mut u64, count: usize) {
unsafe {
let rng = &mut *ptr;
let buffer = from_raw_parts_mut(out, count);
let seed = rng.nextu();
buffer
.par_chunks_mut(SFC64_PAR_CHUNK)
.enumerate()
.for_each(|(chunk_idx, chunk)| {
let seed = SplitMix64::compute(
seed.wrapping_add((chunk_idx as u64).wrapping_mul(0x9E3779B97F4A7C15)),
);
let mut rng = Sfc64::new(seed);
for v in chunk {
*v = rng.nextu();
}
});
}
}
#[unsafe(no_mangle)]
pub extern "C" fn sfc64_next_f64s(ptr: *mut Sfc64, out: *mut f64, count: usize) {
unsafe {
let rng = &mut *ptr;
let buffer = from_raw_parts_mut(out, count);
let seed = rng.nextu();
buffer
.par_chunks_mut(SFC64_PAR_CHUNK)
.enumerate()
.for_each(|(chunk_idx, chunk)| {
let seed = SplitMix64::compute(
seed.wrapping_add((chunk_idx as u64).wrapping_mul(0x9E3779B97F4A7C15)),
);
let mut rng = Sfc64::new(seed);
for v in chunk {
*v = rng.nextf();
}
});
}
}
#[unsafe(no_mangle)]
pub extern "C" fn sfc64_rand_i64s(
ptr: *mut Sfc64,
out: *mut i64,
count: usize,
min: i64,
max: i64,
) {
unsafe {
let rng = &mut *ptr;
let buffer = from_raw_parts_mut(out, count);
let seed = rng.nextu();
buffer
.par_chunks_mut(SFC64_PAR_CHUNK)
.enumerate()
.for_each(|(chunk_idx, chunk)| {
let seed = SplitMix64::compute(
seed.wrapping_add((chunk_idx as u64).wrapping_mul(0x9E3779B97F4A7C15)),
);
let mut rng = Sfc64::new(seed);
for v in chunk {
*v = rng.randi(min, max);
}
});
}
}
#[unsafe(no_mangle)]
pub extern "C" fn sfc64_rand_f64s(
ptr: *mut Sfc64,
out: *mut f64,
count: usize,
min: f64,
max: f64,
) {
unsafe {
let rng = &mut *ptr;
let buffer = from_raw_parts_mut(out, count);
let seed = rng.nextu();
buffer
.par_chunks_mut(SFC64_PAR_CHUNK)
.enumerate()
.for_each(|(chunk_idx, chunk)| {
let seed = SplitMix64::compute(
seed.wrapping_add((chunk_idx as u64).wrapping_mul(0x9E3779B97F4A7C15)),
);
let mut rng = Sfc64::new(seed);
for v in chunk {
*v = rng.randf(min, max);
}
});
}
}
#[unsafe(no_mangle)]
pub extern "C" fn sfc64x8_new(seed: u64) -> *mut Sfc64x8 {
Box::into_raw(Box::new(Sfc64x8::new(seed)))
}
#[unsafe(no_mangle)]
pub extern "C" fn sfc64x8_free(ptr: *mut Sfc64x8) {
if !ptr.is_null() {
unsafe { drop(Box::from_raw(ptr)) };
}
}
const SFC64X8_PAR_CHUNK: usize = 0x1000;
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[allow(unsafe_op_in_unsafe_fn, unused_assignments)]
unsafe fn sfc64x8_next_u64s_chunk(chunk_idx: usize, chunk: &mut [u64], seed: u64) {
const STRIDE: u64 = 0x9E3779B97F4A7C15u64;
let chunk_base = seed.wrapping_add((chunk_idx as u64).wrapping_mul(STRIDE));
macro_rules! make_state {
($group:expr) => {{
let mut va = [0; 8];
let mut vb = [0; 8];
let mut vc = [0; 8];
for i in 0usize..8 {
let seed = SplitMix64::compute(
chunk_base.wrapping_add(((($group << 3) + i) as u64).wrapping_mul(STRIDE)),
);
let mut sm = SplitMix64::new(seed);
va[i] = sm.nextu();
vb[i] = sm.nextu();
vc[i] = sm.nextu();
}
(
_mm512_loadu_si512(va.as_ptr() as *const __m512i),
_mm512_loadu_si512(vb.as_ptr() as *const __m512i),
_mm512_loadu_si512(vc.as_ptr() as *const __m512i),
_mm512_set1_epi64(1),
)
}};
}
let (mut a0, mut b0, mut c0, mut ctr0) = make_state!(0);
let (mut a1, mut b1, mut c1, mut ctr1) = make_state!(1);
let (mut a2, mut b2, mut c2, mut ctr2) = make_state!(2);
let (mut a3, mut b3, mut c3, mut ctr3) = make_state!(3);
let (mut a4, mut b4, mut c4, mut ctr4) = make_state!(4);
let (mut a5, mut b5, mut c5, mut ctr5) = make_state!(5);
let one = _mm512_set1_epi64(1i64);
macro_rules! step6 {
() => {{
let res0 = _mm512_add_epi64(_mm512_add_epi64(a0, b0), ctr0);
let res1 = _mm512_add_epi64(_mm512_add_epi64(a1, b1), ctr1);
let res2 = _mm512_add_epi64(_mm512_add_epi64(a2, b2), ctr2);
let res3 = _mm512_add_epi64(_mm512_add_epi64(a3, b3), ctr3);
let res4 = _mm512_add_epi64(_mm512_add_epi64(a4, b4), ctr4);
let res5 = _mm512_add_epi64(_mm512_add_epi64(a5, b5), ctr5);
a0 = _mm512_xor_si512(b0, _mm512_srli_epi64(b0, 11));
a1 = _mm512_xor_si512(b1, _mm512_srli_epi64(b1, 11));
a2 = _mm512_xor_si512(b2, _mm512_srli_epi64(b2, 11));
a3 = _mm512_xor_si512(b3, _mm512_srli_epi64(b3, 11));
a4 = _mm512_xor_si512(b4, _mm512_srli_epi64(b4, 11));
a5 = _mm512_xor_si512(b5, _mm512_srli_epi64(b5, 11));
b0 = _mm512_add_epi64(c0, _mm512_slli_epi64(c0, 3));
b1 = _mm512_add_epi64(c1, _mm512_slli_epi64(c1, 3));
b2 = _mm512_add_epi64(c2, _mm512_slli_epi64(c2, 3));
b3 = _mm512_add_epi64(c3, _mm512_slli_epi64(c3, 3));
b4 = _mm512_add_epi64(c4, _mm512_slli_epi64(c4, 3));
b5 = _mm512_add_epi64(c5, _mm512_slli_epi64(c5, 3));
c0 = _mm512_rol_epi64(res0, 24);
c1 = _mm512_rol_epi64(res1, 24);
c2 = _mm512_rol_epi64(res2, 24);
c3 = _mm512_rol_epi64(res3, 24);
c4 = _mm512_rol_epi64(res4, 24);
c5 = _mm512_rol_epi64(res5, 24);
ctr0 = _mm512_add_epi64(ctr0, one);
ctr1 = _mm512_add_epi64(ctr1, one);
ctr2 = _mm512_add_epi64(ctr2, one);
ctr3 = _mm512_add_epi64(ctr3, one);
ctr4 = _mm512_add_epi64(ctr4, one);
ctr5 = _mm512_add_epi64(ctr5, one);
(res0, res1, res2, res3, res4, res5)
}};
}
let is_aligned = (chunk.as_ptr() as usize) & 63 == 0;
let mut chunks_exact = chunk.chunks_exact_mut(48);
if is_aligned {
for dst in chunks_exact.by_ref() {
let (r0, r1, r2, r3, r4, r5) = step6!();
let p = dst.as_mut_ptr();
_mm512_stream_si512(p as *mut __m512i, r0);
_mm512_stream_si512(p.add(8) as *mut __m512i, r1);
_mm512_stream_si512(p.add(16) as *mut __m512i, r2);
_mm512_stream_si512(p.add(24) as *mut __m512i, r3);
_mm512_stream_si512(p.add(32) as *mut __m512i, r4);
_mm512_stream_si512(p.add(40) as *mut __m512i, r5);
}
} else {
for dst in chunks_exact.by_ref() {
let (r0, r1, r2, r3, r4, r5) = step6!();
let p = dst.as_mut_ptr();
_mm512_storeu_si512(p as *mut __m512i, r0);
_mm512_storeu_si512(p.add(8) as *mut __m512i, r1);
_mm512_storeu_si512(p.add(16) as *mut __m512i, r2);
_mm512_storeu_si512(p.add(24) as *mut __m512i, r3);
_mm512_storeu_si512(p.add(32) as *mut __m512i, r4);
_mm512_storeu_si512(p.add(40) as *mut __m512i, r5);
}
}
let rem = chunks_exact.into_remainder();
if !rem.is_empty() {
let mut tmp = [0u64; 48];
let (r0, r1, r2, r3, r4, r5) = step6!();
_mm512_storeu_si512(tmp.as_mut_ptr() as *mut __m512i, r0);
_mm512_storeu_si512(tmp.as_mut_ptr().add(8) as *mut __m512i, r1);
_mm512_storeu_si512(tmp.as_mut_ptr().add(16) as *mut __m512i, r2);
_mm512_storeu_si512(tmp.as_mut_ptr().add(24) as *mut __m512i, r3);
_mm512_storeu_si512(tmp.as_mut_ptr().add(32) as *mut __m512i, r4);
_mm512_storeu_si512(tmp.as_mut_ptr().add(40) as *mut __m512i, r5);
for (j, v) in rem.iter_mut().enumerate() {
*v = tmp[j];
}
}
}
#[unsafe(no_mangle)]
pub extern "C" fn sfc64x8_next_u64s(ptr: *mut Sfc64x8, out: *mut u64, count: usize) {
if count == 0 {
return;
}
unsafe {
let rng = &mut *ptr;
let vals = rng.nextu();
let seed = vals.iter().copied().fold(0, u64::wrapping_add);
let buffer = from_raw_parts_mut(out, count);
buffer
.par_chunks_mut(SFC64X8_PAR_CHUNK)
.enumerate()
.for_each(|(chunk_idx, chunk)| {
sfc64x8_next_u64s_chunk(chunk_idx, chunk, seed);
});
let seed =
SplitMix64::compute(seed.wrapping_add((count as u64).wrapping_mul(0x9E3779B97F4A7C15)));
*rng = Sfc64x8::new(seed);
}
}