1#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
12pub struct SmVersion(pub u32);
13
14impl SmVersion {
15 #[must_use]
17 #[inline]
18 pub fn as_u32(self) -> u32 {
19 self.0
20 }
21
22 #[must_use]
24 pub fn ptx_version_str(self) -> &'static str {
25 match self.0 {
26 v if v >= 100 => "8.7",
27 v if v >= 90 => "8.4",
28 v if v >= 80 => "8.0",
29 _ => "7.5",
30 }
31 }
32
33 #[must_use]
35 pub fn target_str(self) -> String {
36 format!("sm_{}", self.0)
37 }
38}
39
40impl std::fmt::Display for SmVersion {
41 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42 write!(f, "SM {}.{}", self.0 / 10, self.0 % 10)
43 }
44}
45
46#[derive(Debug, Clone)]
52pub struct LcgRng {
53 state: u64,
54}
55
56impl LcgRng {
57 #[must_use]
59 pub fn new(seed: u64) -> Self {
60 Self {
61 state: seed.wrapping_add(1_442_695_040_888_963_407),
62 }
63 }
64
65 #[inline]
67 pub fn next_u32(&mut self) -> u32 {
68 self.state = self
69 .state
70 .wrapping_mul(6_364_136_223_846_793_005)
71 .wrapping_add(1_442_695_040_888_963_407);
72 ((self.state >> 33) ^ self.state) as u32
73 }
74
75 #[inline]
77 pub fn next_f32(&mut self) -> f32 {
78 (self.next_u32() >> 8) as f32 / 16_777_216.0
79 }
80
81 #[inline]
83 pub fn next_usize(&mut self, n: usize) -> usize {
84 (self.next_u32() as usize) % n
85 }
86
87 #[inline]
89 pub fn next_f32_range(&mut self, lo: f32, hi: f32) -> f32 {
90 lo + self.next_f32() * (hi - lo)
91 }
92
93 pub fn next_normal_pair(&mut self) -> (f32, f32) {
95 let u1 = (self.next_f32() + 1e-10).min(1.0 - 1e-10);
96 let u2 = self.next_f32();
97 let r = (-2.0_f32 * u1.ln()).sqrt();
98 let theta = 2.0 * std::f32::consts::PI * u2;
99 (r * theta.cos(), r * theta.sin())
100 }
101
102 pub fn fill_normal(&mut self, buf: &mut [f32]) {
104 let mut i = 0;
105 while i + 1 < buf.len() {
106 let (a, b) = self.next_normal_pair();
107 buf[i] = a;
108 buf[i + 1] = b;
109 i += 2;
110 }
111 if i < buf.len() {
112 let (a, _) = self.next_normal_pair();
113 buf[i] = a;
114 }
115 }
116}
117
118#[derive(Debug)]
122pub struct NerfHandle {
123 pub sm: SmVersion,
125 pub device: u32,
127 pub rng: LcgRng,
129}
130
131impl NerfHandle {
132 #[must_use]
134 pub fn new(device: u32, sm: SmVersion, seed: u64) -> Self {
135 Self {
136 sm,
137 device,
138 rng: LcgRng::new(seed),
139 }
140 }
141
142 #[must_use]
145 pub fn default_handle() -> Self {
146 Self::new(0, SmVersion(80), 42)
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153
154 #[test]
155 fn lcg_rng_deterministic() {
156 let mut a = LcgRng::new(42);
157 let mut b = LcgRng::new(42);
158 for _ in 0..100 {
159 assert_eq!(a.next_u32(), b.next_u32());
160 }
161 }
162
163 #[test]
164 fn lcg_rng_f32_in_range() {
165 let mut rng = LcgRng::new(7);
166 for _ in 0..1000 {
167 let v = rng.next_f32();
168 assert!((0.0..1.0).contains(&v));
169 }
170 }
171
172 #[test]
173 fn nerf_handle_default() {
174 let h = NerfHandle::default_handle();
175 assert_eq!(h.device, 0);
176 assert_eq!(h.sm, SmVersion(80));
177 }
178
179 #[test]
180 fn sm_version_target_str() {
181 assert_eq!(SmVersion(80).target_str(), "sm_80");
182 assert_eq!(SmVersion(90).target_str(), "sm_90");
183 assert_eq!(SmVersion(120).target_str(), "sm_120");
184 }
185}