pub const FSCALE64: f64 = 1.0 / (u64::MAX as f64 + 1.0);
pub const FSCALE32: f32 = 1.0 / (u32::MAX as f32 + 1.0);
#[inline(always)]
pub(crate) unsafe fn fill_with<T, F: FnMut() -> T>(out: *mut T, count: usize, mut next: F) {
let buffer = unsafe { std::slice::from_raw_parts_mut(out, count) };
for v in buffer {
*v = next();
}
}
#[inline(always)]
pub(crate) unsafe fn fill_chunk<T: Copy, const N: usize, F: FnMut() -> [T; N]>(
chunk: &mut [T],
mut generate: F,
) {
let mut out_ptr = chunk.as_mut_ptr();
let mut remaining = chunk.len();
while remaining >= N * 4 {
let v0 = generate();
let v1 = generate();
let v2 = generate();
let v3 = generate();
unsafe {
std::ptr::copy_nonoverlapping(v0.as_ptr(), out_ptr, N);
std::ptr::copy_nonoverlapping(v1.as_ptr(), out_ptr.add(N), N);
std::ptr::copy_nonoverlapping(v2.as_ptr(), out_ptr.add(N * 2), N);
std::ptr::copy_nonoverlapping(v3.as_ptr(), out_ptr.add(N * 3), N);
out_ptr = out_ptr.add(N * 4);
}
remaining -= N * 4;
}
while remaining >= N {
let v = generate();
unsafe {
std::ptr::copy_nonoverlapping(v.as_ptr(), out_ptr, N);
out_ptr = out_ptr.add(N);
}
remaining -= N;
}
if remaining > 0 {
let v = generate();
unsafe { std::ptr::copy_nonoverlapping(v.as_ptr(), out_ptr, remaining) };
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
pub(crate) unsafe fn fill_chunk_nt<T: Copy, const N: usize, F: FnMut() -> [T; N]>(
chunk: &mut [T],
mut generate: F,
) {
use std::arch::x86_64::*;
let words = (N * size_of::<T>()) / 32;
if words == 0 || (N * size_of::<T>()) % 32 != 0 {
return fill_chunk(chunk, generate);
}
let mut p = chunk.as_mut_ptr();
let mut rem = chunk.len();
if (p as usize) & 31 == 0 {
while rem >= N {
let v = generate();
let src = v.as_ptr() as *const __m256i;
for i in 0..words {
_mm256_stream_si256(
(p as *mut u8).add(i * 32) as *mut __m256i,
_mm256_loadu_si256(src.add(i)),
);
}
p = p.add(N);
rem -= N;
}
_mm_sfence();
} else {
while rem >= N {
let v = generate();
std::ptr::copy_nonoverlapping(v.as_ptr(), p, N);
p = p.add(N);
rem -= N;
}
}
if rem > 0 {
let v = generate();
std::ptr::copy_nonoverlapping(v.as_ptr(), p, rem);
}
}
pub(crate) const NT_THRESHOLD_BYTES: usize = 24 << 20;
#[inline(always)]
pub(crate) fn prefer_nt<T>(total_elems: usize) -> bool {
total_elems * size_of::<T>() > NT_THRESHOLD_BYTES
}
#[inline(always)]
pub(crate) fn prefer_nt_for<T>(total_elems: usize, _sample: &[T]) -> bool {
total_elems * size_of::<T>() > NT_THRESHOLD_BYTES
}
#[inline(always)]
pub(crate) unsafe fn fill_chunk_auto<T: Copy, const N: usize, F: FnMut() -> [T; N]>(
chunk: &mut [T],
nt: bool,
generate: F,
) {
#[cfg(target_arch = "x86_64")]
if nt && std::arch::is_x86_feature_detected!("avx2") {
return unsafe { fill_chunk_nt(chunk, generate) };
}
let _ = nt;
unsafe { fill_chunk(chunk, generate) }
}
pub(crate) fn par_fill_reseed32<R, T, NF, SF>(
buffer: &mut [T],
base_seed: u32,
new_rng: NF,
step: SF,
) where
T: Copy + Default + Send,
NF: Fn(u32) -> R + Sync,
SF: Fn(&mut R) -> T + Sync,
{
use rayon::iter::{IndexedParallelIterator, ParallelIterator};
use rayon::slice::ParallelSliceMut;
const PAR_CHUNK: usize = 0x20000;
let nt = prefer_nt::<T>(buffer.len());
buffer
.par_chunks_mut(PAR_CHUNK)
.enumerate()
.for_each(|(chunk_idx, chunk)| {
let mut rng = new_rng(chunk_seed32(base_seed, chunk_idx));
unsafe {
fill_chunk_auto(chunk, nt, || {
let mut out = [T::default(); 16];
for v in &mut out {
*v = step(&mut rng);
}
out
});
}
});
}
pub(crate) fn par_fill_reseed64<R, T, NF, SF>(
buffer: &mut [T],
base_seed: u64,
new_rng: NF,
step: SF,
) where
T: Copy + Default + Send,
NF: Fn(u64) -> R + Sync,
SF: Fn(&mut R) -> T + Sync,
{
use rayon::iter::{IndexedParallelIterator, ParallelIterator};
use rayon::slice::ParallelSliceMut;
const PAR_CHUNK: usize = 0x20000;
let nt = prefer_nt::<T>(buffer.len());
buffer
.par_chunks_mut(PAR_CHUNK)
.enumerate()
.for_each(|(chunk_idx, chunk)| {
let chunk_seed = crate::rng64::SplitMix64::compute(
base_seed.wrapping_add((chunk_idx as u64).wrapping_mul(0x9E3779B97F4A7C15)),
);
let mut rng = new_rng(chunk_seed);
unsafe {
fill_chunk_auto(chunk, nt, || {
let mut out = [T::default(); 8];
for v in &mut out {
*v = step(&mut rng);
}
out
});
}
});
}
#[inline]
pub(crate) fn chunk_seed32(base_seed: u32, chunk_idx: usize) -> u32 {
let x = base_seed.wrapping_add((chunk_idx as u32).wrapping_mul(0x9E37_79B9));
let mut z = x as u64;
z ^= z >> 16;
z = z.wrapping_mul(0xFF51_AFD7_ED55_8CCD);
z ^= z >> 16;
z = z.wrapping_mul(0xC4CE_B9FE_1A85_EC53);
(z ^ (z >> 16)) as u32
}
#[cfg(test)]
mod tests {
use super::*;
fn counter_batches<const N: usize>() -> impl FnMut() -> [u32; N] {
let mut next = 0u32;
move || {
let mut out = [0u32; N];
for v in &mut out {
*v = next;
next += 1;
}
out
}
}
fn check_fill(buf: &[u32], len: usize) {
for (i, &v) in buf[..len].iter().enumerate() {
assert_eq!(v as usize, i, "element {i} wrong");
}
}
#[test]
fn fill_chunk_writes_every_element() {
for len in [0usize, 1, 7, 16, 63, 64, 65, 1000] {
let mut buf = vec![u32::MAX; len];
unsafe { fill_chunk::<u32, 16, _>(&mut buf, counter_batches()) };
check_fill(&buf, len);
}
}
#[cfg(target_arch = "x86_64")]
#[test]
fn fill_chunk_nt_writes_every_element() {
if !std::arch::is_x86_feature_detected!("avx2") {
return;
}
for len in [0usize, 1, 7, 16, 63, 64, 65, 1000] {
let mut buf = vec![u32::MAX; len];
unsafe { fill_chunk_nt::<u32, 16, _>(&mut buf, counter_batches()) };
check_fill(&buf, len);
}
}
#[cfg(target_arch = "x86_64")]
#[test]
fn fill_chunk_nt_small_batch_falls_back() {
if !std::arch::is_x86_feature_detected!("avx2") {
return;
}
let mut buf = vec![u32::MAX; 100];
unsafe { fill_chunk_nt::<u32, 4, _>(&mut buf, counter_batches()) };
check_fill(&buf, 100);
}
}