1#[derive(Clone, Debug, Default, Eq, PartialEq)]
17pub struct PrimeWheel30 {
18 base: u128,
19 index: usize,
20}
21
22impl PrimeWheel30 {
23 const GAPS: [u128; 12] = [
24 2, 1, 2, 2, 4, 2, 4, 2, 4, 6, 2, 6 ];
37 pub fn new() -> Self {
38 Self::default()
39 }
40}
41
42impl Iterator for PrimeWheel30 {
43 type Item = u128;
44
45 #[inline(always)]
46 fn next(&mut self) -> Option<Self::Item> {
47 let gap = Self::GAPS.get(self.index)?;
48 self.base += gap;
49 self.index += 1;
50 if self.index == 12 {
51 self.index = 4;
52 }
53 Some(self.base)
54 }
55}
56
57#[derive(Clone, Debug, Default, Eq, PartialEq)]
62pub struct PrimeWheel210 {
63 base: u128,
64 index: usize,
65}
66
67impl PrimeWheel210 {
68 const GAPS: [u128; 53] = [
69 2, 1, 2, 2, 4, 2, 4, 2, 4, 6, 2, 6, 4, 2, 4, 6, 6, 2, 6, 4, 2, 6, 4, 6, 8, 4, 2, 4, 2, 4, 8, 6, 4, 6, 2, 4, 6, 2, 6, 6, 4, 2, 4, 6, 2, 6, 4, 2, 4, 2, 10, 2, 10 ];
75 pub fn new() -> Self {
76 Self::default()
77 }
78 pub fn from(start: u128) -> Self {
81 if start <= 2 {
82 return Self::default();
83 }
84 if start <= 11 {
86 let mut base: u128 = 0;
87 for (i, &gap) in Self::GAPS.iter().enumerate().take(5) {
88 if base + gap >= start {
89 return Self { base, index: i };
90 }
91 base += gap;
92 }
93 return Self { base: 7, index: 4 };
94 }
95 let k = start.saturating_sub(13) / 210;
99 let mut pos = 11 + k * 210;
100 for (i, &gap) in Self::GAPS[5..].iter().enumerate() {
101 if pos + gap >= start {
102 return Self { base: pos, index: i + 5 };
103 }
104 pos += gap;
105 }
106 Self { base: pos, index: 5 }
108 }
109}
110
111impl Iterator for PrimeWheel210 {
112 type Item = u128;
113
114 #[inline(always)]
115 fn next(&mut self) -> Option<Self::Item> {
116 let gap = Self::GAPS.get(self.index)?;
117 self.base += gap;
118 self.index += 1;
119 if self.index == 53 {
120 self.index = 5;
121 }
122 Some(self.base)
123 }
124}
125
126impl PrimeWheel210 {
127 pub fn prev(&mut self) -> Option<u128> {
131 if self.base < 2 {
132 return None;
133 }
134 let current = self.base;
135 if self.index == 5 && self.base > 11 {
138 self.index = 53;
139 }
140 if self.index > 0 {
141 self.index -= 1;
142 let gap = Self::GAPS[self.index];
143 self.base = self.base.saturating_sub(gap);
144 } else {
145 self.base = 0;
146 }
147 Some(current)
148 }
149}
150
151#[inline(always)]
155pub(crate) fn is_prime_candidate(n: u128) -> bool {
156 if n < 11 {
157 return matches!(n, 2 | 3 | 5 | 7);
158 }
159 const BITMAP: [u32; 7] = [
160 0xa08a_2802, 0x2820_8a20, 0x0208_8288, 0x8202_28a2,
161 0x20a0_8a08, 0x8828_2288, 0x0002_00a2,
162 ];
163 let index = (n % 210) as usize;
164 BITMAP[index / 32] & (1 << (index & 0x1F)) != 0
165}
166
167#[inline]
169fn mod_pow(mut base: u128, mut exp: u128, modulus: u128) -> u128 {
170 if modulus == 1 { return 0; }
171 let mut result: u128 = 1;
172 base %= modulus;
173 while exp > 0 {
174 if exp & 1 == 1 {
175 result = mod_mul(result, base, modulus);
176 }
177 exp >>= 1;
178 if exp > 0 {
179 base = mod_mul(base, base, modulus);
180 }
181 }
182 result
183}
184
185#[inline]
188fn add_mod(a: u128, b: u128, m: u128) -> u128 {
189 debug_assert!(a < m);
190 debug_assert!(b < m);
191 if a >= m - b {
192 a - (m - b)
193 } else {
194 a + b
195 }
196}
197
198#[inline]
202fn mod_mul(a: u128, b: u128, m: u128) -> u128 {
203 debug_assert!(m > 0);
204 if a.leading_zeros() + b.leading_zeros() >= 128 {
206 return (a * b) % m;
207 }
208 let mut result: u128 = 0;
209 let mut a = a % m;
210 let mut b = b % m;
211 while b > 0 {
213 if b & 1 == 1 {
214 result = add_mod(result, a, m);
215 }
216 b >>= 1;
217 if b > 0 {
218 a = add_mod(a, a, m);
219 }
220 }
221 result
222}
223
224fn miller_rabin_witness(n: u128, a: u128, d: u128, r: u32) -> bool {
227 debug_assert!(n >= 2);
228 let mut x = mod_pow(a, d, n);
229 if x == 1 || x == n - 1 {
230 return true;
231 }
232 for _ in 1..r {
233 x = mod_mul(x, x, n);
234 if x == n - 1 {
235 return true;
236 }
237 }
238 false
239}
240
241pub(crate) fn miller_rabin(n: u128) -> bool {
248 const WITNESSES: [u128; 12] = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37];
249 debug_assert!(n >= 2);
250 let n_minus_1 = n - 1;
251 let r = n_minus_1.trailing_zeros();
252 let d = n_minus_1 >> r;
253 WITNESSES.iter().all(|&a| a >= n || miller_rabin_witness(n, a, d, r))
254}
255
256#[cfg(test)]
257mod tests {
258 use reikna::prime::{is_prime, next_prime};
259 use super::{PrimeWheel30, PrimeWheel210, add_mod, mod_mul};
260
261 fn mod_mul_reference(a: u128, b: u128, m: u128) -> u128 {
262 let mut result = 0;
263 let mut a = a % m;
264 let mut b = b % m;
265 while b > 0 {
266 if b & 1 == 1 {
267 result = add_mod(result, a, m);
268 }
269 b >>= 1;
270 if b > 0 {
271 a = add_mod(a, a, m);
272 }
273 }
274
275 result
276 }
277
278 #[test]
279 fn test_prime_wheel_30_first_1000() {
280 let mut wheel = PrimeWheel30::new();
281 let mut misses = 0;
282 let mut p = 0;
283 for _ in 0..1000 {
284 p = next_prime(p);
285 for n in wheel.by_ref() {
286 if n == p as u128 {
287 break;
288 }
289 assert!(!is_prime(n as u64));
290 misses += 1;
291 }
292 }
293 assert_eq!(misses, 1114);
295 }
296
297 #[test]
298 fn test_prime_wheel_210_first_1000() {
299 let mut wheel = PrimeWheel210::new();
300 let mut misses = 0;
301 let mut p = 0;
302 for _ in 0..1000 {
303 p = next_prime(p);
304 for n in wheel.by_ref() {
305 if n == p as u128 {
306 break;
307 }
308 assert!(!is_prime(n as u64));
309 misses += 1;
310 }
311 }
312 assert_eq!(misses, 813);
314 }
315
316 #[test]
317 fn test_prime_wheel_30_quality() {
318 const TOTAL: u128 = 1000000;
319 let mut primes: u128 = 0;
320 let pw_iter = PrimeWheel30::new();
321 for p in pw_iter.take(TOTAL as usize) {
322 primes += is_prime(p as u64) as u128;
323 }
324 let percent = primes as f64 / TOTAL as f64 * 100.0;
325 println!("Prime wheel generated {}/{} ({:.3}%) primes",
326 primes, TOTAL, percent);
327 assert!(percent > 25.0);
328 }
329
330 #[test]
331 fn test_prime_wheel_210_quality() {
332 const TOTAL: u128 = 1000000;
333 let mut primes: u128 = 0;
334 let pw_iter = PrimeWheel210::new();
335 for p in pw_iter.take(TOTAL as usize) {
336 primes += is_prime(p as u64) as u128;
337 }
338 let percent = primes as f64 / TOTAL as f64 * 100.0;
339 println!("Prime wheel generated {}/{} ({:.3}%) primes",
340 primes, TOTAL, percent);
341 assert!(percent > 30.0);
342 }
343
344 #[test]
345 fn test_add_mod_large_values() {
346 let m = u128::MAX - 158;
347 let a = m - 1;
348 let b = m - 1;
349 assert_eq!(add_mod(a, b, m), m - 2);
350 }
351
352 #[test]
353 fn test_mod_mul_matches_direct_below_2pow127_boundary() {
354 let m = (1_u128 << 127) - 1;
355 let values = [
356 0,
357 1,
358 2,
359 3,
360 17,
361 (1_u128 << 64) + 13,
362 m - 2,
363 m - 1,
364 ];
365
366 for &a in &values {
367 for &b in &values {
368 assert_eq!(mod_mul(a, b, m), mod_mul_reference(a, b, m));
369 }
370 }
371 }
372
373 #[test]
374 fn test_mod_mul_large_modulus_regression_cases() {
375 let m = u128::MAX - 158;
376 let cases = [
377 (m - 1, m - 1, 1),
378 (m - 1, 2, m - 2),
379 (m - 2, 2, m - 4),
380 (m - 1, m - 2, 2),
381 ];
382
383 for (a, b, expected) in cases {
384 assert_eq!(mod_mul(a, b, m), expected);
385 assert_eq!(mod_mul(b, a, m), expected);
386 }
387 }
388
389 #[test]
390 fn test_prime_wheel_210_from_matches_new() {
391 for start in [0, 1, 2] {
393 let from_iter: Vec<u128> = PrimeWheel210::from(start).take(200).collect();
394 let new_iter: Vec<u128> = PrimeWheel210::new().take(200).collect();
395 assert_eq!(from_iter, new_iter, "from({start}) differs from new()");
396 }
397 }
398
399 #[test]
400 fn test_prime_wheel_210_from_initial_primes() {
401 for &p in &[2u128, 3, 5, 7, 11] {
403 let first = PrimeWheel210::from(p).next().unwrap();
404 assert_eq!(first, p, "from({p}) should yield {p} first");
405 }
406 }
407
408 #[test]
409 fn test_prime_wheel_210_from_never_skips_candidates() {
410 let all: Vec<u128> = PrimeWheel210::new().take(2000).collect();
413 for start in 0..=500 {
414 let first = match PrimeWheel210::from(start).next() {
415 Some(v) => v,
416 None => continue,
417 };
418 assert!(first >= start,
419 "from({start}) yielded {first} which is below start");
420 let pos = all.iter().position(|&v| v == first)
422 .unwrap_or_else(|| panic!(
423 "from({start}) yielded {first} not in wheel sequence"));
424 let from_vals: Vec<u128> = PrimeWheel210::from(start).take(50).collect();
426 assert_eq!(from_vals, all[pos..pos + 50],
427 "from({start}) sequence diverges from new() at offset {pos}");
428 }
429 }
430
431 #[test]
432 fn test_prime_wheel_210_from_at_cycle_boundaries() {
433 let all: Vec<u128> = PrimeWheel210::new().take(5000).collect();
435 for block in [1u128, 2, 5, 10, 100, 1000] {
436 let boundary = block * 210;
437 for offset in [0, 1, 2, 209, 210, 211] {
438 let start = boundary + offset;
439 let first = PrimeWheel210::from(start).next().unwrap();
440 assert!(first >= start,
441 "from({start}) yielded {first} below start");
442 if first < *all.last().unwrap() {
443 assert!(all.contains(&first),
444 "from({start}): {first} not a wheel candidate");
445 }
446 }
447 }
448 }
449
450 #[test]
451 fn test_prime_wheel_210_from_finds_all_primes() {
452 for start in [0u128, 1, 13, 100, 210, 211, 420, 1000, 10000] {
454 let mut wheel = PrimeWheel210::from(start);
455 let mut p = if start <= 2 { 0 } else { (start - 1) as u64 };
456 for _ in 0..100 {
458 p = next_prime(p);
459 if (p as u128) < start { continue; }
460 for n in wheel.by_ref() {
461 if n == p as u128 {
462 break;
463 }
464 assert!(!is_prime(n as u64),
465 "from({start}): wheel candidate {n} is prime but was skipped");
466 }
467 }
468 }
469 }
470}