1use core::fmt::Display;
2
3use alloc::vec;
4use ndarray::{Array, Ix1};
5
6use crate::{pdf::PdfGenerator, schedule::Schedule};
7
8use super::{Generator, Trace};
9
10#[derive(Clone, Copy, Debug)]
14pub struct Quantiles<G: PdfGenerator<Ix1>>(G);
15
16impl<G: PdfGenerator<Ix1>> Quantiles<G> {
17 pub const fn new(pdf: G) -> Quantiles<G> {
19 Quantiles(pdf)
20 }
21}
22
23#[derive(Clone, Copy, Debug)]
25pub struct QuantilesTrace;
26
27impl Display for QuantilesTrace {
28 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
29 write!(f, "No trace information")
30 }
31}
32
33impl<G: PdfGenerator<Ix1>> Generator<Ix1> for Quantiles<G> {
34 fn _generate_no_trace(&self, count: usize, dims: Ix1, _iteration: u64) -> Schedule<Ix1> {
35 if count == 0 {
36 return Schedule::new(Array::from_vec(vec![false; dims[0]]));
37 }
38
39 let pdf = self.0.get(dims).pop().unwrap();
40
41 let integral = pdf.continuous_integral();
42
43 let interval = 1. / (count - 1) as f64;
44
45 let mut sched = alloc::vec![false; pdf.len()];
46 let mut current_quantile = 0.;
47 let mut excess = 0;
48
49 for (spot, pos) in sched.iter_mut().enumerate() {
50 while integral(spot as f64 + 1.5) >= current_quantile * interval {
52 if *pos {
53 excess += 1;
54 }
55
56 *pos = true;
57 current_quantile += 1.;
58 }
59 }
60
61 let last = sched.len() - 1;
62 sched[last] = true;
63
64 let mut pos = 0;
65
66 while excess > 0 {
67 if !sched[pos] {
68 sched[pos] = true;
69 excess -= 1;
70 }
71
72 pos += 1;
73 }
74
75 Schedule::new(Array::from_vec(sched))
76 }
77
78 fn _generate(&self, count: usize, dims: Ix1, iteration: u64) -> Trace<Ix1> {
79 Trace::new(
80 self._generate_no_trace(count, dims, iteration),
81 QuantilesTrace,
82 )
83 }
84}
85
86#[cfg(test)]
87mod tests {
88 use core::f64::consts::PI;
89
90 use ndarray::Ix1;
91
92 use crate::{
93 generators::Generator,
94 modifiers::FillCornersBuilder,
95 pdf::{exponential, qsin, QSinBias},
96 EncodingType, Schedule,
97 };
98
99 use super::Quantiles;
100
101 #[test]
102 fn qsched_compatibility() {
103 let schedules = [
104 (
105 include_str!("./tests/qsched_compat/32x128-low.sch"),
106 Quantiles::new(|len| qsin(len, QSinBias::Low, PI)).generate(32, Ix1(128)),
107 ),
108 (
109 include_str!("./tests/qsched_compat/32x128-med.sch"),
110 Quantiles::new(|len| qsin(len, QSinBias::Med, 3.)).generate(32, Ix1(128)),
111 ),
112 (
113 include_str!("./tests/qsched_compat/32x128-high.sch"),
114 Quantiles::new(|len| qsin(len, QSinBias::High, 3.)).generate(32, Ix1(128)),
115 ),
116 (
117 include_str!("./tests/qsched_compat/32x128.sch"),
118 Quantiles::new(|len| qsin(len, QSinBias::Med, 3.))
119 .fill_corners(|_, _| [8, 1])
120 .generate(32, Ix1(128)),
121 ),
122 (
123 include_str!("./tests/qsched_compat/42x128.sch"),
124 Quantiles::new(|len| exponential(len, 3.))
125 .fill_corners(|_, _| [10, 1])
126 .generate(42, Ix1(128)),
127 ),
128 (
129 include_str!("./tests/qsched_compat/64x256.sch"),
130 Quantiles::new(|len| qsin(len, QSinBias::Med, 3.))
131 .fill_corners(|_, _| [12, 0])
132 .generate(64, Ix1(256)),
133 ),
134 (
135 include_str!("./tests/qsched_compat/85x256.sch"),
136 Quantiles::new(|len| exponential(len, 3.))
137 .fill_corners(|_, _| [18, 1])
138 .generate(85, Ix1(256)),
139 ),
140 (
141 include_str!("./tests/qsched_compat/130x512.sch"),
142 Quantiles::new(|len| qsin(len, QSinBias::Med, 3.))
143 .fill_corners(|_, _| [15, 0])
144 .generate(130, Ix1(512)),
145 ),
146 (
147 include_str!("./tests/qsched_compat/169x512.sch"),
148 Quantiles::new(|len| exponential(len, 3.))
149 .fill_corners(|_, _| [15, 1])
150 .generate(169, Ix1(512)),
151 ),
152 (
153 include_str!("./tests/qsched_compat/257x1024.sch"),
154 Quantiles::new(|len| qsin(len, QSinBias::Med, 3.))
155 .fill_corners(|_, _| [30, 0])
156 .generate(257, Ix1(1024)),
157 ),
158 (
159 include_str!("./tests/qsched_compat/337x1024.sch"),
160 Quantiles::new(|len| exponential(len, 3.))
161 .fill_corners(|_, _| [30, 1])
162 .generate(337, Ix1(1024)),
163 ),
164 ];
165
166 for (truth, generated) in schedules.into_iter() {
167 let truth = Schedule::decode(truth, EncodingType::ZeroBased, |v: Ix1| Ok(v)).unwrap();
168
169 assert_eq!(truth.count(), generated.count());
170 assert_eq!(truth, generated);
171 }
172 }
173}