1#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
14pub struct SmVersion(pub u32);
15
16impl SmVersion {
17 #[must_use]
19 #[inline]
20 pub fn as_u32(self) -> u32 {
21 self.0
22 }
23
24 #[must_use]
29 pub fn ptx_version_str(self) -> &'static str {
30 match self.0 {
31 v if v >= 100 => "8.7",
32 v if v >= 90 => "8.4",
33 v if v >= 80 => "8.0",
34 _ => "7.5",
35 }
36 }
37
38 #[must_use]
40 pub fn target_str(self) -> String {
41 format!("sm_{}", self.0)
42 }
43}
44
45impl std::fmt::Display for SmVersion {
46 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47 write!(f, "SM {}.{}", self.0 / 10, self.0 % 10)
48 }
49}
50
51#[derive(Debug, Clone)]
59pub struct LcgRng {
60 state: u64,
61}
62
63impl LcgRng {
64 #[must_use]
66 pub fn new(seed: u64) -> Self {
67 Self {
68 state: seed.wrapping_add(1),
69 }
70 }
71
72 #[inline]
74 pub fn next_u32(&mut self) -> u32 {
75 self.state = self
76 .state
77 .wrapping_mul(6_364_136_223_846_793_005)
78 .wrapping_add(1_442_695_040_888_963_407);
79 ((self.state >> 33) ^ self.state) as u32
80 }
81
82 #[inline]
84 pub fn next_f32(&mut self) -> f32 {
85 self.next_u32() as f32 / (u32::MAX as f32 + 1.0)
86 }
87
88 #[inline]
92 pub fn next_usize(&mut self, n: usize) -> usize {
93 if n == 0 {
94 return 0;
95 }
96 (self.next_u32() as usize) % n
97 }
98
99 pub fn next_normal_pair(&mut self) -> (f32, f32) {
101 let u1 = (self.next_f32() + 1e-10).min(1.0 - 1e-10);
102 let u2 = self.next_f32();
103 let r = (-2.0 * u1.ln()).sqrt();
104 let theta = 2.0 * std::f32::consts::PI * u2;
105 (r * theta.cos(), r * theta.sin())
106 }
107
108 pub fn fill_normal(&mut self, buf: &mut [f32]) {
110 let mut i = 0;
111 while i + 1 < buf.len() {
112 let (a, b) = self.next_normal_pair();
113 buf[i] = a;
114 buf[i + 1] = b;
115 i += 2;
116 }
117 if i < buf.len() {
118 let (a, _) = self.next_normal_pair();
119 buf[i] = a;
120 }
121 }
122
123 pub fn shuffle<T>(&mut self, slice: &mut [T]) {
125 let n = slice.len();
126 for i in (1..n).rev() {
127 let j = self.next_usize(i + 1);
128 slice.swap(i, j);
129 }
130 }
131}
132
133#[derive(Debug, Clone)]
142pub struct VisionHandle {
143 sm: SmVersion,
144 rng: LcgRng,
145 device: u32,
146}
147
148impl VisionHandle {
149 #[must_use]
151 pub fn new(device: u32, sm: SmVersion, seed: u64) -> Self {
152 Self {
153 sm,
154 rng: LcgRng::new(seed),
155 device,
156 }
157 }
158
159 #[must_use]
162 pub fn default_handle() -> Self {
163 Self::new(0, SmVersion(80), 42)
164 }
165
166 #[must_use]
168 pub fn sm_version(&self) -> SmVersion {
169 self.sm
170 }
171
172 #[must_use]
174 pub fn device(&self) -> u32 {
175 self.device
176 }
177
178 pub fn rng_mut(&mut self) -> &mut LcgRng {
180 &mut self.rng
181 }
182}
183
184#[cfg(test)]
187mod tests {
188 use super::*;
189
190 #[test]
191 fn sm_version_ptx_strings() {
192 assert_eq!(SmVersion(75).ptx_version_str(), "7.5");
193 assert_eq!(SmVersion(80).ptx_version_str(), "8.0");
194 assert_eq!(SmVersion(86).ptx_version_str(), "8.0");
195 assert_eq!(SmVersion(90).ptx_version_str(), "8.4");
196 assert_eq!(SmVersion(100).ptx_version_str(), "8.7");
197 assert_eq!(SmVersion(120).ptx_version_str(), "8.7");
198 }
199
200 #[test]
201 fn sm_version_target_str() {
202 assert_eq!(SmVersion(80).target_str(), "sm_80");
203 assert_eq!(SmVersion(90).target_str(), "sm_90");
204 assert_eq!(SmVersion(120).target_str(), "sm_120");
205 }
206
207 #[test]
208 fn sm_version_display() {
209 assert_eq!(SmVersion(80).to_string(), "SM 8.0");
210 assert_eq!(SmVersion(90).to_string(), "SM 9.0");
211 }
212
213 #[test]
214 fn sm_version_ordering() {
215 assert!(SmVersion(80) < SmVersion(90));
216 assert!(SmVersion(100) > SmVersion(90));
217 assert_eq!(SmVersion(80), SmVersion(80));
218 }
219
220 #[test]
221 fn sm_version_as_u32() {
222 assert_eq!(SmVersion(86).as_u32(), 86);
223 }
224
225 #[test]
226 fn vision_handle_default() {
227 let h = VisionHandle::default_handle();
228 assert_eq!(h.device(), 0);
229 assert_eq!(h.sm_version(), SmVersion(80));
230 }
231
232 #[test]
233 fn vision_handle_custom() {
234 let h = VisionHandle::new(2, SmVersion(90), 12345);
235 assert_eq!(h.device(), 2);
236 assert_eq!(h.sm_version(), SmVersion(90));
237 }
238
239 #[test]
240 fn lcg_rng_determinism() {
241 let mut a = LcgRng::new(42);
242 let mut b = LcgRng::new(42);
243 for _ in 0..100 {
244 assert_eq!(a.next_u32(), b.next_u32());
245 }
246 }
247
248 #[test]
249 fn lcg_rng_f32_in_range() {
250 let mut rng = LcgRng::new(7);
251 for _ in 0..1000 {
252 let v = rng.next_f32();
253 assert!((0.0..1.0).contains(&v), "out of range: {v}");
254 }
255 }
256
257 #[test]
258 fn lcg_rng_usize_in_range() {
259 let mut rng = LcgRng::new(99);
260 for _ in 0..1000 {
261 let v = rng.next_usize(7);
262 assert!(v < 7, "out of range: {v}");
263 }
264 }
265
266 #[test]
267 fn lcg_rng_normal_fill_finite() {
268 let mut rng = LcgRng::new(13);
269 let mut buf = vec![0.0_f32; 64];
270 rng.fill_normal(&mut buf);
271 assert!(buf.iter().all(|v| v.is_finite()));
272 }
273
274 #[test]
275 fn lcg_rng_shuffle_permutes() {
276 let mut rng = LcgRng::new(77);
277 let mut v: Vec<usize> = (0..8).collect();
278 rng.shuffle(&mut v);
279 let mut sorted = v.clone();
280 sorted.sort_unstable();
281 assert_eq!(sorted, (0..8).collect::<Vec<_>>());
282 }
283
284 #[test]
285 fn vision_handle_rng_mut() {
286 let mut h = VisionHandle::default_handle();
287 let v = h.rng_mut().next_f32();
288 assert!((0.0..1.0).contains(&v));
289 }
290}