Skip to main content

oxicuda_graphalg/
handle.rs

1//! Handle and RNG primitives for `oxicuda-graphalg`.
2//!
3//! Mirrors the pattern of `oxicuda-tn` and `oxicuda-cvx`.
4//!
5//! Provides:
6//! - [`SmVersion`] tagging GPU compute capability.
7//! - [`LcgRng`] a minimal LCG (MMIX variant) pseudo-random generator that uses bit 32
8//!   for boolean sampling (LCG low bits have period defects).
9//! - [`GraphalgHandle`] bundling SM version and a seeded RNG.
10
11/// SM compute capability version (e.g. 75 for SM 7.5, 100 for Blackwell SM 10.0).
12#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
13pub struct SmVersion(pub u32);
14
15impl SmVersion {
16    pub const SM_75: Self = Self(75);
17    pub const SM_80: Self = Self(80);
18    pub const SM_86: Self = Self(86);
19    pub const SM_89: Self = Self(89);
20    pub const SM_90: Self = Self(90);
21    pub const SM_100: Self = Self(100);
22
23    /// Return the numeric compute capability.
24    #[must_use]
25    pub fn value(self) -> u32 {
26        self.0
27    }
28}
29
30/// Minimal LCG (MMIX variant) pseudo-random number generator.
31///
32/// Uses the Knuth MMIX constants for a full-period 64-bit LCG.
33/// CRITICAL: `next_bool()` uses bit 32, NOT bit 0 (bit 0 has period 2 in MMIX LCG).
34#[derive(Debug, Clone)]
35pub struct LcgRng {
36    state: u64,
37}
38
39impl LcgRng {
40    const MUL: u64 = 6_364_136_223_846_793_005;
41    const ADD: u64 = 1_442_695_040_888_963_407;
42
43    /// Create a new LCG seeded with `seed`.
44    #[must_use]
45    pub fn new(seed: u64) -> Self {
46        Self {
47            state: seed.wrapping_mul(0x9E37_79B9_7F4A_7C15).wrapping_add(1),
48        }
49    }
50
51    /// Advance the state and return the next 64-bit value.
52    pub fn next_u64(&mut self) -> u64 {
53        self.state = self.state.wrapping_mul(Self::MUL).wrapping_add(Self::ADD);
54        self.state
55    }
56
57    /// Return a uniform `f64` in `[0, 1)`.
58    pub fn next_f64(&mut self) -> f64 {
59        (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
60    }
61
62    /// Return a random `bool`. Uses bit 32 (NOT bit 0) for better statistical quality.
63    pub fn next_bool(&mut self) -> bool {
64        (self.next_u64() >> 32) & 1 == 1
65    }
66
67    /// Return a random `usize` in `[0, n)`.
68    pub fn next_usize(&mut self, n: usize) -> usize {
69        (self.next_u64() as usize) % n.max(1)
70    }
71
72    /// Return a uniform `f64` in `[lo, hi)`.
73    pub fn next_range(&mut self, lo: f64, hi: f64) -> f64 {
74        lo + (hi - lo) * self.next_f64()
75    }
76
77    /// Draw from a standard normal distribution via Box-Muller.
78    pub fn next_normal(&mut self) -> f64 {
79        let u1 = self.next_f64().max(1.0e-300);
80        let u2 = self.next_f64();
81        (-2.0 * u1.ln()).sqrt() * (std::f64::consts::TAU * u2).cos()
82    }
83}
84
85/// Top-level graph algorithm computation handle bundling SM version and a seeded RNG.
86#[derive(Debug, Clone)]
87pub struct GraphalgHandle {
88    pub sm: SmVersion,
89    pub rng: LcgRng,
90}
91
92impl GraphalgHandle {
93    /// Create a new handle for the given SM version and RNG seed.
94    #[must_use]
95    pub fn new(sm: SmVersion, seed: u64) -> Self {
96        Self {
97            sm,
98            rng: LcgRng::new(seed),
99        }
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106
107    #[test]
108    fn sm_constants_correct() {
109        assert_eq!(SmVersion::SM_75.value(), 75);
110        assert_eq!(SmVersion::SM_100.value(), 100);
111    }
112
113    #[test]
114    fn lcg_deterministic() {
115        let mut r1 = LcgRng::new(42);
116        let mut r2 = LcgRng::new(42);
117        for _ in 0..16 {
118            assert_eq!(r1.next_u64(), r2.next_u64());
119        }
120    }
121
122    #[test]
123    fn lcg_unit_interval() {
124        let mut r = LcgRng::new(7);
125        for _ in 0..1000 {
126            let v = r.next_f64();
127            assert!((0.0..1.0).contains(&v));
128        }
129    }
130
131    #[test]
132    fn lcg_bool_balanced() {
133        let mut r = LcgRng::new(13);
134        let mut trues = 0usize;
135        for _ in 0..1000 {
136            if r.next_bool() {
137                trues += 1;
138            }
139        }
140        assert!(trues > 400 && trues < 600);
141    }
142
143    #[test]
144    fn lcg_normal_mean_zero() {
145        let mut r = LcgRng::new(99);
146        let n = 10_000;
147        let mean: f64 = (0..n).map(|_| r.next_normal()).sum::<f64>() / n as f64;
148        assert!(mean.abs() < 0.2);
149    }
150
151    #[test]
152    fn handle_construction() {
153        let h = GraphalgHandle::new(SmVersion::SM_90, 0);
154        assert_eq!(h.sm.value(), 90);
155    }
156}