use core::any::Any;
use core::fmt::Display;
use core::{any, fmt::Debug};
use crate::{Schedule, modifiers::Modifier};
mod averaging;
mod poisson_gap;
mod quantiles;
use alloc::{borrow::ToOwned, boxed::Box, vec, vec::Vec};
pub use averaging::*;
use ndarray::{Dimension, ShapeBuilder};
pub use poisson_gap::*;
pub use quantiles::*;
pub trait Generator<Dim: Dimension> {
fn generate(&self, count: usize, dims: Dim) -> Schedule<Dim> {
self.generate_with_iter_and_trace(count, dims, 0)
.into_sched()
}
fn generate_with_iter(&self, count: usize, dims: Dim, iteration: u64) -> Schedule<Dim> {
assert!(
count <= (0..dims.ndim()).map(|v| dims[v]).product(),
"Count must be less than the number of positions"
);
let sched = self._generate_no_trace(count, dims.to_owned(), iteration);
validate_schedule::<_, Self>(&sched, dims, count);
sched
}
fn generate_with_trace(&self, count: usize, dims: Dim) -> Trace<Dim> {
self.generate_with_iter_and_trace(count, dims, 0)
}
fn generate_with_iter_and_trace(&self, count: usize, dims: Dim, iteration: u64) -> Trace<Dim> {
assert!(
count <= (0..dims.ndim()).map(|v| dims[v]).product(),
"Count must be less than the number of positions"
);
let trace = self._generate(count, dims.to_owned(), iteration);
validate_schedule::<_, Self>(trace.sched(), dims, count);
trace
}
fn then<T: Modifier<Dim>>(self, modifier: T) -> T::Output<Self>
where
Self: Sized,
{
modifier.modify(self)
}
fn _generate(&self, count: usize, dims: Dim, iteration: u64) -> Trace<Dim>;
fn _generate_no_trace(&self, count: usize, dims: Dim, iteration: u64) -> Schedule<Dim> {
self._generate(count, dims, iteration).into_sched()
}
}
impl<T: core::ops::Deref<Target = dyn Generator<Dim>>, Dim: Dimension> Generator<Dim> for T {
fn _generate(&self, count: usize, dims: Dim, iteration: u64) -> Trace<Dim> {
(**self)._generate(count, dims, iteration)
}
fn _generate_no_trace(&self, count: usize, dims: Dim, iteration: u64) -> Schedule<Dim> {
(**self)._generate_no_trace(count, dims, iteration)
}
}
fn validate_schedule<Dim: Dimension, T: Generator<Dim> + ?Sized>(
sched: &Schedule<Dim>,
dims: Dim,
count: usize,
) {
let real_count = sched.iter().filter(|v| **v).count();
assert!(
real_count == count,
"Returned the wrong count (found {real_count}, expected {count})! In {}",
any::type_name::<T>()
);
assert!(
dims == sched.raw_dim(),
"Returned the wrong length (found {:?}, expected {:?})! In {}",
dims,
sched.raw_dim(),
any::type_name::<T>()
);
}
pub fn xor_iteration(mut seed: [u8; 32], iteration: u64) -> [u8; 32] {
for (i, byte) in iteration.to_le_bytes().into_iter().enumerate() {
seed[i] ^= byte;
}
seed
}
pub trait TraceOutput: Any + Debug + Display {}
impl<T: Any + Debug + Display> TraceOutput for T {}
pub struct Trace<Dim: Dimension> {
pub(crate) stack: Vec<(Schedule<Dim>, Box<dyn TraceOutput>)>,
}
impl<Dim: Dimension> Debug for Trace<Dim> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
for (sched, trace) in self.iter() {
writeln!(f, "- {trace}")?;
if sched.dim().into_shape_with_order().size() == 1 {
writeln!(f, " {sched}")?;
}
}
Ok(())
}
}
impl<Dim: Dimension> Trace<Dim> {
pub fn new<T: TraceOutput>(sched: Schedule<Dim>, trace: T) -> Trace<Dim> {
Trace {
stack: vec![(sched, Box::new(trace))],
}
}
#[allow(clippy::missing_panics_doc)]
pub fn sched(&self) -> &Schedule<Dim> {
&self.stack.last().unwrap().0
}
#[allow(clippy::missing_panics_doc)]
pub fn into_sched(mut self) -> Schedule<Dim> {
self.stack.pop().unwrap().0
}
pub fn get<T: TraceOutput>(&self) -> Option<&T> {
for v in self.stack.iter().rev() {
if let Some(t) = (&*v.1 as &dyn Any).downcast_ref::<T>() {
return Some(t);
}
}
None
}
pub fn with<T: TraceOutput>(mut self, sched: Schedule<Dim>, trace: T) -> Trace<Dim> {
self.stack.push((sched, Box::new(trace)));
self
}
pub fn iter(&self) -> impl Iterator<Item = (&Schedule<Dim>, &dyn TraceOutput)> {
self.stack.iter().map(|v| (&v.0, &*v.1))
}
}
#[cfg(test)]
mod tests {
use core::any::TypeId;
use std::panic::resume_unwind;
use std::{fs, thread};
use alloc::vec::Vec;
use alloc::{borrow::ToOwned, sync::Arc};
use ndarray::{Array, Ix1};
use crate::{
DisplayMode, Schedule,
modifiers::{FillCornersBuilder, Filter, PSFPolisher, TMFilter},
pdf::{QSinBias, exponential, qsin, unweighted},
};
use super::{Averaging, Generator, Quantiles, RandomSampling, SinWeightedPoissonGap, Trace};
#[test]
fn trace() {
let s1 = Schedule::new(Array::from_vec(vec![true, false, true]));
let s2 = Schedule::new(Array::from_vec(vec![true, false, false]));
let s3 = Schedule::new(Array::from_vec(vec![false, false, true]));
let trace = Trace::new(s1.to_owned(), 1_u8)
.with(s2.to_owned(), 2_u16)
.with(s3.to_owned(), 3_u8);
assert_eq!(trace.sched(), &s3);
assert_eq!(*trace.get::<u8>().unwrap(), 3);
assert_eq!(trace.get::<u32>(), None);
trace
.iter()
.zip([
(&s1, TypeId::of::<u8>()),
(&s2, TypeId::of::<u16>()),
(&s3, TypeId::of::<u8>()),
])
.for_each(|(a, b)| {
assert_eq!(a.0, b.0);
assert_eq!(a.1.type_id(), b.1);
});
}
#[test]
fn forwards_compatibility() {
let scheds: [(&'static str, Arc<dyn Generator<Ix1> + Send + Sync>); 4] = [
(
"qt",
Arc::from(Quantiles::new(|len| qsin(len, QSinBias::Low, 3.))),
),
(
"pg",
Arc::from(
SinWeightedPoissonGap::new(*b"F R U' R' U' R U R F' R U R' U' ")
.fill_corners(|_, _| [1, 1]),
),
),
(
"ru",
Arc::from(
RandomSampling::new(unweighted, *b"Butter, Honey, Sugar, Cinnamon, ")
.fill_corners(|_, _| [1, 1]),
),
),
(
"av",
Arc::from(Averaging::new(
|v| exponential(v, 4.),
8,
*b"when life gives you f(x), f(henr",
)),
),
];
let configs = [
(64, 256, 8, false, false),
(64, 256, 8, true, false),
(64, 256, 8, false, true),
(64, 256, 8, true, true),
(128, 512, 12, false, false),
(128, 512, 12, true, false),
(128, 512, 12, false, true),
(128, 512, 12, true, true),
(96, 512, 12, true, true),
(96, 512, 12, false, false),
(96, 512, 12, true, false),
(96, 512, 12, false, true),
(52, 256, 8, false, false),
(52, 256, 8, true, true),
(192, 1024, 16, true, true),
(154, 1024, 16, true, true),
(308, 2048, 20, true, true),
(410, 4096, 24, true, true),
(20, 48, 5, true, true),
(32, 128, 6, true, true),
(48, 192, 6, false, false),
(48, 192, 6, false, true),
(48, 192, 6, true, false),
(48, 192, 6, true, true),
];
let mut threads = Vec::new();
for (name, generator) in &scheds {
for (count, length, backfill, tm, itp) in configs {
let generator = Arc::clone(generator);
let name = *name;
threads.push(thread::spawn(move || {
let mut name = format!("{name}-{count}x{length}");
if tm {
name.push_str("-tm");
}
if itp {
name.push_str("-itp");
}
name.push_str(".sch");
let mut sched = (generator as Arc<dyn Generator<Ix1>>)
.fill_corners(|_, _| [backfill, 1])
.generate(count, Ix1(length));
if tm {
sched = TMFilter::new().filter(sched);
}
if itp {
sched = PSFPolisher::new(0.1, 0.32, DisplayMode::Abs).filter(sched);
}
let path = format!("src/generators/tests/forwards_compat/{name}");
println!("{}/{path}", std::env::current_dir().unwrap().display());
let target = fs::read_to_string(&path).unwrap();
let decoded = Schedule::decode(&target, crate::EncodingType::ZeroBased, |_| {
Ok(Ix1(length))
})
.unwrap();
assert_eq!(sched, decoded, "{}", path);
}));
}
}
let seed_variants = [
(52, 256, 0, false, false, 1),
(52, 256, 0, true, true, 1),
(52, 256, 0, false, false, 2),
(52, 256, 0, true, true, 2),
(52, 256, 0, false, false, 3),
(52, 256, 0, true, true, 3),
(52, 256, 0, false, false, 4),
(52, 256, 0, true, true, 4),
(52, 256, 0, false, false, 5),
(52, 256, 0, true, true, 5),
(52, 256, 0, false, false, 6),
(52, 256, 0, true, true, 6),
(52, 256, 0, false, false, 7),
(52, 256, 0, true, true, 7),
(52, 256, 0, false, false, 8),
(52, 256, 0, true, true, 8),
];
for (count, length, backfill, tm, itp, iteration) in seed_variants {
let generator = Arc::clone(&scheds[1].1);
threads.push(thread::spawn(move || {
let mut name = format!("pg-{count}x{length}-{iteration}");
if tm {
name.push_str("-tm");
}
if itp {
name.push_str("-itp");
}
name.push_str(".sch");
let mut sched = (generator as Arc<dyn Generator<Ix1>>)
.fill_corners(|_, _| [backfill, 1])
.generate_with_iter(count, Ix1(length), iteration);
if tm {
sched = TMFilter::new().filter(sched);
}
if itp {
sched = PSFPolisher::new(0.1, 0.32, DisplayMode::Abs).filter(sched);
}
let path = format!("src/generators/tests/forwards_compat/{name}");
println!("{}/{path}", std::env::current_dir().unwrap().display());
let target = fs::read_to_string(path).unwrap();
let decoded =
Schedule::decode(&target, crate::EncodingType::ZeroBased, |_| Ok(Ix1(length)))
.unwrap();
assert_eq!(sched, decoded);
}));
}
for thread in threads {
if let Err(e) = thread.join() {
resume_unwind(e);
};
}
}
}