1use std::ops::BitOr;
2
3#[derive(Clone, Debug, Eq, Hash, PartialEq)]
4pub struct Lcg<L, S> {
5 pub state: S,
6 pub parameters: L
7}
8
9impl<L, S> Lcg<L, S>
10where
11 L: Copy + Into<Parameters<S>>,
12 S: Integer
13{
14 pub fn multiplier(&self) -> S {
15 self.parameters.into().multiplier
16 }
17
18 pub fn increment(&self) -> S {
19 self.parameters.into().increment
20 }
21
22 pub fn current(&self) -> S {
23 self.parameters.into().apply(self.state)
24 }
25
26 pub fn generate(&mut self) -> S {
27 self.state = self.parameters.into().apply(self.state);
28 self.state
29 }
30
31 pub fn jump_forward(&mut self, steps: S) -> S {
32 self.state = self.parameters.into().jump_forward(steps).apply(self.state);
33 self.state
34 }
35
36 pub fn jump_backward(&mut self, steps: S) -> S {
37 self.state = self.parameters.into().jump_backward(steps).apply(self.state);
38 self.state
39 }
40}
41
42#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
43pub struct Parameters<T> {
44 pub multiplier: T,
45 pub increment: T
46}
47
48impl<T> Parameters<T>
49where
50 T: Integer
51{
52 pub fn apply(self, state: T) -> T {
53 state.mul(self.multiplier).add(self.increment)
54 }
55
56 pub fn jump_forward(mut self, mut steps: T) -> Self {
57 let mut acc = Parameters { multiplier: T::ONE, increment: T::ZERO };
58
59 while steps != T::ZERO {
60 if !steps.is_even() {
61 acc.multiplier = acc.multiplier.mul(self.multiplier);
62 acc.increment = acc.increment.mul(self.multiplier).add(self.increment);
63 }
64
65 self.increment = self.multiplier.add(T::ONE).mul(self.increment);
66 self.multiplier = self.multiplier.mul(self.multiplier);
67
68 steps = steps.div(T::TWO);
69 }
70
71 acc
72 }
73
74 pub fn jump_backward(self, steps: T) -> Self {
75 self.jump_forward(steps.neg())
76 }
77}
78
79pub trait Integer: Sized + Copy + Ord + BitOr<Output = Self> {
80 type Bytes: Copy + Default + AsMut<[u8]> + AsRef<[u8]>;
81
82 const ZERO: Self;
83 const ONE: Self;
84 const TWO: Self;
85
86 fn from_bytes(bytes: Self::Bytes) -> Self;
87 fn is_even(self) -> bool;
88 fn add(self, rhs: Self) -> Self;
89 fn mul(self, rhs: Self) -> Self;
90 fn div(self, rhs: Self) -> Self;
91 fn neg(self) -> Self;
92}
93
94macro_rules! impl_integer {
95 ($ty:ty) => {
96 impl Integer for $ty {
97 type Bytes = [u8; size_of::<Self>()];
98
99 const ZERO: Self = 0;
100 const ONE: Self = 1;
101 const TWO: Self = 2;
102
103 #[inline]
104 fn from_bytes(bytes: Self::Bytes) -> Self {
105 Self::from_le_bytes(bytes)
106 }
107
108 #[inline]
109 fn is_even(self) -> bool {
110 self & 1 == 0
111 }
112
113 #[inline]
114 fn add(self, rhs: Self) -> Self {
115 self.wrapping_add(rhs)
116 }
117
118 #[inline]
119 fn mul(self, rhs: Self) -> Self {
120 self.wrapping_mul(rhs)
121 }
122
123 #[inline]
124 fn div(self, rhs: Self) -> Self {
125 self / rhs
126 }
127
128 #[inline]
129 fn neg(self) -> Self {
130 self.wrapping_neg()
131 }
132 }
133 };
134}
135
136impl_integer!(u8);
137impl_integer!(u16);
138impl_integer!(u32);
139impl_integer!(u64);
140impl_integer!(u128);
141impl_integer!(usize);
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146 use crate::DefaultLcgParameters;
147
148 const P: Parameters<u64> = Parameters {
149 multiplier: DefaultLcgParameters::<u64>::multiplier(),
150 increment: DefaultLcgParameters::<u64>::increment()
151 };
152
153 #[test]
154 fn jump_forward() {
155 assert_eq!(P.jump_forward(0), Parameters { multiplier: 1, increment: 0 });
156 assert_eq!(P.jump_forward(1), P);
157 assert_eq!(P.jump_forward(1).jump_forward(2).jump_forward(3), P.jump_forward(6));
158 assert_eq!(P.apply(P.apply(12345)), P.jump_forward(2).apply(12345));
159 assert_eq!((0..997).fold(12345, |acc, _| P.apply(acc)), P.jump_forward(997).apply(12345));
160 }
161
162 #[test]
163 fn jump_backward() {
164 assert_eq!(P.jump_backward(0), Parameters { multiplier: 1, increment: 0 });
165 assert_eq!(P.jump_backward(1).apply(12345), P.jump_forward(u64::MAX).apply(12345));
166 assert_eq!(P.jump_backward(1).jump_backward(2).jump_backward(3), P.jump_backward(6));
167 }
168
169 #[test]
170 fn lcg() {
171 let mut lcg = Lcg { state: 12345, parameters: P };
172 let a = lcg.generate();
173 let b = lcg.jump_forward(u64::MAX);
174 assert_eq!(a, lcg.generate());
175 assert_eq!(b, lcg.jump_backward(1));
176 assert_eq!(a, lcg.generate());
177 }
178}