ruby_math/sampling/rng/
pcg32.rs

1#![allow(dead_code)]
2
3use std::fmt::Display;
4
5const MULTIPLIER: u64 = 6364136223846793005;
6
7#[derive(Clone, Copy, PartialEq, Debug)]
8pub struct Pcg32 {
9    state: u64,
10    increment: u64,
11}
12
13impl Display for Pcg32 {
14    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
15        write!(
16            f,
17            "Rng(state: {}, increment: {})",
18            self.state, self.increment
19        )
20    }
21}
22
23impl Default for Pcg32 {
24    fn default() -> Self {
25        Self::new(0xcafef00dd15ea5e5, 0xa02bdbf7bb3c0a7)
26    }
27}
28
29impl Pcg32 {
30    pub fn new(state: u64, stream: u64) -> Self {
31        let increment = (stream << 1) | 1;
32        let mut pcg = Self { state, increment };
33        pcg.state = pcg.state.wrapping_add(pcg.increment);
34        pcg
35    }
36
37    fn step(&mut self) {
38        self.state = self
39            .state
40            .wrapping_mul(MULTIPLIER)
41            .wrapping_add(self.increment)
42    }
43
44    pub fn next_u32(&mut self) -> u32 {
45        let state = self.state;
46        self.step();
47        let rot = (state >> 59) as u32;
48        let xsh = (((state >> 18) ^ state) >> 27) as u32;
49        xsh.rotate_right(rot)
50    }
51
52    pub fn next_u64(&mut self) -> u64 {
53        let x = u64::from(self.next_u32());
54        let y = u64::from(self.next_u32());
55        (y << 32) | x
56    }
57
58    pub fn next_f32(&mut self) -> f32 {
59        0.99999994f32.min(self.next_u32() as f32 * 2.3283064365386963e-10f32)
60    }
61
62    pub fn next_f64(&mut self) -> f64 {
63        0.99999999999999989f64.min(self.next_u32() as f64 * 2.3283064365386963e-10f64)
64    }
65
66    pub fn advance(&mut self, steps: u64) {
67        let mut acc_mult = 1u64;
68        let mut acc_plus = 0u64;
69        let mut cur_mult = MULTIPLIER;
70        let mut cur_plus = self.increment;
71        let mut n = steps;
72
73        while n > 0 {
74            if (n & 1) != 0 {
75                acc_mult = acc_mult.wrapping_mul(cur_mult);
76                acc_plus = acc_plus.wrapping_mul(cur_mult).wrapping_add(cur_plus);
77            }
78            cur_plus = cur_mult.wrapping_add(1).wrapping_mul(cur_plus);
79            cur_mult = cur_mult.wrapping_mul(cur_mult);
80            n >>= 1;
81        }
82
83        self.state = acc_mult.wrapping_mul(self.state).wrapping_add(acc_plus)
84    }
85}