1use crate::LatinSquare;
2use crate::jacobson_matthews::JMState;
3use rand::Rng;
4
5#[derive(Debug, Clone)]
7pub struct SamplerParams {
8 pub burn_in: Option<u64>,
19 pub thinning: Option<u64>,
24 pub p_do_nothing: f64,
26}
27
28impl Default for SamplerParams {
29 fn default() -> Self {
30 Self {
31 burn_in: None, thinning: None, p_do_nothing: 0.01,
34 }
35 }
36}
37
38pub fn sample<R: Rng + ?Sized>(n: usize, rng: &mut R, params: &SamplerParams) -> LatinSquare {
48 assert!((2..=255).contains(&n), "n must be in range 2..=255");
49 assert!(
50 (0.0..=1.0).contains(¶ms.p_do_nothing),
51 "p_do_nothing must be in [0.0, 1.0]"
52 );
53
54 let mut state = JMState::new_cyclic(n);
55
56 let burn_in = params.burn_in.unwrap_or((n * n * n) as u64);
59 for _ in 0..burn_in {
60 step(&mut state, rng, params);
61 }
62
63 while !state.is_proper() {
65 step(&mut state, rng, params);
66 }
67
68 state.to_latin_square()
69}
70
71pub struct Sampler<R> {
92 n: usize,
93 state: JMState,
94 rng: R,
95 params: SamplerParams,
96 burned_in: bool,
97}
98
99impl<R: Rng> Sampler<R> {
100 pub fn new(n: usize, rng: R, params: SamplerParams) -> Self {
110 assert!((2..=255).contains(&n), "n must be in range 2..=255");
111 assert!(
112 (0.0..=1.0).contains(¶ms.p_do_nothing),
113 "p_do_nothing must be in [0.0, 1.0]"
114 );
115
116 Self {
117 n,
118 state: JMState::new_cyclic(n),
119 rng,
120 params,
121 burned_in: false,
122 }
123 }
124}
125
126impl<R: Rng> Iterator for Sampler<R> {
127 type Item = LatinSquare;
128
129 fn next(&mut self) -> Option<Self::Item> {
130 if !self.burned_in {
132 let burn_in = self
133 .params
134 .burn_in
135 .unwrap_or((self.n * self.n * self.n) as u64);
136 for _ in 0..burn_in {
137 step(&mut self.state, &mut self.rng, &self.params);
138 }
139 self.burned_in = true;
140 } else {
141 let thinning = self.params.thinning.unwrap_or((3 * self.n * self.n) as u64);
144 for _ in 0..thinning {
145 step(&mut self.state, &mut self.rng, &self.params);
146 }
147 }
148
149 while !self.state.is_proper() {
151 step(&mut self.state, &mut self.rng, &self.params);
152 }
153
154 Some(self.state.to_latin_square())
155 }
156}
157
158fn step<R: Rng + ?Sized>(state: &mut JMState, rng: &mut R, params: &SamplerParams) {
160 if rng.random::<f64>() < params.p_do_nothing {
162 return;
163 }
164
165 state.step(rng);
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171 use rand::SeedableRng;
172 use rand_chacha::ChaCha20Rng;
173
174 fn quick_params() -> SamplerParams {
175 SamplerParams {
176 burn_in: Some(1000),
177 thinning: Some(1),
178 p_do_nothing: 0.01,
179 }
180 }
181
182 #[test]
183 fn reproducibility_same_seed_same_output() {
184 let params = quick_params();
185
186 let mut rng1 = ChaCha20Rng::seed_from_u64(0);
187 let sq1 = sample(7, &mut rng1, ¶ms);
188
189 let mut rng2 = ChaCha20Rng::seed_from_u64(0);
190 let sq2 = sample(7, &mut rng2, ¶ms);
191
192 assert_eq!(sq1, sq2, "Same seed should produce identical squares");
193 }
194
195 #[test]
196 fn different_seed_different_output_smoke() {
197 let params = quick_params();
198
199 for offset in 0u64..5 {
201 let mut rng1 = ChaCha20Rng::seed_from_u64(offset);
202 let sq1 = sample(7, &mut rng1, ¶ms);
203
204 let mut rng2 = ChaCha20Rng::seed_from_u64(offset + 100);
205 let sq2 = sample(7, &mut rng2, ¶ms);
206
207 if sq1 != sq2 {
208 return; }
210 }
211 panic!("All tested seed pairs produced identical squares (extremely unlikely)");
212 }
213
214 #[test]
215 fn iterator_reproducibility() {
216 let params = quick_params();
217
218 let rng1 = ChaCha20Rng::seed_from_u64(0);
220 let sampler1 = Sampler::new(5, rng1, params.clone());
221
222 let rng2 = ChaCha20Rng::seed_from_u64(0);
223 let sampler2 = Sampler::new(5, rng2, params);
224
225 let squares1: Vec<_> = sampler1.take(10).collect();
227 let squares2: Vec<_> = sampler2.take(10).collect();
228
229 assert_eq!(
230 squares1, squares2,
231 "Same seed should produce identical sequence"
232 );
233 }
234
235 #[test]
236 fn iterator_thinning_spacing() {
237 let params_thin1 = SamplerParams {
239 burn_in: Some(1000),
240 thinning: Some(1),
241 ..Default::default()
242 };
243 let params_thin100 = SamplerParams {
244 burn_in: Some(1000),
245 thinning: Some(100),
246 ..Default::default()
247 };
248
249 let rng1 = ChaCha20Rng::seed_from_u64(0);
250 let sampler1 = Sampler::new(5, rng1, params_thin1);
251
252 let rng2 = ChaCha20Rng::seed_from_u64(0);
253 let sampler2 = Sampler::new(5, rng2, params_thin100);
254
255 let squares1: Vec<_> = sampler1.take(5).collect();
257 let squares2: Vec<_> = sampler2.take(5).collect();
258
259 assert_eq!(
261 squares1[0], squares2[0],
262 "First sample after burn-in should be identical"
263 );
264
265 let different_count = squares1
267 .iter()
268 .zip(squares2.iter())
269 .skip(1)
270 .filter(|(a, b)| a != b)
271 .count();
272
273 assert!(
274 different_count > 0,
275 "Different thinning should produce different sequences after first sample"
276 );
277 }
278}