1use crate::error::{RlError, RlResult};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
16pub struct SmVersion(pub u32);
17
18impl SmVersion {
19 #[must_use]
21 #[inline]
22 pub fn as_u32(self) -> u32 {
23 self.0
24 }
25
26 #[must_use]
28 pub fn ptx_version_str(self) -> &'static str {
29 match self.0 {
30 v if v >= 100 => "8.7",
31 v if v >= 90 => "8.4",
32 v if v >= 80 => "8.0",
33 _ => "7.5",
34 }
35 }
36}
37
38impl std::fmt::Display for SmVersion {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 write!(f, "sm_{}", self.0)
41 }
42}
43
44#[derive(Debug, Clone)]
51pub struct LcgRng {
52 state: u64,
53}
54
55impl LcgRng {
56 #[must_use]
58 pub fn new(seed: u64) -> Self {
59 Self {
60 state: seed.wrapping_add(1),
61 }
62 }
63
64 #[inline]
66 pub fn next_u32(&mut self) -> u32 {
67 self.state = self
68 .state
69 .wrapping_mul(6_364_136_223_846_793_005)
70 .wrapping_add(1_442_695_040_888_963_407);
71 (self.state >> 33) as u32
72 }
73
74 #[inline]
76 pub fn next_f32(&mut self) -> f32 {
77 self.next_u32() as f32 / (u32::MAX as f32 + 1.0)
78 }
79
80 #[inline]
82 pub fn next_usize(&mut self, n: usize) -> usize {
83 (self.next_u32() as usize) % n
84 }
85}
86
87#[derive(Debug, Clone)]
92pub struct RlHandle {
93 sm: SmVersion,
94 rng: LcgRng,
95 device: u32,
97}
98
99impl RlHandle {
100 #[must_use]
102 pub fn new(sm: u32, device: u32, seed: u64) -> Self {
103 Self {
104 sm: SmVersion(sm),
105 rng: LcgRng::new(seed),
106 device,
107 }
108 }
109
110 #[must_use]
112 pub fn default_handle() -> Self {
113 Self::new(80, 0, 42)
114 }
115
116 #[must_use]
118 #[inline]
119 pub fn sm(&self) -> SmVersion {
120 self.sm
121 }
122
123 #[must_use]
125 #[inline]
126 pub fn device(&self) -> u32 {
127 self.device
128 }
129
130 #[inline]
132 pub fn rng_mut(&mut self) -> &mut LcgRng {
133 &mut self.rng
134 }
135
136 pub fn validate_batch(batch_size: usize, capacity: usize) -> RlResult<()> {
138 if batch_size == 0 {
139 return Err(RlError::InvalidHyperparameter {
140 name: "batch_size".into(),
141 msg: "must be > 0".into(),
142 });
143 }
144 if batch_size > capacity {
145 return Err(RlError::InsufficientTransitions {
146 have: capacity,
147 need: batch_size,
148 });
149 }
150 Ok(())
151 }
152}
153
154#[cfg(test)]
157mod tests {
158 use super::*;
159
160 #[test]
161 fn lcg_different_values() {
162 let mut rng = LcgRng::new(123);
163 let v1 = rng.next_u32();
164 let v2 = rng.next_u32();
165 assert_ne!(v1, v2, "LCG should produce different values");
166 }
167
168 #[test]
169 fn lcg_f32_in_range() {
170 let mut rng = LcgRng::new(0);
171 for _ in 0..1000 {
172 let v = rng.next_f32();
173 assert!((0.0..1.0).contains(&v), "f32 out of [0,1): {v}");
174 }
175 }
176
177 #[test]
178 fn lcg_usize_in_range() {
179 let mut rng = LcgRng::new(7);
180 for _ in 0..1000 {
181 let v = rng.next_usize(10);
182 assert!(v < 10, "usize out of [0,10): {v}");
183 }
184 }
185
186 #[test]
187 fn sm_version_ordering() {
188 assert!(SmVersion(80) > SmVersion(75));
189 assert!(SmVersion(90) > SmVersion(80));
190 }
191
192 #[test]
193 fn sm_version_ptx_str() {
194 assert_eq!(SmVersion(75).ptx_version_str(), "7.5");
195 assert_eq!(SmVersion(80).ptx_version_str(), "8.0");
196 assert_eq!(SmVersion(90).ptx_version_str(), "8.4");
197 }
198
199 #[test]
200 fn rl_handle_default() {
201 let h = RlHandle::default_handle();
202 assert_eq!(h.sm().as_u32(), 80);
203 assert_eq!(h.device(), 0);
204 }
205
206 #[test]
207 fn validate_batch_ok() {
208 RlHandle::validate_batch(32, 1024).unwrap();
209 }
210
211 #[test]
212 fn validate_batch_zero_error() {
213 assert!(RlHandle::validate_batch(0, 100).is_err());
214 }
215
216 #[test]
217 fn validate_batch_too_large_error() {
218 assert!(RlHandle::validate_batch(200, 100).is_err());
219 }
220}