Skip to main content

oxicuda_nerf/
handle.rs

1//! Session handle for `oxicuda-nerf`.
2//!
3//! Provides `LcgRng` (Knuth MMIX LCG), `SmVersion`, and `NerfHandle`.
4
5// ─── SmVersion ───────────────────────────────────────────────────────────────
6
7/// SM (Streaming Multiprocessor) version encoded as `major*10 + minor`.
8///
9/// Examples: 80 = SM 8.0 (Ampere A100), 90 = SM 9.0 (Hopper H100),
10/// 120 = SM 12.0 (Blackwell).
11#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
12pub struct SmVersion(pub u32);
13
14impl SmVersion {
15    /// Return the raw u32 version number.
16    #[must_use]
17    #[inline]
18    pub fn as_u32(self) -> u32 {
19        self.0
20    }
21
22    /// PTX `.version` directive string for this SM.
23    #[must_use]
24    pub fn ptx_version_str(self) -> &'static str {
25        match self.0 {
26            v if v >= 100 => "8.7",
27            v if v >= 90 => "8.4",
28            v if v >= 80 => "8.0",
29            _ => "7.5",
30        }
31    }
32
33    /// PTX `.target` string for this SM (e.g., `"sm_80"`).
34    #[must_use]
35    pub fn target_str(self) -> String {
36        format!("sm_{}", self.0)
37    }
38}
39
40impl std::fmt::Display for SmVersion {
41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42        write!(f, "SM {}.{}", self.0 / 10, self.0 % 10)
43    }
44}
45
46// ─── LcgRng ──────────────────────────────────────────────────────────────────
47
48/// Minimal LCG random number generator using Knuth MMIX multiplier.
49///
50/// `state = state * 6364136223846793005 + 1442695040888963407 (mod 2⁶⁴)`
51#[derive(Debug, Clone)]
52pub struct LcgRng {
53    state: u64,
54}
55
56impl LcgRng {
57    /// Create a new LCG RNG with the given seed.
58    #[must_use]
59    pub fn new(seed: u64) -> Self {
60        Self {
61            state: seed.wrapping_add(1_442_695_040_888_963_407),
62        }
63    }
64
65    /// Advance one step and return a mixed `u32`.
66    #[inline]
67    pub fn next_u32(&mut self) -> u32 {
68        self.state = self
69            .state
70            .wrapping_mul(6_364_136_223_846_793_005)
71            .wrapping_add(1_442_695_040_888_963_407);
72        ((self.state >> 33) ^ self.state) as u32
73    }
74
75    /// Return a `f32` uniformly in `[0, 1)`.
76    #[inline]
77    pub fn next_f32(&mut self) -> f32 {
78        (self.next_u32() >> 8) as f32 / 16_777_216.0
79    }
80
81    /// Return a `usize` uniformly drawn from `[0, n)`.
82    #[inline]
83    pub fn next_usize(&mut self, n: usize) -> usize {
84        (self.next_u32() as usize) % n
85    }
86
87    /// Return a `f32` uniformly in `[lo, hi)`.
88    #[inline]
89    pub fn next_f32_range(&mut self, lo: f32, hi: f32) -> f32 {
90        lo + self.next_f32() * (hi - lo)
91    }
92
93    /// Sample two independent N(0, 1) values via Box-Muller transform.
94    pub fn next_normal_pair(&mut self) -> (f32, f32) {
95        let u1 = (self.next_f32() + 1e-10).min(1.0 - 1e-10);
96        let u2 = self.next_f32();
97        let r = (-2.0_f32 * u1.ln()).sqrt();
98        let theta = 2.0 * std::f32::consts::PI * u2;
99        (r * theta.cos(), r * theta.sin())
100    }
101
102    /// Fill `buf` with N(0, 1) samples via Box-Muller.
103    pub fn fill_normal(&mut self, buf: &mut [f32]) {
104        let mut i = 0;
105        while i + 1 < buf.len() {
106            let (a, b) = self.next_normal_pair();
107            buf[i] = a;
108            buf[i + 1] = b;
109            i += 2;
110        }
111        if i < buf.len() {
112            let (a, _) = self.next_normal_pair();
113            buf[i] = a;
114        }
115    }
116}
117
118// ─── NerfHandle ──────────────────────────────────────────────────────────────
119
120/// Lightweight session descriptor for NeRF / neural rendering operations.
121#[derive(Debug)]
122pub struct NerfHandle {
123    /// SM version for PTX generation.
124    pub sm: SmVersion,
125    /// CUDA device ordinal.
126    pub device: u32,
127    /// Deterministic RNG for CPU-side operations.
128    pub rng: LcgRng,
129}
130
131impl NerfHandle {
132    /// Create a new handle.
133    #[must_use]
134    pub fn new(device: u32, sm: SmVersion, seed: u64) -> Self {
135        Self {
136            sm,
137            device,
138            rng: LcgRng::new(seed),
139        }
140    }
141
142    /// Convenience constructor for unit-test / CPU environments
143    /// (device 0, SM 8.0, seed 42).
144    #[must_use]
145    pub fn default_handle() -> Self {
146        Self::new(0, SmVersion(80), 42)
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153
154    #[test]
155    fn lcg_rng_deterministic() {
156        let mut a = LcgRng::new(42);
157        let mut b = LcgRng::new(42);
158        for _ in 0..100 {
159            assert_eq!(a.next_u32(), b.next_u32());
160        }
161    }
162
163    #[test]
164    fn lcg_rng_f32_in_range() {
165        let mut rng = LcgRng::new(7);
166        for _ in 0..1000 {
167            let v = rng.next_f32();
168            assert!((0.0..1.0).contains(&v));
169        }
170    }
171
172    #[test]
173    fn nerf_handle_default() {
174        let h = NerfHandle::default_handle();
175        assert_eq!(h.device, 0);
176        assert_eq!(h.sm, SmVersion(80));
177    }
178
179    #[test]
180    fn sm_version_target_str() {
181        assert_eq!(SmVersion(80).target_str(), "sm_80");
182        assert_eq!(SmVersion(90).target_str(), "sm_90");
183        assert_eq!(SmVersion(120).target_str(), "sm_120");
184    }
185}