rs_fsrs/
alea.rs

1use crate::Seed;
2
3#[derive(Debug, PartialEq)]
4pub struct AleaState {
5    pub c: f64,
6    pub s0: f64,
7    pub s1: f64,
8    pub s2: f64,
9}
10
11impl From<Alea> for AleaState {
12    fn from(alea: Alea) -> Self {
13        Self {
14            c: alea.c,
15            s0: alea.s0,
16            s1: alea.s1,
17            s2: alea.s2,
18        }
19    }
20}
21
22#[derive(Debug, Clone, Copy)]
23pub struct Alea {
24    c: f64,
25    s0: f64,
26    s1: f64,
27    s2: f64,
28}
29
30impl Alea {
31    fn new(seed: Seed) -> Self {
32        let mut mash = Mash::new();
33        let blank_seed = Seed::new(" ");
34        let mut alea = Self {
35            c: 1.0,
36            s0: mash.mash(&blank_seed),
37            s1: mash.mash(&blank_seed),
38            s2: mash.mash(&blank_seed),
39        };
40
41        alea.s0 -= mash.mash(&seed);
42        if alea.s0 < 0.0 {
43            alea.s0 += 1.0;
44        }
45        alea.s1 -= mash.mash(&seed);
46        if alea.s1 < 0.0 {
47            alea.s1 += 1.0;
48        }
49        alea.s2 -= mash.mash(&seed);
50        if alea.s2 < 0.0 {
51            alea.s2 += 1.0;
52        }
53
54        alea
55    }
56}
57
58impl Iterator for Alea {
59    type Item = f64;
60
61    fn next(&mut self) -> Option<Self::Item> {
62        let t = 2091639.0f64.mul_add(self.s0, self.c * TWO_TO_THE_POWER_OF_MINUS_32);
63        self.s0 = self.s1;
64        self.s1 = self.s2;
65        self.c = t.floor();
66        self.s2 = t - self.c;
67
68        Some(self.s2)
69    }
70}
71
72impl From<AleaState> for Alea {
73    fn from(state: AleaState) -> Self {
74        Self {
75            c: state.c,
76            s0: state.s0,
77            s1: state.s1,
78            s2: state.s2,
79        }
80    }
81}
82
83const TWO_TO_THE_POWER_OF_32: u64 = 1 << 32;
84const TWO_TO_THE_POWER_OF_21: u64 = 1 << 21;
85const TWO_TO_THE_POWER_OF_MINUS_32: f64 = 1.0 / (TWO_TO_THE_POWER_OF_32 as f64);
86const TWO_TO_THE_POWER_OF_MINUS_53: f64 = 1.0 / ((1u64 << 53) as f64);
87
88struct Mash {
89    n: f64,
90}
91
92impl Mash {
93    const N: u64 = 0xefc8249d;
94    const fn new() -> Self {
95        Self { n: Self::N as f64 }
96    }
97
98    fn mash(&mut self, seed: &Seed) -> f64 {
99        let mut n: f64 = self.n;
100        for c in seed.inner_str().chars() {
101            n += c as u32 as f64;
102            let mut h = 0.02519603282416938 * n;
103            n = (h as u32) as f64;
104            h -= n;
105            h *= n;
106            n = (h as u32) as f64;
107            h -= n;
108            n += h * TWO_TO_THE_POWER_OF_32 as f64;
109        }
110        self.n = n;
111        self.n * TWO_TO_THE_POWER_OF_MINUS_32 // 2^-32
112    }
113}
114
115#[derive(Debug)]
116pub struct Prng {
117    pub xg: Alea,
118}
119
120impl Prng {
121    fn new(seed: Seed) -> Self {
122        Self {
123            xg: Alea::new(seed),
124        }
125    }
126
127    pub fn gen_next(&mut self) -> f64 {
128        self.xg.next().unwrap()
129    }
130
131    pub fn int32(&mut self) -> i32 {
132        wrap_to_i32(self.gen_next() * TWO_TO_THE_POWER_OF_32 as f64)
133    }
134
135    pub fn double(&mut self) -> f64 {
136        ((self.gen_next() * TWO_TO_THE_POWER_OF_21 as f64) as u64 as f64)
137            .mul_add(TWO_TO_THE_POWER_OF_MINUS_53, self.gen_next())
138    }
139
140    pub fn get_state(&self) -> AleaState {
141        self.xg.into()
142    }
143
144    pub fn import_state(mut self, state: impl Into<Alea>) -> Self {
145        self.xg = state.into();
146        self
147    }
148}
149
150// The rem_euclid() wraps within a positive range, then casting u32 to i32 makes half of that range negative.
151fn wrap_to_i32(input: f64) -> i32 {
152    input.rem_euclid((u32::MAX as f64) + 1.0) as u32 as i32
153}
154
155pub fn alea(seed: Seed) -> Prng {
156    match seed {
157        Seed::String(_) => Prng::new(seed),
158        Seed::Empty => Prng::new(Seed::default()),
159        Seed::Default => Prng::new(Seed::default()),
160    }
161}