Skip to main content

oxicuda_rl/
handle.rs

1//! # RlHandle — RL Session Handle
2//!
3//! [`crate::handle::RlHandle`] is a lightweight context object that carries device information
4//! needed by RL kernels (SM version for PTX header selection, random seed, and
5//! an optional stream reference).
6
7use crate::error::{RlError, RlResult};
8
9// ─── SmVersion (local mirror of driver's u32 SM version) ────────────────────
10
11/// GPU SM (Streaming Multiprocessor) version as a single integer.
12///
13/// Examples: 75 = SM 7.5 (Turing), 80 = SM 8.0 (Ampere), 90 = SM 9.0
14/// (Hopper).
15#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
16pub struct SmVersion(pub u32);
17
18impl SmVersion {
19    /// Return the version number (e.g. 80 for Ampere).
20    #[must_use]
21    #[inline]
22    pub fn as_u32(self) -> u32 {
23        self.0
24    }
25
26    /// PTX `.version` string for this SM.
27    #[must_use]
28    pub fn ptx_version_str(self) -> &'static str {
29        match self.0 {
30            v if v >= 100 => "8.7",
31            v if v >= 90 => "8.4",
32            v if v >= 80 => "8.0",
33            _ => "7.5",
34        }
35    }
36}
37
38impl std::fmt::Display for SmVersion {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        write!(f, "sm_{}", self.0)
41    }
42}
43
44// ─── RandomState ─────────────────────────────────────────────────────────────
45
46/// Minimal LCG random state for CPU-side sampling in tests and buffers.
47///
48/// Uses the same multiplier/increment as `glibc`'s `rand()`:
49/// `x_{n+1} = 1103515245 * x_n + 12345 (mod 2³¹)`.
50#[derive(Debug, Clone)]
51pub struct LcgRng {
52    state: u64,
53}
54
55impl LcgRng {
56    /// Create a new LCG with the given seed.
57    #[must_use]
58    pub fn new(seed: u64) -> Self {
59        Self {
60            state: seed.wrapping_add(1),
61        }
62    }
63
64    /// Advance one step and return a `u32` in `[0, 2³¹)`.
65    #[inline]
66    pub fn next_u32(&mut self) -> u32 {
67        self.state = self
68            .state
69            .wrapping_mul(6_364_136_223_846_793_005)
70            .wrapping_add(1_442_695_040_888_963_407);
71        (self.state >> 33) as u32
72    }
73
74    /// Return a `f32` in `[0, 1)`.
75    #[inline]
76    pub fn next_f32(&mut self) -> f32 {
77        self.next_u32() as f32 / (u32::MAX as f32 + 1.0)
78    }
79
80    /// Return a `usize` in `[0, n)`.
81    #[inline]
82    pub fn next_usize(&mut self, n: usize) -> usize {
83        (self.next_u32() as usize) % n
84    }
85}
86
87// ─── RlHandle ────────────────────────────────────────────────────────────────
88
89/// RL session handle: carries GPU metadata and a CPU-side RNG for buffer
90/// operations.
91#[derive(Debug, Clone)]
92pub struct RlHandle {
93    sm: SmVersion,
94    rng: LcgRng,
95    /// Device ordinal (0-indexed).
96    device: u32,
97}
98
99impl RlHandle {
100    /// Create a new handle for the given SM version and device.
101    #[must_use]
102    pub fn new(sm: u32, device: u32, seed: u64) -> Self {
103        Self {
104            sm: SmVersion(sm),
105            rng: LcgRng::new(seed),
106            device,
107        }
108    }
109
110    /// Create a default handle (SM 8.0 / Ampere, device 0, seed 42).
111    #[must_use]
112    pub fn default_handle() -> Self {
113        Self::new(80, 0, 42)
114    }
115
116    /// SM version.
117    #[must_use]
118    #[inline]
119    pub fn sm(&self) -> SmVersion {
120        self.sm
121    }
122
123    /// Device ordinal.
124    #[must_use]
125    #[inline]
126    pub fn device(&self) -> u32 {
127        self.device
128    }
129
130    /// Mutable access to the internal RNG (used by replay buffer sampling).
131    #[inline]
132    pub fn rng_mut(&mut self) -> &mut LcgRng {
133        &mut self.rng
134    }
135
136    /// Validate that `batch_size > 0` and `batch_size <= capacity`.
137    pub fn validate_batch(batch_size: usize, capacity: usize) -> RlResult<()> {
138        if batch_size == 0 {
139            return Err(RlError::InvalidHyperparameter {
140                name: "batch_size".into(),
141                msg: "must be > 0".into(),
142            });
143        }
144        if batch_size > capacity {
145            return Err(RlError::InsufficientTransitions {
146                have: capacity,
147                need: batch_size,
148            });
149        }
150        Ok(())
151    }
152}
153
154// ─── Tests ───────────────────────────────────────────────────────────────────
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159
160    #[test]
161    fn lcg_different_values() {
162        let mut rng = LcgRng::new(123);
163        let v1 = rng.next_u32();
164        let v2 = rng.next_u32();
165        assert_ne!(v1, v2, "LCG should produce different values");
166    }
167
168    #[test]
169    fn lcg_f32_in_range() {
170        let mut rng = LcgRng::new(0);
171        for _ in 0..1000 {
172            let v = rng.next_f32();
173            assert!((0.0..1.0).contains(&v), "f32 out of [0,1): {v}");
174        }
175    }
176
177    #[test]
178    fn lcg_usize_in_range() {
179        let mut rng = LcgRng::new(7);
180        for _ in 0..1000 {
181            let v = rng.next_usize(10);
182            assert!(v < 10, "usize out of [0,10): {v}");
183        }
184    }
185
186    #[test]
187    fn sm_version_ordering() {
188        assert!(SmVersion(80) > SmVersion(75));
189        assert!(SmVersion(90) > SmVersion(80));
190    }
191
192    #[test]
193    fn sm_version_ptx_str() {
194        assert_eq!(SmVersion(75).ptx_version_str(), "7.5");
195        assert_eq!(SmVersion(80).ptx_version_str(), "8.0");
196        assert_eq!(SmVersion(90).ptx_version_str(), "8.4");
197    }
198
199    #[test]
200    fn rl_handle_default() {
201        let h = RlHandle::default_handle();
202        assert_eq!(h.sm().as_u32(), 80);
203        assert_eq!(h.device(), 0);
204    }
205
206    #[test]
207    fn validate_batch_ok() {
208        RlHandle::validate_batch(32, 1024).unwrap();
209    }
210
211    #[test]
212    fn validate_batch_zero_error() {
213        assert!(RlHandle::validate_batch(0, 100).is_err());
214    }
215
216    #[test]
217    fn validate_batch_too_large_error() {
218        assert!(RlHandle::validate_batch(200, 100).is_err());
219    }
220}