1use core::any::Any;
4use core::fmt::Display;
5use core::{any, fmt::Debug};
6
7use crate::{Schedule, modifiers::Modifier};
8
9mod averaging;
10mod poisson_gap;
11mod quantiles;
12
13use alloc::{borrow::ToOwned, boxed::Box, vec, vec::Vec};
14pub use averaging::*;
15use ndarray::{Dimension, ShapeBuilder};
16pub use poisson_gap::*;
17pub use quantiles::*;
18
19pub trait Generator<Dim: Dimension> {
27 fn generate(&self, count: usize, dims: Dim) -> Schedule<Dim> {
29 self.generate_with_iter_and_trace(count, dims, 0)
30 .into_sched()
31 }
32
33 fn generate_with_iter(&self, count: usize, dims: Dim, iteration: u64) -> Schedule<Dim> {
35 assert!(
36 count <= (0..dims.ndim()).map(|v| dims[v]).product(),
37 "Count must be less than the number of positions"
38 );
39
40 let sched = self._generate_no_trace(count, dims.to_owned(), iteration);
41
42 validate_schedule::<_, Self>(&sched, dims, count);
43
44 sched
45 }
46
47 fn generate_with_trace(&self, count: usize, dims: Dim) -> Trace<Dim> {
51 self.generate_with_iter_and_trace(count, dims, 0)
52 }
53
54 fn generate_with_iter_and_trace(&self, count: usize, dims: Dim, iteration: u64) -> Trace<Dim> {
56 assert!(
57 count <= (0..dims.ndim()).map(|v| dims[v]).product(),
58 "Count must be less than the number of positions"
59 );
60
61 let trace = self._generate(count, dims.to_owned(), iteration);
62
63 validate_schedule::<_, Self>(trace.sched(), dims, count);
64
65 trace
66 }
67
68 fn then<T: Modifier<Dim>>(self, modifier: T) -> T::Output<Self>
72 where
73 Self: Sized,
74 {
75 modifier.modify(self)
76 }
77
78 fn _generate(&self, count: usize, dims: Dim, iteration: u64) -> Trace<Dim>;
84
85 fn _generate_no_trace(&self, count: usize, dims: Dim, iteration: u64) -> Schedule<Dim> {
87 self._generate(count, dims, iteration).into_sched()
88 }
89}
90
91impl<T: core::ops::Deref<Target = dyn Generator<Dim>>, Dim: Dimension> Generator<Dim> for T {
92 fn _generate(&self, count: usize, dims: Dim, iteration: u64) -> Trace<Dim> {
93 (**self)._generate(count, dims, iteration)
94 }
95
96 fn _generate_no_trace(&self, count: usize, dims: Dim, iteration: u64) -> Schedule<Dim> {
97 (**self)._generate_no_trace(count, dims, iteration)
98 }
99}
100
101fn validate_schedule<Dim: Dimension, T: Generator<Dim> + ?Sized>(
102 sched: &Schedule<Dim>,
103 dims: Dim,
104 count: usize,
105) {
106 let real_count = sched.iter().filter(|v| **v).count();
107
108 assert!(
109 real_count == count,
110 "Returned the wrong count (found {real_count}, expected {count})! In {}",
111 any::type_name::<T>()
112 );
113
114 assert!(
115 dims == sched.raw_dim(),
116 "Returned the wrong length (found {:?}, expected {:?})! In {}",
117 dims,
118 sched.raw_dim(),
119 any::type_name::<T>()
120 );
121}
122
123pub fn xor_iteration(mut seed: [u8; 32], iteration: u64) -> [u8; 32] {
125 for (i, byte) in iteration.to_le_bytes().into_iter().enumerate() {
126 seed[i] ^= byte;
127 }
128 seed
129}
130
131pub trait TraceOutput: Any + Debug + Display {}
133
134impl<T: Any + Debug + Display> TraceOutput for T {}
135
136pub struct Trace<Dim: Dimension> {
140 pub(crate) stack: Vec<(Schedule<Dim>, Box<dyn TraceOutput>)>,
141}
142
143impl<Dim: Dimension> Debug for Trace<Dim> {
144 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
145 for (sched, trace) in self.iter() {
146 writeln!(f, "- {trace}")?;
147 if sched.dim().into_shape_with_order().size() == 1 {
148 writeln!(f, " {sched}")?;
149 }
150 }
151
152 Ok(())
153 }
154}
155
156impl<Dim: Dimension> Trace<Dim> {
157 pub fn new<T: TraceOutput>(sched: Schedule<Dim>, trace: T) -> Trace<Dim> {
159 Trace {
160 stack: vec![(sched, Box::new(trace))],
161 }
162 }
163
164 #[allow(clippy::missing_panics_doc)]
166 pub fn sched(&self) -> &Schedule<Dim> {
167 &self.stack.last().unwrap().0
169 }
170
171 #[allow(clippy::missing_panics_doc)]
173 pub fn into_sched(mut self) -> Schedule<Dim> {
174 self.stack.pop().unwrap().0
176 }
177
178 pub fn get<T: TraceOutput>(&self) -> Option<&T> {
180 for v in self.stack.iter().rev() {
181 if let Some(t) = (&*v.1 as &dyn Any).downcast_ref::<T>() {
182 return Some(t);
183 }
184 }
185
186 None
187 }
188
189 pub fn with<T: TraceOutput>(mut self, sched: Schedule<Dim>, trace: T) -> Trace<Dim> {
191 self.stack.push((sched, Box::new(trace)));
192 self
193 }
194
195 pub fn iter(&self) -> impl Iterator<Item = (&Schedule<Dim>, &dyn TraceOutput)> {
197 self.stack.iter().map(|v| (&v.0, &*v.1))
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use core::any::TypeId;
204 use std::panic::resume_unwind;
205 use std::{fs, thread};
206
207 use alloc::vec::Vec;
208 use alloc::{borrow::ToOwned, sync::Arc};
209 use ndarray::{Array, Ix1};
210
211 use crate::{
212 DisplayMode, Schedule,
213 modifiers::{FillCornersBuilder, Filter, PSFPolisher, TMFilter},
214 pdf::{QSinBias, exponential, qsin, unweighted},
215 };
216
217 use super::{Averaging, Generator, Quantiles, RandomSampling, SinWeightedPoissonGap, Trace};
218
219 #[test]
220 fn trace() {
221 let s1 = Schedule::new(Array::from_vec(vec![true, false, true]));
222 let s2 = Schedule::new(Array::from_vec(vec![true, false, false]));
223 let s3 = Schedule::new(Array::from_vec(vec![false, false, true]));
224
225 let trace = Trace::new(s1.to_owned(), 1_u8)
226 .with(s2.to_owned(), 2_u16)
227 .with(s3.to_owned(), 3_u8);
228
229 assert_eq!(trace.sched(), &s3);
230
231 assert_eq!(*trace.get::<u8>().unwrap(), 3);
232 assert_eq!(trace.get::<u32>(), None);
233 trace
234 .iter()
235 .zip([
236 (&s1, TypeId::of::<u8>()),
237 (&s2, TypeId::of::<u16>()),
238 (&s3, TypeId::of::<u8>()),
239 ])
240 .for_each(|(a, b)| {
241 assert_eq!(a.0, b.0);
242 assert_eq!(a.1.type_id(), b.1);
243 });
244 }
245
246 #[test]
247 fn forwards_compatibility() {
248 let scheds: [(&'static str, Arc<dyn Generator<Ix1> + Send + Sync>); 4] = [
249 (
250 "qt",
251 Arc::from(Quantiles::new(|len| qsin(len, QSinBias::Low, 3.))),
252 ),
253 (
254 "pg",
255 Arc::from(
256 SinWeightedPoissonGap::new(*b"F R U' R' U' R U R F' R U R' U' ")
258 .fill_corners(|_, _| [1, 1]),
259 ),
260 ),
261 (
262 "ru",
263 Arc::from(
264 RandomSampling::new(unweighted, *b"Butter, Honey, Sugar, Cinnamon, ")
265 .fill_corners(|_, _| [1, 1]),
266 ),
267 ),
268 (
269 "av",
270 Arc::from(Averaging::new(
271 |v| exponential(v, 4.),
272 8,
273 *b"when life gives you f(x), f(henr",
274 )),
275 ),
276 ];
277
278 let configs = [
280 (64, 256, 8, false, false),
281 (64, 256, 8, true, false),
282 (64, 256, 8, false, true),
283 (64, 256, 8, true, true),
284 (128, 512, 12, false, false),
285 (128, 512, 12, true, false),
286 (128, 512, 12, false, true),
287 (128, 512, 12, true, true),
288 (96, 512, 12, true, true),
289 (96, 512, 12, false, false),
290 (96, 512, 12, true, false),
291 (96, 512, 12, false, true),
292 (52, 256, 8, false, false),
293 (52, 256, 8, true, true),
294 (192, 1024, 16, true, true),
295 (154, 1024, 16, true, true),
296 (308, 2048, 20, true, true),
297 (410, 4096, 24, true, true),
298 (20, 48, 5, true, true),
299 (32, 128, 6, true, true),
300 (48, 192, 6, false, false),
301 (48, 192, 6, false, true),
302 (48, 192, 6, true, false),
303 (48, 192, 6, true, true),
304 ];
305
306 let mut threads = Vec::new();
307
308 for (name, generator) in &scheds {
309 for (count, length, backfill, tm, itp) in configs {
310 let generator = Arc::clone(generator);
311 let name = *name;
312 threads.push(thread::spawn(move || {
313 let mut name = format!("{name}-{count}x{length}");
314
315 if tm {
316 name.push_str("-tm");
317 }
318
319 if itp {
320 name.push_str("-itp");
321 }
322
323 name.push_str(".sch");
324
325 let mut sched = (generator as Arc<dyn Generator<Ix1>>)
326 .fill_corners(|_, _| [backfill, 1])
327 .generate(count, Ix1(length));
328
329 if tm {
330 sched = TMFilter::new().filter(sched);
331 }
332
333 if itp {
334 sched = PSFPolisher::new(0.1, 0.32, DisplayMode::Abs).filter(sched);
335 }
336
337 let path = format!("src/generators/tests/forwards_compat/{name}");
338
339 println!("{}/{path}", std::env::current_dir().unwrap().display());
340
341 let target = fs::read_to_string(&path).unwrap();
342 let decoded = Schedule::decode(&target, crate::EncodingType::ZeroBased, |_| {
343 Ok(Ix1(length))
344 })
345 .unwrap();
346
347 assert_eq!(sched, decoded, "{}", path);
348 }));
349 }
350 }
351
352 let seed_variants = [
353 (52, 256, 0, false, false, 1),
354 (52, 256, 0, true, true, 1),
355 (52, 256, 0, false, false, 2),
356 (52, 256, 0, true, true, 2),
357 (52, 256, 0, false, false, 3),
358 (52, 256, 0, true, true, 3),
359 (52, 256, 0, false, false, 4),
360 (52, 256, 0, true, true, 4),
361 (52, 256, 0, false, false, 5),
362 (52, 256, 0, true, true, 5),
363 (52, 256, 0, false, false, 6),
364 (52, 256, 0, true, true, 6),
365 (52, 256, 0, false, false, 7),
366 (52, 256, 0, true, true, 7),
367 (52, 256, 0, false, false, 8),
368 (52, 256, 0, true, true, 8),
369 ];
370
371 for (count, length, backfill, tm, itp, iteration) in seed_variants {
372 let generator = Arc::clone(&scheds[1].1);
373 threads.push(thread::spawn(move || {
374 let mut name = format!("pg-{count}x{length}-{iteration}");
375
376 if tm {
377 name.push_str("-tm");
378 }
379
380 if itp {
381 name.push_str("-itp");
382 }
383
384 name.push_str(".sch");
385
386 let mut sched = (generator as Arc<dyn Generator<Ix1>>)
387 .fill_corners(|_, _| [backfill, 1])
388 .generate_with_iter(count, Ix1(length), iteration);
389
390 if tm {
391 sched = TMFilter::new().filter(sched);
392 }
393
394 if itp {
395 sched = PSFPolisher::new(0.1, 0.32, DisplayMode::Abs).filter(sched);
396 }
397
398 let path = format!("src/generators/tests/forwards_compat/{name}");
399
400 println!("{}/{path}", std::env::current_dir().unwrap().display());
401
402 let target = fs::read_to_string(path).unwrap();
403 let decoded =
404 Schedule::decode(&target, crate::EncodingType::ZeroBased, |_| Ok(Ix1(length)))
405 .unwrap();
406
407 assert_eq!(sched, decoded);
408 }));
409 }
410
411 for thread in threads {
412 if let Err(e) = thread.join() {
413 resume_unwind(e);
414 };
415 }
416 }
417}