Skip to main content

oxicuda_seq/
handle.rs

1//! Handle and RNG primitives for `oxicuda-seq`.
2
3/// SM compute capability version (e.g. 75 for SM 7.5, 86 for SM 8.6).
4#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
5pub struct SmVersion(pub u32);
6
7impl SmVersion {
8    pub const SM_75: Self = Self(75);
9    pub const SM_80: Self = Self(80);
10    pub const SM_86: Self = Self(86);
11    pub const SM_89: Self = Self(89);
12    pub const SM_90: Self = Self(90);
13    pub const SM_100: Self = Self(100);
14
15    /// Get the numeric SM compute capability value.
16    pub fn value(self) -> u32 {
17        self.0
18    }
19}
20
21/// Minimal LCG (MMIX variant) pseudo-random number generator.
22///
23/// Uses Knuth's MMIX constants for a full-period 64-bit LCG.
24/// CRITICAL: `next_bool()` uses bit 32, NOT bit 0 (bit 0 has period 2 in MMIX LCG).
25#[derive(Debug, Clone)]
26pub struct LcgRng {
27    state: u64,
28}
29
30impl LcgRng {
31    const MUL: u64 = 6_364_136_223_846_793_005;
32    const ADD: u64 = 1_442_695_040_888_963_407;
33
34    /// Create a new LCG seeded with `seed`.
35    pub fn new(seed: u64) -> Self {
36        Self {
37            state: seed.wrapping_mul(0x9E37_79B9_7F4A_7C15).wrapping_add(1),
38        }
39    }
40
41    /// Advance the state and return the next 64-bit value.
42    pub fn next_u64(&mut self) -> u64 {
43        self.state = self.state.wrapping_mul(Self::MUL).wrapping_add(Self::ADD);
44        self.state
45    }
46
47    /// Return a uniform `f64` in `[0, 1)`.
48    pub fn next_f64(&mut self) -> f64 {
49        (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
50    }
51
52    /// Return a random bool. Uses bit 32 (NOT bit 0) for better statistical quality.
53    pub fn next_bool(&mut self) -> bool {
54        (self.next_u64() >> 32) & 1 == 1
55    }
56
57    /// Return a standard-normal sample via Box-Muller.
58    pub fn next_normal(&mut self) -> f64 {
59        let u1 = self.next_f64().max(1e-300);
60        let u2 = self.next_f64();
61        (-2.0 * u1.ln()).sqrt() * (std::f64::consts::TAU * u2).cos()
62    }
63
64    /// Uniform sample on `[lo, hi)`.
65    pub fn next_range(&mut self, lo: f64, hi: f64) -> f64 {
66        lo + (hi - lo) * self.next_f64()
67    }
68
69    /// Sample a categorical index given a probability vector.  Probabilities are
70    /// expected to sum to ~1; if the sum is short due to rounding, the last index
71    /// is returned.
72    pub fn sample_categorical(&mut self, probs: &[f64]) -> usize {
73        let u = self.next_f64();
74        let mut acc = 0.0;
75        for (i, &p) in probs.iter().enumerate() {
76            acc += p;
77            if u <= acc {
78                return i;
79            }
80        }
81        if probs.is_empty() { 0 } else { probs.len() - 1 }
82    }
83
84    /// Return a random `usize` in `[0, n)`.  Returns 0 if `n == 0`.
85    pub fn next_usize(&mut self, n: usize) -> usize {
86        if n == 0 {
87            0
88        } else {
89            (self.next_u64() as usize) % n
90        }
91    }
92}
93
94/// Top-level sequence-model computation handle bundling SM version and a seeded RNG.
95#[derive(Debug, Clone)]
96pub struct SeqHandle {
97    pub sm: SmVersion,
98    pub rng: LcgRng,
99}
100
101impl SeqHandle {
102    /// Create a new handle from an `SmVersion` and seed.
103    pub fn new(sm: SmVersion, seed: u64) -> Self {
104        Self {
105            sm,
106            rng: LcgRng::new(seed),
107        }
108    }
109
110    /// Create a new handle from a raw SM numeric code and seed.
111    pub fn from_sm_code(sm_code: u32, seed: u64) -> Self {
112        Self::new(SmVersion(sm_code), seed)
113    }
114}