1use std::{fmt::Debug, ops::BitOr};
2
3#[derive(Clone, Debug, Eq, Hash, PartialEq)]
4pub struct Lcg<L, S> {
5 pub state: S,
6 pub parameters: L
7}
8
9impl<L, S> Lcg<L, S>
10where
11 L: Copy + Default + Into<Parameters<S>>,
12 S: Integer
13{
14 pub fn new(seed: S) -> Self {
15 Self::with_parameters(seed, L::default())
16 }
17}
18
19impl<L, S> Lcg<L, S>
20where
21 L: Copy + Into<Parameters<S>>,
22 S: Integer
23{
24 pub fn with_parameters(seed: S, parameters: L) -> Self {
25 Self {
26 state: if parameters.into().increment == S::ZERO { seed | S::ONE } else { seed },
27 parameters
28 }
29 }
30
31 pub fn multiplier(&self) -> S {
32 self.parameters.into().multiplier
33 }
34
35 pub fn increment(&self) -> S {
36 self.parameters.into().increment
37 }
38
39 pub fn current(&self) -> S {
40 self.parameters.into().apply(self.state)
41 }
42
43 pub fn generate(&mut self) -> S {
44 self.state = self.parameters.into().apply(self.state);
45 self.state
46 }
47
48 pub fn jump_forward(&mut self, steps: S) -> S {
49 self.state = self.parameters.into().jump_forward(steps).apply(self.state);
50 self.state
51 }
52
53 pub fn jump_backward(&mut self, steps: S) -> S {
54 self.state = self.parameters.into().jump_backward(steps).apply(self.state);
55 self.state
56 }
57}
58
59#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
60pub struct Parameters<T> {
61 pub multiplier: T,
62 pub increment: T
63}
64
65impl<T> Parameters<T>
66where
67 T: Integer
68{
69 pub fn apply(self, state: T) -> T {
70 state.mul(self.multiplier).add(self.increment)
71 }
72
73 pub fn jump_forward(mut self, mut steps: T) -> Self {
74 let mut acc = Parameters { multiplier: T::ONE, increment: T::ZERO };
75
76 while steps != T::ZERO {
77 if !steps.is_even() {
78 acc.multiplier = acc.multiplier.mul(self.multiplier);
79 acc.increment = acc.increment.mul(self.multiplier).add(self.increment);
80 }
81
82 self.increment = self.multiplier.add(T::ONE).mul(self.increment);
83 self.multiplier = self.multiplier.mul(self.multiplier);
84
85 steps = steps.div(T::TWO);
86 }
87
88 acc
89 }
90
91 pub fn jump_backward(self, steps: T) -> Self {
92 self.jump_forward(steps.neg())
93 }
94}
95
96pub trait Integer: Sized + Copy + Debug + Ord + BitOr<Output = Self> {
97 type Bytes: Copy + Default + AsMut<[u8]> + AsRef<[u8]>;
98
99 const ZERO: Self;
100 const ONE: Self;
101 const TWO: Self;
102
103 fn from_bytes(bytes: Self::Bytes) -> Self;
104 fn is_even(self) -> bool;
105 fn add(self, rhs: Self) -> Self;
106 fn sub(self, rhs: Self) -> Self;
107 fn mul(self, rhs: Self) -> Self;
108 fn div(self, rhs: Self) -> Self;
109 fn neg(self) -> Self;
110}
111
112macro_rules! impl_integer {
113 ($ty:ty) => {
114 impl Integer for $ty {
115 type Bytes = [u8; size_of::<Self>()];
116
117 const ZERO: Self = 0;
118 const ONE: Self = 1;
119 const TWO: Self = 2;
120
121 #[inline]
122 fn from_bytes(bytes: Self::Bytes) -> Self {
123 Self::from_le_bytes(bytes)
124 }
125
126 #[inline]
127 fn is_even(self) -> bool {
128 self & 1 == 0
129 }
130
131 #[inline]
132 fn add(self, rhs: Self) -> Self {
133 self.wrapping_add(rhs)
134 }
135
136 #[inline]
137 fn sub(self, rhs: Self) -> Self {
138 self.wrapping_sub(rhs)
139 }
140
141 #[inline]
142 fn mul(self, rhs: Self) -> Self {
143 self.wrapping_mul(rhs)
144 }
145
146 #[inline]
147 fn div(self, rhs: Self) -> Self {
148 self / rhs
149 }
150
151 #[inline]
152 fn neg(self) -> Self {
153 self.wrapping_neg()
154 }
155 }
156 };
157}
158
159impl_integer!(u8);
160impl_integer!(u16);
161impl_integer!(u32);
162impl_integer!(u64);
163impl_integer!(u128);
164impl_integer!(usize);
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169 use crate::DefaultLcgParameters;
170
171 const P: Parameters<u64> = Parameters {
172 multiplier: DefaultLcgParameters::<u64>::multiplier(),
173 increment: DefaultLcgParameters::<u64>::increment()
174 };
175
176 #[test]
177 fn jump_forward() {
178 assert_eq!(P.jump_forward(0), Parameters { multiplier: 1, increment: 0 });
179 assert_eq!(P.jump_forward(1), P);
180 assert_eq!(P.jump_forward(1).jump_forward(2).jump_forward(3), P.jump_forward(6));
181 assert_eq!(P.apply(P.apply(12345)), P.jump_forward(2).apply(12345));
182 assert_eq!((0..997).fold(12345, |acc, _| P.apply(acc)), P.jump_forward(997).apply(12345));
183 }
184
185 #[test]
186 fn jump_backward() {
187 assert_eq!(P.jump_backward(0), Parameters { multiplier: 1, increment: 0 });
188 assert_eq!(P.jump_backward(1).apply(12345), P.jump_forward(u64::MAX).apply(12345));
189 assert_eq!(P.jump_backward(1).jump_backward(2).jump_backward(3), P.jump_backward(6));
190 }
191
192 #[test]
193 fn lcg() {
194 let mut lcg = Lcg { state: 12345, parameters: P };
195 let a = lcg.generate();
196 let b = lcg.jump_forward(u64::MAX);
197 assert_eq!(a, lcg.generate());
198 assert_eq!(b, lcg.jump_backward(1));
199 assert_eq!(a, lcg.generate());
200 }
201}