Skip to main content

oxicuda_ssl/
handle.rs

1//! Session handle for `oxicuda-ssl`.
2//!
3//! `SslHandle` stores the compute device index, the GPU SM version,
4//! and a deterministic LCG random number generator for CPU-side operations
5//! (data augmentation, momentum schedules, masking, projector init).
6
7// ─── SmVersion ───────────────────────────────────────────────────────────────
8
9/// SM (Streaming Multiprocessor) version encoded as `major*10 + minor`.
10///
11/// Examples: 80 = SM 8.0 (Ampere A100), 90 = SM 9.0 (Hopper H100),
12/// 120 = SM 12.0 (Blackwell).
13#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
14pub struct SmVersion(pub u32);
15
16impl SmVersion {
17    /// Return the raw u32 version number.
18    #[must_use]
19    #[inline]
20    pub fn as_u32(self) -> u32 {
21        self.0
22    }
23
24    /// PTX `.version` directive string for this SM.
25    #[must_use]
26    pub fn ptx_version_str(self) -> &'static str {
27        match self.0 {
28            v if v >= 100 => "8.7",
29            v if v >= 90 => "8.4",
30            v if v >= 80 => "8.0",
31            _ => "7.5",
32        }
33    }
34
35    /// PTX `.target` string for this SM (e.g., `"sm_80"`).
36    #[must_use]
37    pub fn target_str(self) -> String {
38        format!("sm_{}", self.0)
39    }
40}
41
42impl std::fmt::Display for SmVersion {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        write!(f, "SM {}.{}", self.0 / 10, self.0 % 10)
45    }
46}
47
48// ─── LcgRng ──────────────────────────────────────────────────────────────────
49
50/// Minimal LCG random number generator for deterministic CPU-side sampling.
51///
52/// Uses the Knuth MMIX 64-bit LCG multiplier:
53/// `x_{n+1} = 6364136223846793005 * x_n + 1442695040888963407 (mod 2⁶⁴)`.
54#[derive(Debug, Clone)]
55pub struct LcgRng {
56    state: u64,
57}
58
59impl LcgRng {
60    /// Create a new LCG with the given seed.
61    #[must_use]
62    pub fn new(seed: u64) -> Self {
63        Self {
64            state: seed.wrapping_add(1),
65        }
66    }
67
68    /// Advance one step and return a `u32` drawn from the high 32 bits.
69    #[inline]
70    pub fn next_u32(&mut self) -> u32 {
71        self.state = self
72            .state
73            .wrapping_mul(6_364_136_223_846_793_005)
74            .wrapping_add(1_442_695_040_888_963_407);
75        ((self.state >> 33) ^ self.state) as u32
76    }
77
78    /// Return a `f32` uniformly distributed in `[0, 1)`.
79    #[inline]
80    pub fn next_f32(&mut self) -> f32 {
81        self.next_u32() as f32 / (u32::MAX as f32 + 1.0)
82    }
83
84    /// Return a `usize` uniformly drawn from `[0, n)`.
85    #[inline]
86    pub fn next_usize(&mut self, n: usize) -> usize {
87        if n == 0 {
88            return 0;
89        }
90        (self.next_u32() as usize) % n
91    }
92
93    /// Sample two independent N(0, 1) values via Box-Muller transform.
94    pub fn next_normal_pair(&mut self) -> (f32, f32) {
95        let u1 = (self.next_f32() + 1e-10).min(1.0 - 1e-10);
96        let u2 = self.next_f32();
97        let r = (-2.0 * u1.ln()).sqrt();
98        let theta = 2.0 * std::f32::consts::PI * u2;
99        (r * theta.cos(), r * theta.sin())
100    }
101
102    /// Fill `buf` with N(0, 1) samples via Box-Muller.
103    pub fn fill_normal(&mut self, buf: &mut [f32]) {
104        let mut i = 0;
105        while i + 1 < buf.len() {
106            let (a, b) = self.next_normal_pair();
107            buf[i] = a;
108            buf[i + 1] = b;
109            i += 2;
110        }
111        if i < buf.len() {
112            let (a, _) = self.next_normal_pair();
113            buf[i] = a;
114        }
115    }
116
117    /// Shuffle a slice in-place using Fisher-Yates.
118    pub fn shuffle<T>(&mut self, slice: &mut [T]) {
119        let n = slice.len();
120        for i in (1..n).rev() {
121            let j = self.next_usize(i + 1);
122            slice.swap(i, j);
123        }
124    }
125}
126
127// ─── SslHandle ───────────────────────────────────────────────────────────────
128
129/// Lightweight session descriptor for SSL operations.
130#[derive(Debug, Clone)]
131pub struct SslHandle {
132    sm: SmVersion,
133    rng: LcgRng,
134    device: u32,
135}
136
137impl SslHandle {
138    /// Create a new handle.
139    #[must_use]
140    pub fn new(device: u32, sm: SmVersion, seed: u64) -> Self {
141        Self {
142            sm,
143            rng: LcgRng::new(seed),
144            device,
145        }
146    }
147
148    /// Convenience constructor for unit-test / CPU environments
149    /// (device 0, SM 8.0, seed 42).
150    #[must_use]
151    pub fn default_handle() -> Self {
152        Self::new(0, SmVersion(80), 42)
153    }
154
155    /// Return the SM version.
156    #[must_use]
157    pub fn sm_version(&self) -> SmVersion {
158        self.sm
159    }
160
161    /// Return the device ordinal.
162    #[must_use]
163    pub fn device(&self) -> u32 {
164        self.device
165    }
166
167    /// Return a mutable reference to the internal RNG.
168    pub fn rng_mut(&mut self) -> &mut LcgRng {
169        &mut self.rng
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176
177    #[test]
178    fn sm_version_ptx_strings() {
179        assert_eq!(SmVersion(75).ptx_version_str(), "7.5");
180        assert_eq!(SmVersion(80).ptx_version_str(), "8.0");
181        assert_eq!(SmVersion(86).ptx_version_str(), "8.0");
182        assert_eq!(SmVersion(90).ptx_version_str(), "8.4");
183        assert_eq!(SmVersion(100).ptx_version_str(), "8.7");
184        assert_eq!(SmVersion(120).ptx_version_str(), "8.7");
185    }
186
187    #[test]
188    fn sm_version_target_str() {
189        assert_eq!(SmVersion(80).target_str(), "sm_80");
190        assert_eq!(SmVersion(90).target_str(), "sm_90");
191        assert_eq!(SmVersion(120).target_str(), "sm_120");
192    }
193
194    #[test]
195    fn sm_version_display() {
196        assert_eq!(SmVersion(80).to_string(), "SM 8.0");
197        assert_eq!(SmVersion(120).to_string(), "SM 12.0");
198    }
199
200    #[test]
201    fn sm_version_ordering() {
202        assert!(SmVersion(80) < SmVersion(90));
203        assert!(SmVersion(100) > SmVersion(90));
204    }
205
206    #[test]
207    fn ssl_handle_default() {
208        let h = SslHandle::default_handle();
209        assert_eq!(h.device(), 0);
210        assert_eq!(h.sm_version(), SmVersion(80));
211    }
212
213    #[test]
214    fn ssl_handle_custom() {
215        let h = SslHandle::new(2, SmVersion(120), 99);
216        assert_eq!(h.device(), 2);
217        assert_eq!(h.sm_version(), SmVersion(120));
218    }
219
220    #[test]
221    fn lcg_rng_determinism() {
222        let mut a = LcgRng::new(7);
223        let mut b = LcgRng::new(7);
224        for _ in 0..100 {
225            assert_eq!(a.next_u32(), b.next_u32());
226        }
227    }
228
229    #[test]
230    fn lcg_rng_f32_in_range() {
231        let mut rng = LcgRng::new(11);
232        for _ in 0..1000 {
233            let v = rng.next_f32();
234            assert!((0.0..1.0).contains(&v));
235        }
236    }
237
238    #[test]
239    fn lcg_rng_normal_finite() {
240        let mut rng = LcgRng::new(13);
241        let mut buf = vec![0.0_f32; 64];
242        rng.fill_normal(&mut buf);
243        assert!(buf.iter().all(|v| v.is_finite()));
244    }
245
246    #[test]
247    fn lcg_rng_shuffle_preserves_elements() {
248        let mut rng = LcgRng::new(17);
249        let mut v: Vec<usize> = (0..16).collect();
250        rng.shuffle(&mut v);
251        let mut sorted = v.clone();
252        sorted.sort_unstable();
253        assert_eq!(sorted, (0..16).collect::<Vec<_>>());
254    }
255
256    #[test]
257    fn lcg_next_usize_in_range() {
258        let mut rng = LcgRng::new(19);
259        for _ in 0..200 {
260            let v = rng.next_usize(10);
261            assert!(v < 10);
262        }
263    }
264}