#[cfg(target_arch = "x86_64")]
use crate::rng32::jsf::JSF32X16;
use crate::{
_internal::chunk_seed32,
rng::{Rng32, Rng32V512},
rng32::{Jsf32, jsf::Jsf32x16},
};
use rayon::prelude::*;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
use std::{ptr, slice::from_raw_parts_mut};
#[unsafe(no_mangle)]
pub extern "C" fn jsf32_new(seed: u32) -> *mut Jsf32 {
Box::into_raw(Box::new(Jsf32::new(seed)))
}
#[unsafe(no_mangle)]
pub extern "C" fn jsf32_free(ptr: *mut Jsf32) {
if !ptr.is_null() {
unsafe {
drop(Box::from_raw(ptr));
}
}
}
#[unsafe(no_mangle)]
pub extern "C" fn jsf32_next_u32s(ptr: *mut Jsf32, out: *mut u32, count: usize) {
if count == 0 {
return;
}
unsafe {
let rng = &mut *ptr;
let seed = rng.nextu();
let buffer = from_raw_parts_mut(out, count);
crate::_internal::par_fill_reseed32(buffer, seed, Jsf32::new, |r| r.nextu());
}
}
#[unsafe(no_mangle)]
pub extern "C" fn jsf32_next_f32s(ptr: *mut Jsf32, out: *mut f32, count: usize) {
if count == 0 {
return;
}
unsafe {
let rng = &mut *ptr;
let base_seed = rng.nextu();
let buffer = from_raw_parts_mut(out, count);
crate::_internal::par_fill_reseed32(buffer, base_seed, Jsf32::new, |r| r.nextf());
}
}
#[unsafe(no_mangle)]
pub extern "C" fn jsf32_rand_i32s(
ptr: *mut Jsf32,
out: *mut i32,
count: usize,
min: i32,
max: i32,
) {
if count == 0 {
return;
}
unsafe {
let rng = &mut *ptr;
let base_seed = rng.nextu();
let buffer = from_raw_parts_mut(out, count);
crate::_internal::par_fill_reseed32(buffer, base_seed, Jsf32::new, |r| r.randi(min, max));
}
}
#[unsafe(no_mangle)]
pub extern "C" fn jsf32_rand_f32s(
ptr: *mut Jsf32,
out: *mut f32,
count: usize,
min: f32,
max: f32,
) {
if count == 0 {
return;
}
unsafe {
let rng = &mut *ptr;
let base_seed = rng.nextu();
let buffer = from_raw_parts_mut(out, count);
crate::_internal::par_fill_reseed32(buffer, base_seed, Jsf32::new, |r| r.randf(min, max));
}
}
#[cfg(target_arch = "x86_64")]
const JSF32X16_PAR_CHUNK: usize = 1 << 20;
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn jsf32x16_next_u32s_chunk(rng: &mut Jsf32x16, chunk: &mut [u32], nt: bool) {
let mut out_ptr = chunk.as_mut_ptr();
let mut remaining = chunk.len();
let aligned = nt && (out_ptr as usize & 63) == 0;
const UNROLL: usize = JSF32X16 * 4;
if aligned {
while remaining >= UNROLL {
let v0 = rng.nextuv();
let v1 = rng.nextuv();
let v2 = rng.nextuv();
let v3 = rng.nextuv();
_mm512_stream_si512(out_ptr as *mut _, v0);
_mm512_stream_si512(out_ptr.add(JSF32X16) as *mut _, v1);
_mm512_stream_si512(out_ptr.add(JSF32X16 * 2) as *mut _, v2);
_mm512_stream_si512(out_ptr.add(JSF32X16 * 3) as *mut _, v3);
out_ptr = out_ptr.add(UNROLL);
remaining -= UNROLL;
}
while remaining >= JSF32X16 {
let v = rng.nextuv();
_mm512_stream_si512(out_ptr as *mut _, v);
out_ptr = out_ptr.add(JSF32X16);
remaining -= JSF32X16;
}
} else {
while remaining >= UNROLL {
let v0 = rng.nextuv();
let v1 = rng.nextuv();
let v2 = rng.nextuv();
let v3 = rng.nextuv();
_mm512_storeu_si512(out_ptr as *mut _, v0);
_mm512_storeu_si512(out_ptr.add(JSF32X16) as *mut _, v1);
_mm512_storeu_si512(out_ptr.add(JSF32X16 * 2) as *mut _, v2);
_mm512_storeu_si512(out_ptr.add(JSF32X16 * 3) as *mut _, v3);
out_ptr = out_ptr.add(UNROLL);
remaining -= UNROLL;
}
while remaining >= JSF32X16 {
let v = rng.nextuv();
_mm512_storeu_si512(out_ptr as *mut _, v);
out_ptr = out_ptr.add(JSF32X16);
remaining -= JSF32X16;
}
}
if remaining > 0 {
let mut tmp = [0u32; JSF32X16];
let v = rng.nextuv();
_mm512_storeu_si512(tmp.as_mut_ptr() as *mut _, v);
ptr::copy_nonoverlapping(tmp.as_ptr(), out_ptr, remaining);
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn jsf32x16_next_f32s_chunk(rng: &mut Jsf32x16, chunk: &mut [f32], nt: bool, scale: __m512) {
let mut out_ptr = chunk.as_mut_ptr();
let mut remaining = chunk.len();
let aligned = nt && (out_ptr as usize & 63) == 0;
const UNROLL: usize = JSF32X16 * 4;
if aligned {
while remaining >= UNROLL {
let v0 = rng.nextfv(scale);
let v1 = rng.nextfv(scale);
let v2 = rng.nextfv(scale);
let v3 = rng.nextfv(scale);
_mm512_stream_ps(out_ptr, v0);
_mm512_stream_ps(out_ptr.add(JSF32X16), v1);
_mm512_stream_ps(out_ptr.add(JSF32X16 * 2), v2);
_mm512_stream_ps(out_ptr.add(JSF32X16 * 3), v3);
out_ptr = out_ptr.add(UNROLL);
remaining -= UNROLL;
}
while remaining >= JSF32X16 {
let v = rng.nextfv(scale);
_mm512_stream_ps(out_ptr, v);
out_ptr = out_ptr.add(JSF32X16);
remaining -= JSF32X16;
}
} else {
while remaining >= UNROLL {
let v0 = rng.nextfv(scale);
let v1 = rng.nextfv(scale);
let v2 = rng.nextfv(scale);
let v3 = rng.nextfv(scale);
_mm512_storeu_ps(out_ptr, v0);
_mm512_storeu_ps(out_ptr.add(JSF32X16), v1);
_mm512_storeu_ps(out_ptr.add(JSF32X16 * 2), v2);
_mm512_storeu_ps(out_ptr.add(JSF32X16 * 3), v3);
out_ptr = out_ptr.add(UNROLL);
remaining -= UNROLL;
}
while remaining >= JSF32X16 {
let v = rng.nextfv(scale);
_mm512_storeu_ps(out_ptr, v);
out_ptr = out_ptr.add(JSF32X16);
remaining -= JSF32X16;
}
}
if remaining > 0 {
let mut tmp = [0f32; JSF32X16];
let v = rng.nextfv(scale);
_mm512_storeu_ps(tmp.as_mut_ptr(), v);
ptr::copy_nonoverlapping(tmp.as_ptr(), out_ptr, remaining);
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn jsf32x16_rand_i32s_chunk(
rng: &mut Jsf32x16,
chunk: &mut [i32], nt: bool,
v_range: __m512i,
v_min: __m512i,
) {
let mut out_ptr = chunk.as_mut_ptr();
let mut remaining = chunk.len();
let aligned = nt && (out_ptr as usize & 63) == 0;
const UNROLL: usize = JSF32X16 * 4;
if aligned {
while remaining >= UNROLL {
let v0 = rng.randiv(v_range, v_min);
let v1 = rng.randiv(v_range, v_min);
let v2 = rng.randiv(v_range, v_min);
let v3 = rng.randiv(v_range, v_min);
_mm512_stream_si512(out_ptr as *mut _, v0);
_mm512_stream_si512(out_ptr.add(JSF32X16) as *mut _, v1);
_mm512_stream_si512(out_ptr.add(JSF32X16 * 2) as *mut _, v2);
_mm512_stream_si512(out_ptr.add(JSF32X16 * 3) as *mut _, v3);
out_ptr = out_ptr.add(UNROLL);
remaining -= UNROLL;
}
while remaining >= JSF32X16 {
let v = rng.randiv(v_range, v_min);
_mm512_stream_si512(out_ptr as *mut _, v);
out_ptr = out_ptr.add(JSF32X16);
remaining -= JSF32X16;
}
} else {
while remaining >= UNROLL {
let v0 = rng.randiv(v_range, v_min);
let v1 = rng.randiv(v_range, v_min);
let v2 = rng.randiv(v_range, v_min);
let v3 = rng.randiv(v_range, v_min);
_mm512_storeu_si512(out_ptr as *mut _, v0);
_mm512_storeu_si512(out_ptr.add(JSF32X16) as *mut _, v1);
_mm512_storeu_si512(out_ptr.add(JSF32X16 * 2) as *mut _, v2);
_mm512_storeu_si512(out_ptr.add(JSF32X16 * 3) as *mut _, v3);
out_ptr = out_ptr.add(UNROLL);
remaining -= UNROLL;
}
while remaining >= JSF32X16 {
let v = rng.randiv(v_range, v_min);
_mm512_storeu_si512(out_ptr as *mut _, v);
out_ptr = out_ptr.add(JSF32X16);
remaining -= JSF32X16;
}
}
if remaining > 0 {
let mut tmp = [0i32; JSF32X16];
let v = rng.randiv(v_range, v_min);
_mm512_storeu_si512(tmp.as_mut_ptr() as *mut _, v);
ptr::copy_nonoverlapping(tmp.as_ptr(), out_ptr, remaining);
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn jsf32x16_rand_f32s_chunk(
rng: &mut Jsf32x16,
chunk: &mut [f32], nt: bool,
v_mult: __m512,
v_min: __m512,
) {
let mut out_ptr = chunk.as_mut_ptr();
let mut remaining = chunk.len();
let aligned = nt && (out_ptr as usize & 63) == 0;
const UNROLL: usize = JSF32X16 * 4;
if aligned {
while remaining >= UNROLL {
let v0 = rng.randfv(v_mult, v_min);
let v1 = rng.randfv(v_mult, v_min);
let v2 = rng.randfv(v_mult, v_min);
let v3 = rng.randfv(v_mult, v_min);
_mm512_stream_ps(out_ptr, v0);
_mm512_stream_ps(out_ptr.add(JSF32X16), v1);
_mm512_stream_ps(out_ptr.add(JSF32X16 * 2), v2);
_mm512_stream_ps(out_ptr.add(JSF32X16 * 3), v3);
out_ptr = out_ptr.add(UNROLL);
remaining -= UNROLL;
}
while remaining >= JSF32X16 {
let v = rng.randfv(v_mult, v_min);
_mm512_stream_ps(out_ptr, v);
out_ptr = out_ptr.add(JSF32X16);
remaining -= JSF32X16;
}
} else {
while remaining >= UNROLL {
let v0 = rng.randfv(v_mult, v_min);
let v1 = rng.randfv(v_mult, v_min);
let v2 = rng.randfv(v_mult, v_min);
let v3 = rng.randfv(v_mult, v_min);
_mm512_storeu_ps(out_ptr, v0);
_mm512_storeu_ps(out_ptr.add(JSF32X16), v1);
_mm512_storeu_ps(out_ptr.add(JSF32X16 * 2), v2);
_mm512_storeu_ps(out_ptr.add(JSF32X16 * 3), v3);
out_ptr = out_ptr.add(UNROLL);
remaining -= UNROLL;
}
while remaining >= JSF32X16 {
let v = rng.randfv(v_mult, v_min);
_mm512_storeu_ps(out_ptr, v);
out_ptr = out_ptr.add(JSF32X16);
remaining -= JSF32X16;
}
}
if remaining > 0 {
let mut tmp = [0f32; JSF32X16];
let v = rng.randfv(v_mult, v_min);
_mm512_storeu_ps(tmp.as_mut_ptr(), v);
ptr::copy_nonoverlapping(tmp.as_ptr(), out_ptr, remaining);
}
}
#[unsafe(no_mangle)]
pub extern "C" fn jsf32x16_new(seed: u32) -> *mut Jsf32x16 {
unsafe { Box::into_raw(Box::new(Jsf32x16::new(seed))) }
}
#[unsafe(no_mangle)]
pub extern "C" fn jsf32x16_free(ptr: *mut Jsf32x16) {
if !ptr.is_null() {
unsafe { drop(Box::from_raw(ptr)) };
}
}
#[unsafe(no_mangle)]
pub extern "C" fn jsf32x16_next_u32s(ptr: *mut Jsf32x16, out: *mut u32, count: usize) {
if count == 0 {
return;
}
unsafe {
let rng = &mut *ptr;
let mut tmp = [0u32; JSF32X16];
let v = rng.nextuv();
_mm512_storeu_si512(tmp.as_mut_ptr() as *mut _, v);
let base_seed = tmp[0];
let buffer = from_raw_parts_mut(out, count);
buffer
.par_chunks_mut(JSF32X16_PAR_CHUNK)
.enumerate()
.for_each(|(chunk_idx, chunk)| {
let mut local_rng = Jsf32x16::new(chunk_seed32(base_seed, chunk_idx));
jsf32x16_next_u32s_chunk(&mut local_rng, chunk, crate::_internal::prefer_nt_for(count, chunk));
});
}
}
#[unsafe(no_mangle)]
pub extern "C" fn jsf32x16_next_f32s(ptr: *mut Jsf32x16, out: *mut f32, count: usize) {
if count == 0 {
return;
}
unsafe {
let rng = &mut *ptr;
let mut tmp = [0u32; JSF32X16];
let v = rng.nextuv();
_mm512_storeu_si512(tmp.as_mut_ptr() as *mut _, v);
let base_seed = tmp[0];
let buffer = from_raw_parts_mut(out, count);
buffer
.par_chunks_mut(JSF32X16_PAR_CHUNK)
.enumerate()
.for_each(|(chunk_idx, chunk)| {
let mut local_rng = Jsf32x16::new(chunk_seed32(base_seed, chunk_idx));
let scale = _mm512_set1_ps(1.0 / (u32::MAX as f32 + 1.0));
jsf32x16_next_f32s_chunk(&mut local_rng, chunk, crate::_internal::prefer_nt_for(count, chunk), scale);
});
}
}
#[unsafe(no_mangle)]
pub extern "C" fn jsf32x16_rand_i32s(
ptr: *mut Jsf32x16,
out: *mut i32,
count: usize,
min: i32,
max: i32,
) {
if count == 0 {
return;
}
unsafe {
let rng = &mut *ptr;
let mut tmp = [0u32; JSF32X16];
let v = rng.nextuv();
_mm512_storeu_si512(tmp.as_mut_ptr() as *mut _, v);
let base_seed = tmp[0];
let buffer = from_raw_parts_mut(out, count);
buffer
.par_chunks_mut(JSF32X16_PAR_CHUNK)
.enumerate()
.for_each(|(chunk_idx, chunk)| {
let mut local_rng = Jsf32x16::new(chunk_seed32(base_seed, chunk_idx));
let v_range = _mm512_set1_epi64(max as i64 - min as i64 + 1);
let v_min = _mm512_set1_epi32(min);
jsf32x16_rand_i32s_chunk(&mut local_rng, chunk, crate::_internal::prefer_nt_for(count, chunk), v_range, v_min);
});
}
}
#[unsafe(no_mangle)]
pub extern "C" fn jsf32x16_rand_f32s(
ptr: *mut Jsf32x16,
out: *mut f32,
count: usize,
min: f32,
max: f32,
) {
if count == 0 {
return;
}
unsafe {
let rng = &mut *ptr;
let mut tmp = [0u32; JSF32X16];
let v = rng.nextuv();
_mm512_storeu_si512(tmp.as_mut_ptr() as *mut _, v);
let base_seed = tmp[0];
let buffer = from_raw_parts_mut(out, count);
buffer
.par_chunks_mut(JSF32X16_PAR_CHUNK)
.enumerate()
.for_each(|(chunk_idx, chunk)| {
let mut local_rng = Jsf32x16::new(chunk_seed32(base_seed, chunk_idx));
let v_mult = _mm512_set1_ps((max - min) * (1.0 / (u32::MAX as f32 + 1.0)));
let v_min = _mm512_set1_ps(min);
jsf32x16_rand_f32s_chunk(&mut local_rng, chunk, crate::_internal::prefer_nt_for(count, chunk), v_mult, v_min);
});
}
}