svod-tensor 0.1.0-alpha.3

High-level lazy tensor API for the Svod ML compiler
Documentation
//! Per-device RNG state: seed Tensor + host-side AtomicU64 counter.
//!
//! - **Counter is host-side `AtomicU64`**, not a device BUFFER. Each
//!   `Tensor::rand` call does a lock-free `fetch_add(num)` and stamps the value
//!   as a CONST u64 into the THREEFRY graph. The graph stays non-foldable
//!   because the **seed Tensor is a BUFFER**, so rand output depends on a
//!   runtime read.
//!
//! - **`manual_seed` does not clear the map.** Existing rand graphs captured
//!   their seed Tensor via `Arc<Buffer>`, so they remain numerically valid
//!   forever. `manual_seed` only bumps a `SEED_EPOCH`; the next
//!   `next_counter(device, …)` notices the bump and atomically swaps that
//!   device's state for a fresh one (new seed Tensor, counter reset to 0,
//!   new epoch tag).

use std::sync::Arc;
use std::sync::OnceLock;
use std::sync::atomic::{AtomicU64, Ordering};

use papaya::HashMap as PapayaMap;

use svod_dtype::DeviceSpec;

use crate::Tensor;

/// SHA256(device_index_as_u32_big_endian).digest() reinterpreted as a 256-bit
/// big-endian integer, truncated to its low 32 bits. Used as `seed_hi` per
/// device — hardcoded for the first 16 device indices to avoid a `sha2`
/// dependency. Generated by:
///
/// ```text
/// python3 -c "import hashlib; \
///   [print(f'0x{int.from_bytes(hashlib.sha256(i.to_bytes(4, \"big\")).digest(), \"big\") & 0xFFFFFFFF:08x},') \
///    for i in range(16)]"
/// ```
const SEED_HI_BY_DEVICE_IDX: [u32; 16] = [
    0x14b81119, 0x764f528d, 0xdb12cf1a, 0x0f79f7d3, // 0..3
    0x900b0937, 0x3eaa7569, 0x7c0172cb, 0x28a3a23b, // 4..7
    0x15036943, 0xdfa285d3, 0xeae31389, 0xf888ebe4, // 8..11
    0x9a1c7059, 0x4273a435, 0x5a3336e3, 0x0f41de35, // 12..15
];

/// Per-device RNG state. Immutable except for the lock-free counter advance.
/// Replaced wholesale on `manual_seed` via the epoch comparison in
/// `get_or_init_state`.
pub(crate) struct DeviceRngState {
    /// `[seed_hi, seed_lo]` BUFFER-backed Tensor. The BUFFER is what prevents
    /// THREEFRY graphs from const-folding their output to a literal value.
    pub seed: Tensor,
    /// Lock-free counter; advances by `num` per draw. Treated as a `(lo, hi)`
    /// uint32 pair inside the THREEFRY graph — high word is `(counter >> 32)`.
    pub counter: AtomicU64,
    /// Snapshot of `SEED_EPOCH` at construction. Stale entries are replaced
    /// lazily on next access; in-flight Arc clones remain valid.
    seed_epoch: u64,
}

static RNG_STATES: OnceLock<PapayaMap<DeviceSpec, Arc<DeviceRngState>>> = OnceLock::new();

/// Raw bits of the user-provided seed; written by `manual_seed`. Low 32 bits
/// become `seed_lo` for every newly-derived per-device seed Tensor.
static GLOBAL_SEED: AtomicU64 = AtomicU64::new(0);

/// Bumped on every `manual_seed` call; each `DeviceRngState` carries its own
/// snapshot, and stale snapshots trigger a lazy rebuild on next access. Starts
/// at 1 so the initial (uninitialized) GLOBAL_SEED case can be detected
/// separately (epoch 0 = never seeded → first read sets epoch 1).
static SEED_EPOCH: AtomicU64 = AtomicU64::new(0);

/// Monotonic index allocated to each first-seen device per epoch. Reset by
/// `manual_seed` so the first device after reseed gets index 0 again.
static DEVICE_INDEX_COUNTER: AtomicU64 = AtomicU64::new(0);

fn rng_states() -> &'static PapayaMap<DeviceSpec, Arc<DeviceRngState>> {
    RNG_STATES.get_or_init(PapayaMap::new)
}

