pcg_random/
lcg.rs

1use std::{fmt::Debug, ops::BitOr};
2
3#[derive(Clone, Debug, Eq, Hash, PartialEq)]
4pub struct Lcg<L, S> {
5	pub state: S,
6	pub parameters: L
7}
8
9impl<L, S> Lcg<L, S>
10where
11	L: Copy + Default + Into<Parameters<S>>,
12	S: Integer
13{
14	pub fn new(seed: S) -> Self {
15		Self::with_parameters(seed, L::default())
16	}
17}
18
19impl<L, S> Lcg<L, S>
20where
21	L: Copy + Into<Parameters<S>>,
22	S: Integer
23{
24	pub fn with_parameters(seed: S, parameters: L) -> Self {
25		Self {
26			state: if parameters.into().increment == S::ZERO { seed | S::ONE } else { seed },
27			parameters
28		}
29	}
30
31	pub fn multiplier(&self) -> S {
32		self.parameters.into().multiplier
33	}
34
35	pub fn increment(&self) -> S {
36		self.parameters.into().increment
37	}
38
39	pub fn current(&self) -> S {
40		self.parameters.into().apply(self.state)
41	}
42
43	pub fn generate(&mut self) -> S {
44		self.state = self.parameters.into().apply(self.state);
45		self.state
46	}
47
48	pub fn jump_forward(&mut self, steps: S) -> S {
49		self.state = self.parameters.into().jump_forward(steps).apply(self.state);
50		self.state
51	}
52
53	pub fn jump_backward(&mut self, steps: S) -> S {
54		self.state = self.parameters.into().jump_backward(steps).apply(self.state);
55		self.state
56	}
57}
58
59#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
60pub struct Parameters<T> {
61	pub multiplier: T,
62	pub increment: T
63}
64
65impl<T> Parameters<T>
66where
67	T: Integer
68{
69	pub fn apply(self, state: T) -> T {
70		state.mul(self.multiplier).add(self.increment)
71	}
72
73	pub fn jump_forward(mut self, mut steps: T) -> Self {
74		let mut acc = Parameters { multiplier: T::ONE, increment: T::ZERO };
75
76		while steps != T::ZERO {
77			if !steps.is_even() {
78				acc.multiplier = acc.multiplier.mul(self.multiplier);
79				acc.increment = acc.increment.mul(self.multiplier).add(self.increment);
80			}
81
82			self.increment = self.multiplier.add(T::ONE).mul(self.increment);
83			self.multiplier = self.multiplier.mul(self.multiplier);
84
85			steps = steps.div(T::TWO);
86		}
87
88		acc
89	}
90
91	pub fn jump_backward(self, steps: T) -> Self {
92		self.jump_forward(steps.neg())
93	}
94}
95
96pub trait Integer: Sized + Copy + Debug + Ord + BitOr<Output = Self> {
97	type Bytes: Copy + Default + AsMut<[u8]> + AsRef<[u8]>;
98
99	const ZERO: Self;
100	const ONE: Self;
101	const TWO: Self;
102
103	fn from_bytes(bytes: Self::Bytes) -> Self;
104	fn is_even(self) -> bool;
105	fn add(self, rhs: Self) -> Self;
106	fn sub(self, rhs: Self) -> Self;
107	fn mul(self, rhs: Self) -> Self;
108	fn div(self, rhs: Self) -> Self;
109	fn neg(self) -> Self;
110}
111
112macro_rules! impl_integer {
113	($ty:ty) => {
114		impl Integer for $ty {
115			type Bytes = [u8; size_of::<Self>()];
116
117			const ZERO: Self = 0;
118			const ONE: Self = 1;
119			const TWO: Self = 2;
120
121			#[inline]
122			fn from_bytes(bytes: Self::Bytes) -> Self {
123				Self::from_le_bytes(bytes)
124			}
125
126			#[inline]
127			fn is_even(self) -> bool {
128				self & 1 == 0
129			}
130
131			#[inline]
132			fn add(self, rhs: Self) -> Self {
133				self.wrapping_add(rhs)
134			}
135
136			#[inline]
137			fn sub(self, rhs: Self) -> Self {
138				self.wrapping_sub(rhs)
139			}
140
141			#[inline]
142			fn mul(self, rhs: Self) -> Self {
143				self.wrapping_mul(rhs)
144			}
145
146			#[inline]
147			fn div(self, rhs: Self) -> Self {
148				self / rhs
149			}
150
151			#[inline]
152			fn neg(self) -> Self {
153				self.wrapping_neg()
154			}
155		}
156	};
157}
158
159impl_integer!(u8);
160impl_integer!(u16);
161impl_integer!(u32);
162impl_integer!(u64);
163impl_integer!(u128);
164impl_integer!(usize);
165
166#[cfg(test)]
167mod tests {
168	use super::*;
169	use crate::DefaultLcgParameters;
170
171	const P: Parameters<u64> = Parameters {
172		multiplier: DefaultLcgParameters::<u64>::multiplier(),
173		increment: DefaultLcgParameters::<u64>::increment()
174	};
175
176	#[test]
177	fn jump_forward() {
178		assert_eq!(P.jump_forward(0), Parameters { multiplier: 1, increment: 0 });
179		assert_eq!(P.jump_forward(1), P);
180		assert_eq!(P.jump_forward(1).jump_forward(2).jump_forward(3), P.jump_forward(6));
181		assert_eq!(P.apply(P.apply(12345)), P.jump_forward(2).apply(12345));
182		assert_eq!((0..997).fold(12345, |acc, _| P.apply(acc)), P.jump_forward(997).apply(12345));
183	}
184
185	#[test]
186	fn jump_backward() {
187		assert_eq!(P.jump_backward(0), Parameters { multiplier: 1, increment: 0 });
188		assert_eq!(P.jump_backward(1).apply(12345), P.jump_forward(u64::MAX).apply(12345));
189		assert_eq!(P.jump_backward(1).jump_backward(2).jump_backward(3), P.jump_backward(6));
190	}
191
192	#[test]
193	fn lcg() {
194		let mut lcg = Lcg { state: 12345, parameters: P };
195		let a = lcg.generate();
196		let b = lcg.jump_forward(u64::MAX);
197		assert_eq!(a, lcg.generate());
198		assert_eq!(b, lcg.jump_backward(1));
199		assert_eq!(a, lcg.generate());
200	}
201}