use std::sync::Arc;
use std::sync::OnceLock;
use std::sync::atomic::{AtomicU64, Ordering};
use papaya::HashMap as PapayaMap;
use svod_dtype::DeviceSpec;
use crate::Tensor;
const SEED_HI_BY_DEVICE_IDX: [u32; 16] = [
0x14b81119, 0x764f528d, 0xdb12cf1a, 0x0f79f7d3, 0x900b0937, 0x3eaa7569, 0x7c0172cb, 0x28a3a23b, 0x15036943, 0xdfa285d3, 0xeae31389, 0xf888ebe4, 0x9a1c7059, 0x4273a435, 0x5a3336e3, 0x0f41de35, ];
pub(crate) struct DeviceRngState {
pub seed: Tensor,
pub counter: AtomicU64,
seed_epoch: u64,
}
static RNG_STATES: OnceLock<PapayaMap<DeviceSpec, Arc<DeviceRngState>>> = OnceLock::new();
static GLOBAL_SEED: AtomicU64 = AtomicU64::new(0);
static SEED_EPOCH: AtomicU64 = AtomicU64::new(0);
static DEVICE_INDEX_COUNTER: AtomicU64 = AtomicU64::new(0);
fn rng_states() -> &'static PapayaMap<DeviceSpec, Arc<DeviceRngState>> {
RNG_STATES.get_or_init(PapayaMap::new)
}
pub fn manual_seed(seed: u64) {
GLOBAL_SEED.store(seed, Ordering::Release);
DEVICE_INDEX_COUNTER.store(0, Ordering::Release);
SEED_EPOCH.fetch_add(1, Ordering::AcqRel);
}
fn current_seed_lo() -> u32 {
if SEED_EPOCH.load(Ordering::Acquire) == 0 {
let now = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).map(|d| d.as_secs()).unwrap_or(0);
let _ = GLOBAL_SEED.compare_exchange(0, now, Ordering::SeqCst, Ordering::SeqCst);
let _ = SEED_EPOCH.compare_exchange(0, 1, Ordering::SeqCst, Ordering::SeqCst);
}
(GLOBAL_SEED.load(Ordering::Acquire) & 0xFFFF_FFFF) as u32
}
fn build_fresh_state(device_index: usize, seed_epoch: u64) -> DeviceRngState {
let seed_hi = SEED_HI_BY_DEVICE_IDX.get(device_index).copied().unwrap_or_else(|| {
panic!("svod_tensor::rand: device index {device_index} exceeds hardcoded SHA256 table (16 entries)")
});
let seed_lo = current_seed_lo();
DeviceRngState { seed: Tensor::from_slice([seed_hi, seed_lo]), counter: AtomicU64::new(0), seed_epoch }
}
fn get_or_init_state(device: &DeviceSpec) -> Arc<DeviceRngState> {
let current_epoch = {
let _ = current_seed_lo();
SEED_EPOCH.load(Ordering::Acquire)
};
let pinned = rng_states().pin();
if let Some(state) = pinned.get(device)
&& state.seed_epoch == current_epoch
{
return state.clone();
}
use papaya::Operation;
let mut prepared: Option<Arc<DeviceRngState>> = None;
let result = pinned.compute(device.clone(), |entry| match entry {
Some((_, state)) if state.seed_epoch == current_epoch => Operation::Abort::<Arc<DeviceRngState>, ()>(()),
_ => {
let fresh = prepared
.get_or_insert_with(|| {
let device_index = DEVICE_INDEX_COUNTER.fetch_add(1, Ordering::AcqRel) as usize;
Arc::new(build_fresh_state(device_index, current_epoch))
})
.clone();
Operation::Insert(fresh)
}
});
match result {
papaya::Compute::Aborted(()) => pinned.get(device).expect("entry must exist after abort").clone(),
papaya::Compute::Inserted(_, state) => state.clone(),
papaya::Compute::Updated { new: (_, state), .. } => state.clone(),
papaya::Compute::Removed(_, _) => unreachable!("compute closure never returns Remove"),
}
}
pub(crate) fn next_counter(device: &DeviceSpec, num: u64) -> (Tensor, u64) {
let state = get_or_init_state(device);
let counter = state.counter.fetch_add(num, Ordering::AcqRel);
(state.seed.clone(), counter)
}