1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
//! Handle and RNG primitives for `oxicuda-seq`.
/// SM compute capability version (e.g. 75 for SM 7.5, 86 for SM 8.6).
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct SmVersion(pub u32);
impl SmVersion {
pub const SM_75: Self = Self(75);
pub const SM_80: Self = Self(80);
pub const SM_86: Self = Self(86);
pub const SM_89: Self = Self(89);
pub const SM_90: Self = Self(90);
pub const SM_100: Self = Self(100);
/// Get the numeric SM compute capability value.
pub fn value(self) -> u32 {
self.0
}
}
/// Minimal LCG (MMIX variant) pseudo-random number generator.
///
/// Uses Knuth's MMIX constants for a full-period 64-bit LCG.
/// CRITICAL: `next_bool()` uses bit 32, NOT bit 0 (bit 0 has period 2 in MMIX LCG).
#[derive(Debug, Clone)]
pub struct LcgRng {
state: u64,
}
impl LcgRng {
const MUL: u64 = 6_364_136_223_846_793_005;
const ADD: u64 = 1_442_695_040_888_963_407;
/// Create a new LCG seeded with `seed`.
pub fn new(seed: u64) -> Self {
Self {
state: seed.wrapping_mul(0x9E37_79B9_7F4A_7C15).wrapping_add(1),
}
}
/// Advance the state and return the next 64-bit value.
pub fn next_u64(&mut self) -> u64 {
self.state = self.state.wrapping_mul(Self::MUL).wrapping_add(Self::ADD);
self.state
}
/// Return a uniform `f64` in `[0, 1)`.
pub fn next_f64(&mut self) -> f64 {
(self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
}
/// Return a random bool. Uses bit 32 (NOT bit 0) for better statistical quality.
pub fn next_bool(&mut self) -> bool {
(self.next_u64() >> 32) & 1 == 1
}
/// Return a standard-normal sample via Box-Muller.
pub fn next_normal(&mut self) -> f64 {
let u1 = self.next_f64().max(1e-300);
let u2 = self.next_f64();
(-2.0 * u1.ln()).sqrt() * (std::f64::consts::TAU * u2).cos()
}
/// Uniform sample on `[lo, hi)`.
pub fn next_range(&mut self, lo: f64, hi: f64) -> f64 {
lo + (hi - lo) * self.next_f64()
}
/// Sample a categorical index given a probability vector. Probabilities are
/// expected to sum to ~1; if the sum is short due to rounding, the last index
/// is returned.
pub fn sample_categorical(&mut self, probs: &[f64]) -> usize {
let u = self.next_f64();
let mut acc = 0.0;
for (i, &p) in probs.iter().enumerate() {
acc += p;
if u <= acc {
return i;
}
}
if probs.is_empty() { 0 } else { probs.len() - 1 }
}
/// Return a random `usize` in `[0, n)`. Returns 0 if `n == 0`.
pub fn next_usize(&mut self, n: usize) -> usize {
if n == 0 {
0
} else {
(self.next_u64() as usize) % n
}
}
}
/// Top-level sequence-model computation handle bundling SM version and a seeded RNG.
#[derive(Debug, Clone)]
pub struct SeqHandle {
pub sm: SmVersion,
pub rng: LcgRng,
}
impl SeqHandle {
/// Create a new handle from an `SmVersion` and seed.
pub fn new(sm: SmVersion, seed: u64) -> Self {
Self {
sm,
rng: LcgRng::new(seed),
}
}
/// Create a new handle from a raw SM numeric code and seed.
pub fn from_sm_code(sm_code: u32, seed: u64) -> Self {
Self::new(SmVersion(sm_code), seed)
}
}