#![allow(unsafe_code)]
use crate::xoshiro::Xoshiro256PlusPlus;
#[inline]
fn derive_lanes<const K: usize>(
rng: &Xoshiro256PlusPlus,
) -> [[u64; 4]; K] {
const LANE_SALT: [u64; 4] = [
0xA076_1D64_78BD_642F,
0xE703_7ED1_A0B4_28DB,
0x8EBC_6AF0_9C88_C6E3,
0x5899_65CC_7537_4CC3,
];
let base = rng.state_snapshot();
let mut out = [[0u64; 4]; K];
for (k, lane) in out.iter_mut().enumerate() {
let mut sm = base[0] ^ LANE_SALT[k];
for slot in lane.iter_mut() {
*slot = splitmix64(&mut sm);
}
for (slot, &b) in lane.iter_mut().zip(base.iter()) {
*slot ^= b.rotate_left(((k as u32 + 1) * 13) % 64);
}
}
out
}
#[inline]
fn splitmix64(state: &mut u64) -> u64 {
*state = state.wrapping_add(0x9E37_79B9_7F4A_7C15);
let mut z = *state;
z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
z ^ (z >> 31)
}
const SIMD_THRESHOLD: usize = 64;
#[cfg(target_arch = "aarch64")]
#[inline]
pub fn fill_bytes(rng: &mut Xoshiro256PlusPlus, dest: &mut [u8]) {
if dest.len() < SIMD_THRESHOLD {
rng.fill_bytes_scalar(dest);
return;
}
aarch64::fill_bytes_neon(rng, dest);
}
#[cfg(target_arch = "x86_64")]
#[inline]
pub fn fill_bytes(rng: &mut Xoshiro256PlusPlus, dest: &mut [u8]) {
if dest.len() < SIMD_THRESHOLD || !is_avx2_available() {
rng.fill_bytes_scalar(dest);
return;
}
unsafe { x86_64::fill_bytes_avx2(rng, dest) };
}
#[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
#[inline]
pub fn fill_bytes(rng: &mut Xoshiro256PlusPlus, dest: &mut [u8]) {
rng.fill_bytes_scalar(dest);
}
#[cfg(all(target_arch = "x86_64", feature = "std"))]
#[inline]
fn is_avx2_available() -> bool {
std::is_x86_feature_detected!("avx2")
}
#[cfg(all(target_arch = "x86_64", not(feature = "std")))]
#[inline]
fn is_avx2_available() -> bool {
cfg!(target_feature = "avx2")
}
#[cfg(target_arch = "aarch64")]
mod aarch64 {
use super::Xoshiro256PlusPlus;
use core::arch::aarch64::*;
struct Lanes {
s: [uint64x2_t; 4],
}
impl Lanes {
#[inline]
fn from_pair(lane0: [u64; 4], lane1: [u64; 4]) -> Self {
unsafe {
let mut s = [vdupq_n_u64(0); 4];
for (j, slot) in s.iter_mut().enumerate() {
let r = vsetq_lane_u64::<0>(lane0[j], *slot);
*slot = vsetq_lane_u64::<1>(lane1[j], r);
}
Self { s }
}
}
#[inline]
fn lane0_state(&self) -> [u64; 4] {
unsafe {
[
vgetq_lane_u64::<0>(self.s[0]),
vgetq_lane_u64::<0>(self.s[1]),
vgetq_lane_u64::<0>(self.s[2]),
vgetq_lane_u64::<0>(self.s[3]),
]
}
}
#[inline]
unsafe fn step(&mut self) -> uint64x2_t {
let s = &mut self.s;
let sum = vaddq_u64(s[0], s[3]);
let res = vaddq_u64(rotl::<23, 41>(sum), s[0]);
let t = vshlq_n_u64::<17>(s[1]);
s[2] = veorq_u64(s[2], s[0]);
s[3] = veorq_u64(s[3], s[1]);
s[1] = veorq_u64(s[1], s[2]);
s[0] = veorq_u64(s[0], s[3]);
s[2] = veorq_u64(s[2], t);
s[3] = rotl::<45, 19>(s[3]);
res
}
}
#[inline]
unsafe fn rotl<const N: i32, const N_INV: i32>(
x: uint64x2_t,
) -> uint64x2_t {
vorrq_u64(vshlq_n_u64::<N>(x), vshrq_n_u64::<N_INV>(x))
}
pub(super) fn fill_bytes_neon(
rng: &mut Xoshiro256PlusPlus,
dest: &mut [u8],
) {
let four = super::derive_lanes::<4>(rng);
let mut lanes_a = Lanes::from_pair(four[0], four[1]);
let mut lanes_b = Lanes::from_pair(four[2], four[3]);
let dest_len = dest.len();
let mut i = 0;
while i + 64 <= dest_len {
unsafe {
let oa0 = lanes_a.step();
let ob0 = lanes_b.step();
let oa1 = lanes_a.step();
let ob1 = lanes_b.step();
let p = dest.as_mut_ptr().add(i) as *mut u64;
vst1q_u64(p, oa0);
vst1q_u64(p.add(2), ob0);
vst1q_u64(p.add(4), oa1);
vst1q_u64(p.add(6), ob1);
}
i += 64;
}
while i + 32 <= dest_len {
unsafe {
let oa = lanes_a.step();
let ob = lanes_b.step();
let p = dest.as_mut_ptr().add(i) as *mut u64;
vst1q_u64(p, oa);
vst1q_u64(p.add(2), ob);
}
i += 32;
}
while i + 16 <= dest_len {
unsafe {
let out = lanes_a.step();
vst1q_u64(dest.as_mut_ptr().add(i) as *mut u64, out);
}
i += 16;
}
rng.set_state(lanes_a.lane0_state());
if i < dest_len {
rng.fill_bytes_scalar(&mut dest[i..]);
}
}
}
#[cfg(target_arch = "x86_64")]
mod x86_64 {
use super::Xoshiro256PlusPlus;
use core::arch::x86_64::*;
struct Lanes {
s: [__m256i; 4],
}
impl Lanes {
#[target_feature(enable = "avx2")]
unsafe fn from_rng(rng: &Xoshiro256PlusPlus) -> Self {
let lane_states = super::derive_lanes::<4>(rng);
let mut s = [_mm256_setzero_si256(); 4];
for (j, slot) in s.iter_mut().enumerate() {
*slot = _mm256_set_epi64x(
lane_states[3][j] as i64,
lane_states[2][j] as i64,
lane_states[1][j] as i64,
lane_states[0][j] as i64,
);
}
Self { s }
}
#[inline]
#[target_feature(enable = "avx2")]
unsafe fn step(&mut self) -> __m256i {
let s = &mut self.s;
let sum = _mm256_add_epi64(s[0], s[3]);
let res = _mm256_add_epi64(rotl::<23, 41>(sum), s[0]);
let t = _mm256_slli_epi64::<17>(s[1]);
s[2] = _mm256_xor_si256(s[2], s[0]);
s[3] = _mm256_xor_si256(s[3], s[1]);
s[1] = _mm256_xor_si256(s[1], s[2]);
s[0] = _mm256_xor_si256(s[0], s[3]);
s[2] = _mm256_xor_si256(s[2], t);
s[3] = rotl::<45, 19>(s[3]);
res
}
#[target_feature(enable = "avx2")]
unsafe fn lane0_state(&self) -> [u64; 4] {
let mut tmp = [0i64; 4];
let mut out = [0u64; 4];
for (j, slot) in out.iter_mut().enumerate() {
_mm256_storeu_si256(
tmp.as_mut_ptr() as *mut __m256i,
self.s[j],
);
*slot = tmp[0] as u64;
}
out
}
}
#[inline]
#[target_feature(enable = "avx2")]
unsafe fn rotl<const N: i32, const N_INV: i32>(
x: __m256i,
) -> __m256i {
_mm256_or_si256(
_mm256_slli_epi64::<N>(x),
_mm256_srli_epi64::<N_INV>(x),
)
}
#[target_feature(enable = "avx2")]
pub(super) unsafe fn fill_bytes_avx2(
rng: &mut Xoshiro256PlusPlus,
dest: &mut [u8],
) {
let mut lanes = Lanes::from_rng(rng);
let mut i = 0;
while i + 32 <= dest.len() {
let out = lanes.step();
_mm256_storeu_si256(
dest.as_mut_ptr().add(i) as *mut __m256i,
out,
);
i += 32;
}
let new_scalar = lanes.lane0_state();
rng.set_state(new_scalar);
if i < dest.len() {
rng.fill_bytes_scalar(&mut dest[i..]);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "alloc")]
use alloc::vec;
#[cfg(all(not(feature = "alloc"), feature = "std"))]
use std::vec;
#[test]
fn fill_produces_uniform_bytes() {
let mut rng = Xoshiro256PlusPlus::from_u64_seed(0xC0DE_BEEF);
let mut buf = [0u8; 64 * 1024];
fill_bytes(&mut rng, &mut buf);
let mut counts = [0u32; 256];
for &b in &buf[..] {
counts[b as usize] += 1;
}
let mean = (buf.len() / 256) as f64;
let chi2: f64 = counts
.iter()
.map(|&c| {
let diff = c as f64 - mean;
diff * diff / mean
})
.sum();
assert!(chi2 < 500.0, "χ² = {chi2} too high");
}
#[test]
fn fill_handles_short_and_unaligned_lengths() {
let mut rng = Xoshiro256PlusPlus::from_u64_seed(1);
for &len in
&[0usize, 1, 7, 15, 16, 17, 31, 33, 63, 65, 127, 129]
{
let mut buf = vec![0u8; len];
fill_bytes(&mut rng, &mut buf);
if len > 4 {
assert!(
buf.iter().any(|&b| b != 0),
"no entropy at len {len}"
);
}
}
}
#[test]
#[cfg(any(target_arch = "aarch64", target_arch = "x86_64"))]
fn simd_diverges_from_scalar() {
let mut a = Xoshiro256PlusPlus::from_u64_seed(42);
let mut b = Xoshiro256PlusPlus::from_u64_seed(42);
let mut sa = [0u8; 256];
let mut sb = [0u8; 256];
fill_bytes(&mut a, &mut sa);
b.fill_bytes_scalar(&mut sb);
assert_ne!(sa, sb, "SIMD and scalar must diverge");
}
}