/// Set the global seed and bump the epoch. **Does not clear** the per-device
/// map: existing Tensors that captured a seed/counter from an earlier epoch
/// remain numerically valid (the relevant BUFFERs are held alive by their Arc).
/// Future `Tensor::rand` calls see the new epoch and lazy-rebuild per device.
pub fn manual_seed(seed: u64) {
    GLOBAL_SEED.store(seed, Ordering::Release);
    DEVICE_INDEX_COUNTER.store(0, Ordering::Release);
    SEED_EPOCH.fetch_add(1, Ordering::AcqRel);
}

/// Read the active seed_lo (low 32 bits of GLOBAL_SEED). If the user never
/// called `manual_seed`, seed once from epoch seconds (race-tolerant; first
/// writer wins).
fn current_seed_lo() -> u32 {
    if SEED_EPOCH.load(Ordering::Acquire) == 0 {
        let now = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).map(|d| d.as_secs()).unwrap_or(0);
        let _ = GLOBAL_SEED.compare_exchange(0, now, Ordering::SeqCst, Ordering::SeqCst);
        let _ = SEED_EPOCH.compare_exchange(0, 1, Ordering::SeqCst, Ordering::SeqCst);
    }
    (GLOBAL_SEED.load(Ordering::Acquire) & 0xFFFF_FFFF) as u32
}

fn build_fresh_state(device_index: usize, seed_epoch: u64) -> DeviceRngState {
    let seed_hi = SEED_HI_BY_DEVICE_IDX.get(device_index).copied().unwrap_or_else(|| {
        panic!("svod_tensor::rand: device index {device_index} exceeds hardcoded SHA256 table (16 entries)")
    });
    let seed_lo = current_seed_lo();
    DeviceRngState { seed: Tensor::from_slice([seed_hi, seed_lo]), counter: AtomicU64::new(0), seed_epoch }
}

fn get_or_init_state(device: &DeviceSpec) -> Arc<DeviceRngState> {
    let current_epoch = {
        // Trigger lazy first-time init if needed.
        let _ = current_seed_lo();
        SEED_EPOCH.load(Ordering::Acquire)
    };

    let pinned = rng_states().pin();

    // Fast path: existing entry, current epoch.
    if let Some(state) = pinned.get(device)
        && state.seed_epoch == current_epoch
    {
        return state.clone();
    }

    // Slow path: claim a device index only when we're actually going to insert.
    // Doing the `fetch_add` eagerly before `compute()` would burn an index on
    // every concurrent caller that loses the race, drifting the sequence and
    // eventually overflowing the SHA256 table.
    //
    // The closure is `FnMut` and papaya may retry it on contention — cache the
    // freshly-built state in a local `Option` so retries reuse one
    // `fetch_add`/`build_fresh_state` rather than burning a new index per retry.
    use papaya::Operation;
    let mut prepared: Option<Arc<DeviceRngState>> = None;
    let result = pinned.compute(device.clone(), |entry| match entry {
        Some((_, state)) if state.seed_epoch == current_epoch => Operation::Abort::<Arc<DeviceRngState>, ()>(()),
        _ => {
            let fresh = prepared
                .get_or_insert_with(|| {
                    let device_index = DEVICE_INDEX_COUNTER.fetch_add(1, Ordering::AcqRel) as usize;
                    Arc::new(build_fresh_state(device_index, current_epoch))
                })
                .clone();
            Operation::Insert(fresh)
        }
    });

    match result {
        papaya::Compute::Aborted(()) => pinned.get(device).expect("entry must exist after abort").clone(),
        papaya::Compute::Inserted(_, state) => state.clone(),
        papaya::Compute::Updated { new: (_, state), .. } => state.clone(),
        papaya::Compute::Removed(_, _) => unreachable!("compute closure never returns Remove"),
    }
}

/// Returns `(seed_tensor, counter_value)` for an upcoming draw of `num`
/// uint32 words. `counter_value` is stamped into the THREEFRY graph as a
/// `CONST u64` by the caller.
pub(crate) fn next_counter(device: &DeviceSpec, num: u64) -> (Tensor, u64) {
    let state = get_or_init_state(device);
    let counter = state.counter.fetch_add(num, Ordering::AcqRel);
    (state.seed.clone(), counter)
}