1use core::any;
4use core::any::Any;
5
6use crate::{modifiers::Modifier, Schedule};
7
8mod averaging;
9mod poisson_gap;
10mod quantiles;
11
12use alloc::{borrow::ToOwned, boxed::Box, vec, vec::Vec};
13pub use averaging::*;
14use ndarray::{Dimension, ShapeBuilder};
15pub use poisson_gap::*;
16pub use quantiles::*;
17
18pub trait Generator<Dim: Dimension> {
26 fn generate(&self, count: usize, dims: Dim) -> Schedule<Dim> {
28 self.generate_with_iter_and_trace(count, dims, 0)
29 .into_sched()
30 }
31
32 fn generate_with_iter(&self, count: usize, dims: Dim, iteration: u64) -> Schedule<Dim> {
34 assert!(
35 count <= (0..dims.ndim()).map(|v| dims[v]).product(),
36 "Count must be less than the number of positions"
37 );
38
39 let sched = self._generate_no_trace(count, dims.to_owned(), iteration);
40
41 validate_schedule::<_, Self>(&sched, dims, count);
42
43 sched
44 }
45
46 fn generate_with_trace(&self, count: usize, dims: Dim) -> Trace<Dim> {
50 self.generate_with_iter_and_trace(count, dims, 0)
51 }
52
53 fn generate_with_iter_and_trace(&self, count: usize, dims: Dim, iteration: u64) -> Trace<Dim> {
55 assert!(
56 count <= (0..dims.ndim()).map(|v| dims[v]).product(),
57 "Count must be less than the number of positions"
58 );
59
60 let trace = self._generate(count, dims.to_owned(), iteration);
61
62 validate_schedule::<_, Self>(trace.sched(), dims, count);
63
64 trace
65 }
66
67 fn then<T: Modifier<Dim>>(self, modifier: T) -> T::Output<Self>
71 where
72 Self: Sized,
73 {
74 modifier.modify(self)
75 }
76
77 fn _generate(&self, count: usize, dims: Dim, iteration: u64) -> Trace<Dim>;
83
84 fn _generate_no_trace(&self, count: usize, dims: Dim, iteration: u64) -> Schedule<Dim> {
86 self._generate(count, dims, iteration).into_sched()
87 }
88}
89
90impl<T: core::ops::Deref<Target = dyn Generator<Dim>>, Dim: Dimension> Generator<Dim> for T {
91 fn _generate(&self, count: usize, dims: Dim, iteration: u64) -> Trace<Dim> {
92 (**self)._generate(count, dims, iteration)
93 }
94
95 fn _generate_no_trace(&self, count: usize, dims: Dim, iteration: u64) -> Schedule<Dim> {
96 (**self)._generate_no_trace(count, dims, iteration)
97 }
98}
99
100fn validate_schedule<Dim: Dimension, T: Generator<Dim> + ?Sized>(
101 sched: &Schedule<Dim>,
102 dims: Dim,
103 count: usize,
104) {
105 let real_count = sched.iter().filter(|v| **v).count();
106
107 assert!(
108 real_count == count,
109 "Returned the wrong count (found {real_count}, expected {count})! In {}",
110 any::type_name::<T>()
111 );
112
113 assert!(
114 dims == sched.raw_dim(),
115 "Returned the wrong length (found {:?}, expected {:?})! In {}",
116 dims,
117 sched.raw_dim(),
118 any::type_name::<T>()
119 );
120}
121
122pub fn xor_iteration(mut seed: [u8; 32], iteration: u64) -> [u8; 32] {
124 for (i, byte) in iteration.to_le_bytes().into_iter().enumerate() {
125 seed[i] ^= byte;
126 }
127 seed
128}
129
130pub trait TraceOutput: Any + core::fmt::Debug + core::fmt::Display {
132 fn as_any(&self) -> &dyn Any;
134}
135
136impl<T: Any + core::fmt::Debug + core::fmt::Display> TraceOutput for T {
137 fn as_any(&self) -> &dyn Any {
138 self
139 }
140}
141
142pub struct Trace<Dim: Dimension> {
146 pub(crate) stack: Vec<(Schedule<Dim>, Box<dyn TraceOutput>)>,
147}
148
149impl<Dim: Dimension> core::fmt::Debug for Trace<Dim> {
150 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
151 for (sched, trace) in self.iter() {
152 writeln!(f, "- {trace}")?;
153 if sched.dim().into_shape_with_order().size() == 1 {
154 writeln!(f, " {sched}")?;
155 }
156 }
157
158 Ok(())
159 }
160}
161
162impl<Dim: Dimension> Trace<Dim> {
163 pub fn new<T: TraceOutput>(sched: Schedule<Dim>, trace: T) -> Trace<Dim> {
165 Trace {
166 stack: vec![(sched, Box::new(trace))],
167 }
168 }
169
170 #[allow(clippy::missing_panics_doc)]
172 pub fn sched(&self) -> &Schedule<Dim> {
173 &self.stack.last().unwrap().0
175 }
176
177 #[allow(clippy::missing_panics_doc)]
179 pub fn into_sched(mut self) -> Schedule<Dim> {
180 self.stack.pop().unwrap().0
182 }
183
184 pub fn get<T: TraceOutput>(&self) -> Option<&T> {
186 for v in self.stack.iter().rev() {
187 if let Some(t) = (*v.1).as_any().downcast_ref::<T>() {
188 return Some(t);
189 }
190 }
191
192 None
193 }
194
195 pub fn with<T: TraceOutput>(mut self, sched: Schedule<Dim>, trace: T) -> Trace<Dim> {
197 self.stack.push((sched, Box::new(trace)));
198 self
199 }
200
201 pub fn iter(&self) -> impl Iterator<Item = (&Schedule<Dim>, &dyn TraceOutput)> {
203 self.stack.iter().map(|v| (&v.0, &*v.1))
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use core::any::TypeId;
210 use std::panic::resume_unwind;
211 use std::{fs, thread};
212
213 use alloc::vec::Vec;
214 use alloc::{borrow::ToOwned, sync::Arc};
215 use ndarray::{Array, Ix1};
216
217 use crate::{
218 modifiers::{FillCornersBuilder, Filter, PSFPolisher, TMFilter},
219 pdf::{exponential, qsin, unweighted, QSinBias},
220 DisplayMode, Schedule,
221 };
222
223 use super::{Averaging, Generator, Quantiles, RandomSampling, SinWeightedPoissonGap, Trace};
224
225 #[test]
226 fn trace() {
227 let s1 = Schedule::new(Array::from_vec(vec![true, false, true]));
228 let s2 = Schedule::new(Array::from_vec(vec![true, false, false]));
229 let s3 = Schedule::new(Array::from_vec(vec![false, false, true]));
230
231 let trace = Trace::new(s1.to_owned(), 1_u8)
232 .with(s2.to_owned(), 2_u16)
233 .with(s3.to_owned(), 3_u8);
234
235 assert_eq!(trace.sched(), &s3);
236
237 assert_eq!(*trace.get::<u8>().unwrap(), 3);
238 assert_eq!(trace.get::<u32>(), None);
239 trace
240 .iter()
241 .zip([
242 (&s1, TypeId::of::<u8>()),
243 (&s2, TypeId::of::<u16>()),
244 (&s3, TypeId::of::<u8>()),
245 ])
246 .for_each(|(a, b)| {
247 assert_eq!(a.0, b.0);
248 assert_eq!(a.1.type_id(), b.1);
249 });
250 }
251
252 #[test]
253 fn forwards_compatibility() {
254 let scheds: [(&'static str, Arc<dyn Generator<Ix1> + Send + Sync>); 4] = [
255 (
256 "qt",
257 Arc::from(Quantiles::new(|len| qsin(len, QSinBias::Low, 3.))),
258 ),
259 (
260 "pg",
261 Arc::from(
262 SinWeightedPoissonGap::new(*b"F R U' R' U' R U R F' R U R' U' ")
264 .fill_corners(|_, _| [1, 1]),
265 ),
266 ),
267 (
268 "ru",
269 Arc::from(
270 RandomSampling::new(unweighted, *b"Butter, Honey, Sugar, Cinnamon, ")
271 .fill_corners(|_, _| [1, 1]),
272 ),
273 ),
274 (
275 "av",
276 Arc::from(Averaging::new(
277 |v| exponential(v, 4.),
278 8,
279 *b"when life gives you f(x), f(henr",
280 )),
281 ),
282 ];
283
284 let configs = [
286 (64, 256, 8, false, false),
287 (64, 256, 8, true, false),
288 (64, 256, 8, false, true),
289 (64, 256, 8, true, true),
290 (128, 512, 12, false, false),
291 (128, 512, 12, true, false),
292 (128, 512, 12, false, true),
293 (128, 512, 12, true, true),
294 (96, 512, 12, true, true),
295 (96, 512, 12, false, false),
296 (96, 512, 12, true, false),
297 (96, 512, 12, false, true),
298 (52, 256, 8, false, false),
299 (52, 256, 8, true, true),
300 (192, 1024, 16, true, true),
301 (154, 1024, 16, true, true),
302 (308, 2048, 20, true, true),
303 (410, 4096, 24, true, true),
304 (20, 48, 5, true, true),
305 (32, 128, 6, true, true),
306 (48, 192, 6, false, false),
307 (48, 192, 6, false, true),
308 (48, 192, 6, true, false),
309 (48, 192, 6, true, true),
310 ];
311
312 let mut threads = Vec::new();
313
314 for (name, gen) in &scheds {
315 for (count, length, backfill, tm, itp) in configs {
316 let gen = Arc::clone(gen);
317 let name = *name;
318 threads.push(thread::spawn(move || {
319 let mut name = format!("{name}-{count}x{length}");
320
321 if tm {
322 name.push_str("-tm");
323 }
324
325 if itp {
326 name.push_str("-itp");
327 }
328
329 name.push_str(".sch");
330
331 let mut sched = (gen as Arc<dyn Generator<Ix1>>)
332 .fill_corners(|_, _| [backfill, 1])
333 .generate(count, Ix1(length));
334
335 if tm {
336 sched = TMFilter::new().filter(sched);
337 }
338
339 if itp {
340 sched = PSFPolisher::new(0.1, 0.32, DisplayMode::Abs).filter(sched);
341 }
342
343 let path = format!("src/generators/tests/forwards_compat/{name}");
344
345 println!("{}/{path}", std::env::current_dir().unwrap().display());
346
347 let target = fs::read_to_string(&path).unwrap();
348 let decoded = Schedule::decode(&target, crate::EncodingType::ZeroBased, |_| {
349 Ok(Ix1(length))
350 })
351 .unwrap();
352
353 assert_eq!(sched, decoded, "{}", path);
354 }));
355 }
356 }
357
358 let seed_variants = [
359 (52, 256, 0, false, false, 1),
360 (52, 256, 0, true, true, 1),
361 (52, 256, 0, false, false, 2),
362 (52, 256, 0, true, true, 2),
363 (52, 256, 0, false, false, 3),
364 (52, 256, 0, true, true, 3),
365 (52, 256, 0, false, false, 4),
366 (52, 256, 0, true, true, 4),
367 (52, 256, 0, false, false, 5),
368 (52, 256, 0, true, true, 5),
369 (52, 256, 0, false, false, 6),
370 (52, 256, 0, true, true, 6),
371 (52, 256, 0, false, false, 7),
372 (52, 256, 0, true, true, 7),
373 (52, 256, 0, false, false, 8),
374 (52, 256, 0, true, true, 8),
375 ];
376
377 for (count, length, backfill, tm, itp, iteration) in seed_variants {
378 let gen = Arc::clone(&scheds[1].1);
379 threads.push(thread::spawn(move || {
380 let mut name = format!("pg-{count}x{length}-{iteration}");
381
382 if tm {
383 name.push_str("-tm");
384 }
385
386 if itp {
387 name.push_str("-itp");
388 }
389
390 name.push_str(".sch");
391
392 let mut sched = (gen as Arc<dyn Generator<Ix1>>)
393 .fill_corners(|_, _| [backfill, 1])
394 .generate_with_iter(count, Ix1(length), iteration);
395
396 if tm {
397 sched = TMFilter::new().filter(sched);
398 }
399
400 if itp {
401 sched = PSFPolisher::new(0.1, 0.32, DisplayMode::Abs).filter(sched);
402 }
403
404 let path = format!("src/generators/tests/forwards_compat/{name}");
405
406 println!("{}/{path}", std::env::current_dir().unwrap().display());
407
408 let target = fs::read_to_string(path).unwrap();
409 let decoded =
410 Schedule::decode(&target, crate::EncodingType::ZeroBased, |_| Ok(Ix1(length)))
411 .unwrap();
412
413 assert_eq!(sched, decoded);
414 }));
415 }
416
417 for thread in threads {
418 if let Err(e) = thread.join() {
419 resume_unwind(e);
420 };
421 }
422 }
423}