Skip to main content

oxicuda_vision/
handle.rs

1//! Session handle for `oxicuda-vision`.
2//!
3//! `VisionHandle` stores the compute device index, the GPU SM version,
4//! and a deterministic LCG random number generator for CPU-side operations
5//! (buffer initialisation, test scaffolding).
6
7// ─── SmVersion ───────────────────────────────────────────────────────────────
8
9/// SM (Streaming Multiprocessor) version encoded as `major*10 + minor`.
10///
11/// Examples: 80 = SM 8.0 (Ampere A100), 90 = SM 9.0 (Hopper H100),
12/// 120 = SM 12.0 (Blackwell).
13#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
14pub struct SmVersion(pub u32);
15
16impl SmVersion {
17    /// Return the raw u32 version number.
18    #[must_use]
19    #[inline]
20    pub fn as_u32(self) -> u32 {
21        self.0
22    }
23
24    /// PTX `.version` directive string for this SM.
25    ///
26    /// PTX ISA 8.7 covers SM 12.x, 8.4 covers SM 9.x,
27    /// 8.0 covers SM 8.x, 7.5 covers SM 7.x.
28    #[must_use]
29    pub fn ptx_version_str(self) -> &'static str {
30        match self.0 {
31            v if v >= 100 => "8.7",
32            v if v >= 90 => "8.4",
33            v if v >= 80 => "8.0",
34            _ => "7.5",
35        }
36    }
37
38    /// PTX `.target` string for this SM (e.g., `"sm_80"`).
39    #[must_use]
40    pub fn target_str(self) -> String {
41        format!("sm_{}", self.0)
42    }
43}
44
45impl std::fmt::Display for SmVersion {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        write!(f, "SM {}.{}", self.0 / 10, self.0 % 10)
48    }
49}
50
51// ─── LcgRng ──────────────────────────────────────────────────────────────────
52
53/// Minimal LCG random number generator for deterministic CPU-side sampling
54/// in tests and buffer initialisation.
55///
56/// Uses the Knuth MMIX 64-bit LCG multiplier:
57/// `x_{n+1} = 6364136223846793005 * x_n + 1442695040888963407 (mod 2⁶⁴)`.
58#[derive(Debug, Clone)]
59pub struct LcgRng {
60    state: u64,
61}
62
63impl LcgRng {
64    /// Create a new LCG with the given seed.
65    #[must_use]
66    pub fn new(seed: u64) -> Self {
67        Self {
68            state: seed.wrapping_add(1),
69        }
70    }
71
72    /// Advance one step and return a `u32` drawn from the high 32 bits.
73    #[inline]
74    pub fn next_u32(&mut self) -> u32 {
75        self.state = self
76            .state
77            .wrapping_mul(6_364_136_223_846_793_005)
78            .wrapping_add(1_442_695_040_888_963_407);
79        ((self.state >> 33) ^ self.state) as u32
80    }
81
82    /// Return a `f32` uniformly distributed in `[0, 1)`.
83    #[inline]
84    pub fn next_f32(&mut self) -> f32 {
85        self.next_u32() as f32 / (u32::MAX as f32 + 1.0)
86    }
87
88    /// Return a `usize` uniformly drawn from `[0, n)`.
89    ///
90    /// Returns 0 if `n == 0`.
91    #[inline]
92    pub fn next_usize(&mut self, n: usize) -> usize {
93        if n == 0 {
94            return 0;
95        }
96        (self.next_u32() as usize) % n
97    }
98
99    /// Sample two independent N(0, 1) values via Box-Muller transform.
100    pub fn next_normal_pair(&mut self) -> (f32, f32) {
101        let u1 = (self.next_f32() + 1e-10).min(1.0 - 1e-10);
102        let u2 = self.next_f32();
103        let r = (-2.0 * u1.ln()).sqrt();
104        let theta = 2.0 * std::f32::consts::PI * u2;
105        (r * theta.cos(), r * theta.sin())
106    }
107
108    /// Fill `buf` with N(0, 1) samples via Box-Muller.
109    pub fn fill_normal(&mut self, buf: &mut [f32]) {
110        let mut i = 0;
111        while i + 1 < buf.len() {
112            let (a, b) = self.next_normal_pair();
113            buf[i] = a;
114            buf[i + 1] = b;
115            i += 2;
116        }
117        if i < buf.len() {
118            let (a, _) = self.next_normal_pair();
119            buf[i] = a;
120        }
121    }
122
123    /// Shuffle a slice in-place using Fisher-Yates.
124    pub fn shuffle<T>(&mut self, slice: &mut [T]) {
125        let n = slice.len();
126        for i in (1..n).rev() {
127            let j = self.next_usize(i + 1);
128            slice.swap(i, j);
129        }
130    }
131}
132
133// ─── VisionHandle ────────────────────────────────────────────────────────────
134
135/// Lightweight session descriptor for vision model operations.
136///
137/// A `VisionHandle` does **not** open a CUDA context; it merely records which
138/// device and SM version are targeted so that PTX kernel generators can
139/// emit architecture-appropriate code, and carries an LCG RNG for
140/// deterministic CPU-side sampling and parameter initialisation.
141#[derive(Debug, Clone)]
142pub struct VisionHandle {
143    sm: SmVersion,
144    rng: LcgRng,
145    device: u32,
146}
147
148impl VisionHandle {
149    /// Create a new handle for the given device, SM version, and RNG seed.
150    #[must_use]
151    pub fn new(device: u32, sm: SmVersion, seed: u64) -> Self {
152        Self {
153            sm,
154            rng: LcgRng::new(seed),
155            device,
156        }
157    }
158
159    /// Convenience constructor for unit-test / CPU environments
160    /// (device 0, SM 8.0, seed 42).
161    #[must_use]
162    pub fn default_handle() -> Self {
163        Self::new(0, SmVersion(80), 42)
164    }
165
166    /// Return the SM version.
167    #[must_use]
168    pub fn sm_version(&self) -> SmVersion {
169        self.sm
170    }
171
172    /// Return the device ordinal.
173    #[must_use]
174    pub fn device(&self) -> u32 {
175        self.device
176    }
177
178    /// Return a mutable reference to the internal RNG.
179    pub fn rng_mut(&mut self) -> &mut LcgRng {
180        &mut self.rng
181    }
182}
183
184// ─── Tests ───────────────────────────────────────────────────────────────────
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189
190    #[test]
191    fn sm_version_ptx_strings() {
192        assert_eq!(SmVersion(75).ptx_version_str(), "7.5");
193        assert_eq!(SmVersion(80).ptx_version_str(), "8.0");
194        assert_eq!(SmVersion(86).ptx_version_str(), "8.0");
195        assert_eq!(SmVersion(90).ptx_version_str(), "8.4");
196        assert_eq!(SmVersion(100).ptx_version_str(), "8.7");
197        assert_eq!(SmVersion(120).ptx_version_str(), "8.7");
198    }
199
200    #[test]
201    fn sm_version_target_str() {
202        assert_eq!(SmVersion(80).target_str(), "sm_80");
203        assert_eq!(SmVersion(90).target_str(), "sm_90");
204        assert_eq!(SmVersion(120).target_str(), "sm_120");
205    }
206
207    #[test]
208    fn sm_version_display() {
209        assert_eq!(SmVersion(80).to_string(), "SM 8.0");
210        assert_eq!(SmVersion(90).to_string(), "SM 9.0");
211    }
212
213    #[test]
214    fn sm_version_ordering() {
215        assert!(SmVersion(80) < SmVersion(90));
216        assert!(SmVersion(100) > SmVersion(90));
217        assert_eq!(SmVersion(80), SmVersion(80));
218    }
219
220    #[test]
221    fn sm_version_as_u32() {
222        assert_eq!(SmVersion(86).as_u32(), 86);
223    }
224
225    #[test]
226    fn vision_handle_default() {
227        let h = VisionHandle::default_handle();
228        assert_eq!(h.device(), 0);
229        assert_eq!(h.sm_version(), SmVersion(80));
230    }
231
232    #[test]
233    fn vision_handle_custom() {
234        let h = VisionHandle::new(2, SmVersion(90), 12345);
235        assert_eq!(h.device(), 2);
236        assert_eq!(h.sm_version(), SmVersion(90));
237    }
238
239    #[test]
240    fn lcg_rng_determinism() {
241        let mut a = LcgRng::new(42);
242        let mut b = LcgRng::new(42);
243        for _ in 0..100 {
244            assert_eq!(a.next_u32(), b.next_u32());
245        }
246    }
247
248    #[test]
249    fn lcg_rng_f32_in_range() {
250        let mut rng = LcgRng::new(7);
251        for _ in 0..1000 {
252            let v = rng.next_f32();
253            assert!((0.0..1.0).contains(&v), "out of range: {v}");
254        }
255    }
256
257    #[test]
258    fn lcg_rng_usize_in_range() {
259        let mut rng = LcgRng::new(99);
260        for _ in 0..1000 {
261            let v = rng.next_usize(7);
262            assert!(v < 7, "out of range: {v}");
263        }
264    }
265
266    #[test]
267    fn lcg_rng_normal_fill_finite() {
268        let mut rng = LcgRng::new(13);
269        let mut buf = vec![0.0_f32; 64];
270        rng.fill_normal(&mut buf);
271        assert!(buf.iter().all(|v| v.is_finite()));
272    }
273
274    #[test]
275    fn lcg_rng_shuffle_permutes() {
276        let mut rng = LcgRng::new(77);
277        let mut v: Vec<usize> = (0..8).collect();
278        rng.shuffle(&mut v);
279        let mut sorted = v.clone();
280        sorted.sort_unstable();
281        assert_eq!(sorted, (0..8).collect::<Vec<_>>());
282    }
283
284    #[test]
285    fn vision_handle_rng_mut() {
286        let mut h = VisionHandle::default_handle();
287        let v = h.rng_mut().next_f32();
288        assert!((0.0..1.0).contains(&v));
289    }
290}