oxicuda_graphalg/
handle.rs1#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
13pub struct SmVersion(pub u32);
14
15impl SmVersion {
16 pub const SM_75: Self = Self(75);
17 pub const SM_80: Self = Self(80);
18 pub const SM_86: Self = Self(86);
19 pub const SM_89: Self = Self(89);
20 pub const SM_90: Self = Self(90);
21 pub const SM_100: Self = Self(100);
22
23 #[must_use]
25 pub fn value(self) -> u32 {
26 self.0
27 }
28}
29
30#[derive(Debug, Clone)]
35pub struct LcgRng {
36 state: u64,
37}
38
39impl LcgRng {
40 const MUL: u64 = 6_364_136_223_846_793_005;
41 const ADD: u64 = 1_442_695_040_888_963_407;
42
43 #[must_use]
45 pub fn new(seed: u64) -> Self {
46 Self {
47 state: seed.wrapping_mul(0x9E37_79B9_7F4A_7C15).wrapping_add(1),
48 }
49 }
50
51 pub fn next_u64(&mut self) -> u64 {
53 self.state = self.state.wrapping_mul(Self::MUL).wrapping_add(Self::ADD);
54 self.state
55 }
56
57 pub fn next_f64(&mut self) -> f64 {
59 (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
60 }
61
62 pub fn next_bool(&mut self) -> bool {
64 (self.next_u64() >> 32) & 1 == 1
65 }
66
67 pub fn next_usize(&mut self, n: usize) -> usize {
69 (self.next_u64() as usize) % n.max(1)
70 }
71
72 pub fn next_range(&mut self, lo: f64, hi: f64) -> f64 {
74 lo + (hi - lo) * self.next_f64()
75 }
76
77 pub fn next_normal(&mut self) -> f64 {
79 let u1 = self.next_f64().max(1.0e-300);
80 let u2 = self.next_f64();
81 (-2.0 * u1.ln()).sqrt() * (std::f64::consts::TAU * u2).cos()
82 }
83}
84
85#[derive(Debug, Clone)]
87pub struct GraphalgHandle {
88 pub sm: SmVersion,
89 pub rng: LcgRng,
90}
91
92impl GraphalgHandle {
93 #[must_use]
95 pub fn new(sm: SmVersion, seed: u64) -> Self {
96 Self {
97 sm,
98 rng: LcgRng::new(seed),
99 }
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106
107 #[test]
108 fn sm_constants_correct() {
109 assert_eq!(SmVersion::SM_75.value(), 75);
110 assert_eq!(SmVersion::SM_100.value(), 100);
111 }
112
113 #[test]
114 fn lcg_deterministic() {
115 let mut r1 = LcgRng::new(42);
116 let mut r2 = LcgRng::new(42);
117 for _ in 0..16 {
118 assert_eq!(r1.next_u64(), r2.next_u64());
119 }
120 }
121
122 #[test]
123 fn lcg_unit_interval() {
124 let mut r = LcgRng::new(7);
125 for _ in 0..1000 {
126 let v = r.next_f64();
127 assert!((0.0..1.0).contains(&v));
128 }
129 }
130
131 #[test]
132 fn lcg_bool_balanced() {
133 let mut r = LcgRng::new(13);
134 let mut trues = 0usize;
135 for _ in 0..1000 {
136 if r.next_bool() {
137 trues += 1;
138 }
139 }
140 assert!(trues > 400 && trues < 600);
141 }
142
143 #[test]
144 fn lcg_normal_mean_zero() {
145 let mut r = LcgRng::new(99);
146 let n = 10_000;
147 let mean: f64 = (0..n).map(|_| r.next_normal()).sum::<f64>() / n as f64;
148 assert!(mean.abs() < 0.2);
149 }
150
151 #[test]
152 fn handle_construction() {
153 let h = GraphalgHandle::new(SmVersion::SM_90, 0);
154 assert_eq!(h.sm.value(), 90);
155 }
156}