1#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
14pub struct SmVersion(pub u32);
15
16impl SmVersion {
17 #[must_use]
19 #[inline]
20 pub fn as_u32(self) -> u32 {
21 self.0
22 }
23
24 #[must_use]
26 pub fn ptx_version_str(self) -> &'static str {
27 match self.0 {
28 v if v >= 100 => "8.7",
29 v if v >= 90 => "8.4",
30 v if v >= 80 => "8.0",
31 _ => "7.5",
32 }
33 }
34
35 #[must_use]
37 pub fn target_str(self) -> String {
38 format!("sm_{}", self.0)
39 }
40}
41
42impl std::fmt::Display for SmVersion {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 write!(f, "SM {}.{}", self.0 / 10, self.0 % 10)
45 }
46}
47
48#[derive(Debug, Clone)]
55pub struct LcgRng {
56 state: u64,
57}
58
59impl LcgRng {
60 #[must_use]
62 pub fn new(seed: u64) -> Self {
63 Self {
64 state: seed.wrapping_add(1),
65 }
66 }
67
68 #[inline]
70 pub fn next_u32(&mut self) -> u32 {
71 self.state = self
72 .state
73 .wrapping_mul(6_364_136_223_846_793_005)
74 .wrapping_add(1_442_695_040_888_963_407);
75 ((self.state >> 33) ^ self.state) as u32
76 }
77
78 #[inline]
80 pub fn next_f32(&mut self) -> f32 {
81 self.next_u32() as f32 / (u32::MAX as f32 + 1.0)
82 }
83
84 #[inline]
86 pub fn next_usize(&mut self, n: usize) -> usize {
87 if n == 0 {
88 return 0;
89 }
90 (self.next_u32() as usize) % n
91 }
92
93 pub fn next_normal_pair(&mut self) -> (f32, f32) {
95 let u1 = (self.next_f32() + 1e-10).min(1.0 - 1e-10);
96 let u2 = self.next_f32();
97 let r = (-2.0 * u1.ln()).sqrt();
98 let theta = 2.0 * std::f32::consts::PI * u2;
99 (r * theta.cos(), r * theta.sin())
100 }
101
102 pub fn fill_normal(&mut self, buf: &mut [f32]) {
104 let mut i = 0;
105 while i + 1 < buf.len() {
106 let (a, b) = self.next_normal_pair();
107 buf[i] = a;
108 buf[i + 1] = b;
109 i += 2;
110 }
111 if i < buf.len() {
112 let (a, _) = self.next_normal_pair();
113 buf[i] = a;
114 }
115 }
116
117 pub fn shuffle<T>(&mut self, slice: &mut [T]) {
119 let n = slice.len();
120 for i in (1..n).rev() {
121 let j = self.next_usize(i + 1);
122 slice.swap(i, j);
123 }
124 }
125}
126
127#[derive(Debug, Clone)]
131pub struct SslHandle {
132 sm: SmVersion,
133 rng: LcgRng,
134 device: u32,
135}
136
137impl SslHandle {
138 #[must_use]
140 pub fn new(device: u32, sm: SmVersion, seed: u64) -> Self {
141 Self {
142 sm,
143 rng: LcgRng::new(seed),
144 device,
145 }
146 }
147
148 #[must_use]
151 pub fn default_handle() -> Self {
152 Self::new(0, SmVersion(80), 42)
153 }
154
155 #[must_use]
157 pub fn sm_version(&self) -> SmVersion {
158 self.sm
159 }
160
161 #[must_use]
163 pub fn device(&self) -> u32 {
164 self.device
165 }
166
167 pub fn rng_mut(&mut self) -> &mut LcgRng {
169 &mut self.rng
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176
177 #[test]
178 fn sm_version_ptx_strings() {
179 assert_eq!(SmVersion(75).ptx_version_str(), "7.5");
180 assert_eq!(SmVersion(80).ptx_version_str(), "8.0");
181 assert_eq!(SmVersion(86).ptx_version_str(), "8.0");
182 assert_eq!(SmVersion(90).ptx_version_str(), "8.4");
183 assert_eq!(SmVersion(100).ptx_version_str(), "8.7");
184 assert_eq!(SmVersion(120).ptx_version_str(), "8.7");
185 }
186
187 #[test]
188 fn sm_version_target_str() {
189 assert_eq!(SmVersion(80).target_str(), "sm_80");
190 assert_eq!(SmVersion(90).target_str(), "sm_90");
191 assert_eq!(SmVersion(120).target_str(), "sm_120");
192 }
193
194 #[test]
195 fn sm_version_display() {
196 assert_eq!(SmVersion(80).to_string(), "SM 8.0");
197 assert_eq!(SmVersion(120).to_string(), "SM 12.0");
198 }
199
200 #[test]
201 fn sm_version_ordering() {
202 assert!(SmVersion(80) < SmVersion(90));
203 assert!(SmVersion(100) > SmVersion(90));
204 }
205
206 #[test]
207 fn ssl_handle_default() {
208 let h = SslHandle::default_handle();
209 assert_eq!(h.device(), 0);
210 assert_eq!(h.sm_version(), SmVersion(80));
211 }
212
213 #[test]
214 fn ssl_handle_custom() {
215 let h = SslHandle::new(2, SmVersion(120), 99);
216 assert_eq!(h.device(), 2);
217 assert_eq!(h.sm_version(), SmVersion(120));
218 }
219
220 #[test]
221 fn lcg_rng_determinism() {
222 let mut a = LcgRng::new(7);
223 let mut b = LcgRng::new(7);
224 for _ in 0..100 {
225 assert_eq!(a.next_u32(), b.next_u32());
226 }
227 }
228
229 #[test]
230 fn lcg_rng_f32_in_range() {
231 let mut rng = LcgRng::new(11);
232 for _ in 0..1000 {
233 let v = rng.next_f32();
234 assert!((0.0..1.0).contains(&v));
235 }
236 }
237
238 #[test]
239 fn lcg_rng_normal_finite() {
240 let mut rng = LcgRng::new(13);
241 let mut buf = vec![0.0_f32; 64];
242 rng.fill_normal(&mut buf);
243 assert!(buf.iter().all(|v| v.is_finite()));
244 }
245
246 #[test]
247 fn lcg_rng_shuffle_preserves_elements() {
248 let mut rng = LcgRng::new(17);
249 let mut v: Vec<usize> = (0..16).collect();
250 rng.shuffle(&mut v);
251 let mut sorted = v.clone();
252 sorted.sort_unstable();
253 assert_eq!(sorted, (0..16).collect::<Vec<_>>());
254 }
255
256 #[test]
257 fn lcg_next_usize_in_range() {
258 let mut rng = LcgRng::new(19);
259 for _ in 0..200 {
260 let v = rng.next_usize(10);
261 assert!(v < 10);
262 }
263 }
264}