pcg_random/
lcg.rs

1use std::ops::BitOr;
2
3#[derive(Clone, Debug, Eq, Hash, PartialEq)]
4pub struct Lcg<L, S> {
5	pub state: S,
6	pub parameters: L
7}
8
9impl<L, S> Lcg<L, S>
10where
11	L: Copy + Into<Parameters<S>>,
12	S: Integer
13{
14	pub fn multiplier(&self) -> S {
15		self.parameters.into().multiplier
16	}
17
18	pub fn increment(&self) -> S {
19		self.parameters.into().increment
20	}
21
22	pub fn current(&self) -> S {
23		self.parameters.into().apply(self.state)
24	}
25
26	pub fn generate(&mut self) -> S {
27		self.state = self.parameters.into().apply(self.state);
28		self.state
29	}
30
31	pub fn jump_forward(&mut self, steps: S) -> S {
32		self.state = self.parameters.into().jump_forward(steps).apply(self.state);
33		self.state
34	}
35
36	pub fn jump_backward(&mut self, steps: S) -> S {
37		self.state = self.parameters.into().jump_backward(steps).apply(self.state);
38		self.state
39	}
40}
41
42#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
43pub struct Parameters<T> {
44	pub multiplier: T,
45	pub increment: T
46}
47
48impl<T> Parameters<T>
49where
50	T: Integer
51{
52	pub fn apply(self, state: T) -> T {
53		state.mul(self.multiplier).add(self.increment)
54	}
55
56	pub fn jump_forward(mut self, mut steps: T) -> Self {
57		let mut acc = Parameters { multiplier: T::ONE, increment: T::ZERO };
58
59		while steps != T::ZERO {
60			if !steps.is_even() {
61				acc.multiplier = acc.multiplier.mul(self.multiplier);
62				acc.increment = acc.increment.mul(self.multiplier).add(self.increment);
63			}
64
65			self.increment = self.multiplier.add(T::ONE).mul(self.increment);
66			self.multiplier = self.multiplier.mul(self.multiplier);
67
68			steps = steps.div(T::TWO);
69		}
70
71		acc
72	}
73
74	pub fn jump_backward(self, steps: T) -> Self {
75		self.jump_forward(steps.neg())
76	}
77}
78
79pub trait Integer: Sized + Copy + Ord + BitOr<Output = Self> {
80	type Bytes: Copy + Default + AsMut<[u8]> + AsRef<[u8]>;
81
82	const ZERO: Self;
83	const ONE: Self;
84	const TWO: Self;
85
86	fn from_bytes(bytes: Self::Bytes) -> Self;
87	fn is_even(self) -> bool;
88	fn add(self, rhs: Self) -> Self;
89	fn mul(self, rhs: Self) -> Self;
90	fn div(self, rhs: Self) -> Self;
91	fn neg(self) -> Self;
92}
93
94macro_rules! impl_integer {
95	($ty:ty) => {
96		impl Integer for $ty {
97			type Bytes = [u8; size_of::<Self>()];
98
99			const ZERO: Self = 0;
100			const ONE: Self = 1;
101			const TWO: Self = 2;
102
103			#[inline]
104			fn from_bytes(bytes: Self::Bytes) -> Self {
105				Self::from_le_bytes(bytes)
106			}
107
108			#[inline]
109			fn is_even(self) -> bool {
110				self & 1 == 0
111			}
112
113			#[inline]
114			fn add(self, rhs: Self) -> Self {
115				self.wrapping_add(rhs)
116			}
117
118			#[inline]
119			fn mul(self, rhs: Self) -> Self {
120				self.wrapping_mul(rhs)
121			}
122
123			#[inline]
124			fn div(self, rhs: Self) -> Self {
125				self / rhs
126			}
127
128			#[inline]
129			fn neg(self) -> Self {
130				self.wrapping_neg()
131			}
132		}
133	};
134}
135
136impl_integer!(u8);
137impl_integer!(u16);
138impl_integer!(u32);
139impl_integer!(u64);
140impl_integer!(u128);
141impl_integer!(usize);
142
143#[cfg(test)]
144mod tests {
145	use super::*;
146	use crate::DefaultLcgParameters;
147
148	const P: Parameters<u64> = Parameters {
149		multiplier: DefaultLcgParameters::<u64>::multiplier(),
150		increment: DefaultLcgParameters::<u64>::increment()
151	};
152
153	#[test]
154	fn jump_forward() {
155		assert_eq!(P.jump_forward(0), Parameters { multiplier: 1, increment: 0 });
156		assert_eq!(P.jump_forward(1), P);
157		assert_eq!(P.jump_forward(1).jump_forward(2).jump_forward(3), P.jump_forward(6));
158		assert_eq!(P.apply(P.apply(12345)), P.jump_forward(2).apply(12345));
159		assert_eq!((0..997).fold(12345, |acc, _| P.apply(acc)), P.jump_forward(997).apply(12345));
160	}
161
162	#[test]
163	fn jump_backward() {
164		assert_eq!(P.jump_backward(0), Parameters { multiplier: 1, increment: 0 });
165		assert_eq!(P.jump_backward(1).apply(12345), P.jump_forward(u64::MAX).apply(12345));
166		assert_eq!(P.jump_backward(1).jump_backward(2).jump_backward(3), P.jump_backward(6));
167	}
168
169	#[test]
170	fn lcg() {
171		let mut lcg = Lcg { state: 12345, parameters: P };
172		let a = lcg.generate();
173		let b = lcg.jump_forward(u64::MAX);
174		assert_eq!(a, lcg.generate());
175		assert_eq!(b, lcg.jump_backward(1));
176		assert_eq!(a, lcg.generate());
177	}
178}