nmr_schedule/generators/
mod.rs

1//! Implements various algorithms for generating new schedules.
2
3use core::any::Any;
4use core::fmt::Display;
5use core::{any, fmt::Debug};
6
7use crate::{Schedule, modifiers::Modifier};
8
9mod averaging;
10mod poisson_gap;
11mod quantiles;
12
13use alloc::{borrow::ToOwned, boxed::Box, vec, vec::Vec};
14pub use averaging::*;
15use ndarray::{Dimension, ShapeBuilder};
16pub use poisson_gap::*;
17pub use quantiles::*;
18
19/// Generates new schedules.
20///
21/// Generators take in the number of samples to generate, the dimensions of the schedule, and an `iteration` parameter.
22///
23/// The `iteration` parameter controls implementation-specific arbitrary parameters of the algorithm like random seeds. This allows seed searching using [`crate::modifiers::Iterate`].
24///
25/// Implementations are expected to generate schedules with the dimensions given and the number of samples specified by `count`. This is verified by assertions in the default implementations of `generate`, `generate_with_trace`, and `generate_with_iter_and_trace`.
26pub trait Generator<Dim: Dimension> {
27    /// Generate a schedule where the iteration parameter is set to zero.
28    fn generate(&self, count: usize, dims: Dim) -> Schedule<Dim> {
29        self.generate_with_iter_and_trace(count, dims, 0)
30            .into_sched()
31    }
32
33    /// Generate a schedule with a user-defined iteration parameter.
34    fn generate_with_iter(&self, count: usize, dims: Dim, iteration: u64) -> Schedule<Dim> {
35        assert!(
36            count <= (0..dims.ndim()).map(|v| dims[v]).product(),
37            "Count must be less than the number of positions"
38        );
39
40        let sched = self._generate_no_trace(count, dims.to_owned(), iteration);
41
42        validate_schedule::<_, Self>(&sched, dims, count);
43
44        sched
45    }
46
47    /// Generate a schedule including trace output from each generation step.
48    ///
49    /// The iteration parameter is set to zero.
50    fn generate_with_trace(&self, count: usize, dims: Dim) -> Trace<Dim> {
51        self.generate_with_iter_and_trace(count, dims, 0)
52    }
53
54    /// Generate a schedule with a user-defined iteration parameter while returning a trace.
55    fn generate_with_iter_and_trace(&self, count: usize, dims: Dim, iteration: u64) -> Trace<Dim> {
56        assert!(
57            count <= (0..dims.ndim()).map(|v| dims[v]).product(),
58            "Count must be less than the number of positions"
59        );
60
61        let trace = self._generate(count, dims.to_owned(), iteration);
62
63        validate_schedule::<_, Self>(trace.sched(), dims, count);
64
65        trace
66    }
67
68    /// Apply a modifier to the generator.
69    ///
70    /// You may use a modifier's builder method instead of this method directly if you do not need to determine the modifier at runtime.
71    fn then<T: Modifier<Dim>>(self, modifier: T) -> T::Output<Self>
72    where
73        Self: Sized,
74    {
75        modifier.modify(self)
76    }
77
78    /// The underlying implementation of a schedule generator. Users should not call this directly because it doesn't perform correctness assertions.
79    ///
80    /// Implementors should not override any other methods of `Generator` except possibly [`Generator::_generate_no_trace`].
81    ///
82    /// Implementors must push their trace output value to the top of the trace.
83    fn _generate(&self, count: usize, dims: Dim, iteration: u64) -> Trace<Dim>;
84
85    /// This function may be overridden when a generator can be sped up in cases where the trace is not needed. Users should not call this directly because it doesn't perform correctness assertions.
86    fn _generate_no_trace(&self, count: usize, dims: Dim, iteration: u64) -> Schedule<Dim> {
87        self._generate(count, dims, iteration).into_sched()
88    }
89}
90
91impl<T: core::ops::Deref<Target = dyn Generator<Dim>>, Dim: Dimension> Generator<Dim> for T {
92    fn _generate(&self, count: usize, dims: Dim, iteration: u64) -> Trace<Dim> {
93        (**self)._generate(count, dims, iteration)
94    }
95
96    fn _generate_no_trace(&self, count: usize, dims: Dim, iteration: u64) -> Schedule<Dim> {
97        (**self)._generate_no_trace(count, dims, iteration)
98    }
99}
100
101fn validate_schedule<Dim: Dimension, T: Generator<Dim> + ?Sized>(
102    sched: &Schedule<Dim>,
103    dims: Dim,
104    count: usize,
105) {
106    let real_count = sched.iter().filter(|v| **v).count();
107
108    assert!(
109        real_count == count,
110        "Returned the wrong count (found {real_count}, expected {count})! In {}",
111        any::type_name::<T>()
112    );
113
114    assert!(
115        dims == sched.raw_dim(),
116        "Returned the wrong length (found {:?}, expected {:?})! In {}",
117        dims,
118        sched.raw_dim(),
119        any::type_name::<T>()
120    );
121}
122
123/// A helper function that will perform the bitwise XOR of a seed with an iteration parameter.
124pub fn xor_iteration(mut seed: [u8; 32], iteration: u64) -> [u8; 32] {
125    for (i, byte) in iteration.to_le_bytes().into_iter().enumerate() {
126        seed[i] ^= byte;
127    }
128    seed
129}
130
131/// A trace outputted by a generator
132pub trait TraceOutput: Any + Debug + Display {}
133
134impl<T: Any + Debug + Display> TraceOutput for T {}
135
136/// A trace of the steps taken to generate a schedule.
137///
138/// The trace can be queried for the schedule output of each generator and filter, and each generator and filter can attach useful information to the trace detailing what it did to the schedule.
139pub struct Trace<Dim: Dimension> {
140    pub(crate) stack: Vec<(Schedule<Dim>, Box<dyn TraceOutput>)>,
141}
142
143impl<Dim: Dimension> Debug for Trace<Dim> {
144    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
145        for (sched, trace) in self.iter() {
146            writeln!(f, "- {trace}")?;
147            if sched.dim().into_shape_with_order().size() == 1 {
148                writeln!(f, "  {sched}")?;
149            }
150        }
151
152        Ok(())
153    }
154}
155
156impl<Dim: Dimension> Trace<Dim> {
157    /// Create a new trace that starts with `sched` and the trace output `trace`.
158    pub fn new<T: TraceOutput>(sched: Schedule<Dim>, trace: T) -> Trace<Dim> {
159        Trace {
160            stack: vec![(sched, Box::new(trace))],
161        }
162    }
163
164    /// Get the final schedule that was generated.
165    #[allow(clippy::missing_panics_doc)]
166    pub fn sched(&self) -> &Schedule<Dim> {
167        // The stack is guaranteed not to be empty
168        &self.stack.last().unwrap().0
169    }
170
171    /// Discard the trace and take ownership of the final schedule.
172    #[allow(clippy::missing_panics_doc)]
173    pub fn into_sched(mut self) -> Schedule<Dim> {
174        // The stack is guaranteed not to be empty
175        self.stack.pop().unwrap().0
176    }
177
178    /// Get the trace ouput of a particular generation step. Returns the highest value in the stack or `None` if it is not in the trace.
179    pub fn get<T: TraceOutput>(&self) -> Option<&T> {
180        for v in self.stack.iter().rev() {
181            if let Some(t) = (&*v.1 as &dyn Any).downcast_ref::<T>() {
182                return Some(t);
183            }
184        }
185
186        None
187    }
188
189    /// Push a schedule and trace output onto the stack.
190    pub fn with<T: TraceOutput>(mut self, sched: Schedule<Dim>, trace: T) -> Trace<Dim> {
191        self.stack.push((sched, Box::new(trace)));
192        self
193    }
194
195    /// Iterate over all generation steps.
196    pub fn iter(&self) -> impl Iterator<Item = (&Schedule<Dim>, &dyn TraceOutput)> {
197        self.stack.iter().map(|v| (&v.0, &*v.1))
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use core::any::TypeId;
204    use std::panic::resume_unwind;
205    use std::{fs, thread};
206
207    use alloc::vec::Vec;
208    use alloc::{borrow::ToOwned, sync::Arc};
209    use ndarray::{Array, Ix1};
210
211    use crate::{
212        DisplayMode, Schedule,
213        modifiers::{FillCornersBuilder, Filter, PSFPolisher, TMFilter},
214        pdf::{QSinBias, exponential, qsin, unweighted},
215    };
216
217    use super::{Averaging, Generator, Quantiles, RandomSampling, SinWeightedPoissonGap, Trace};
218
219    #[test]
220    fn trace() {
221        let s1 = Schedule::new(Array::from_vec(vec![true, false, true]));
222        let s2 = Schedule::new(Array::from_vec(vec![true, false, false]));
223        let s3 = Schedule::new(Array::from_vec(vec![false, false, true]));
224
225        let trace = Trace::new(s1.to_owned(), 1_u8)
226            .with(s2.to_owned(), 2_u16)
227            .with(s3.to_owned(), 3_u8);
228
229        assert_eq!(trace.sched(), &s3);
230
231        assert_eq!(*trace.get::<u8>().unwrap(), 3);
232        assert_eq!(trace.get::<u32>(), None);
233        trace
234            .iter()
235            .zip([
236                (&s1, TypeId::of::<u8>()),
237                (&s2, TypeId::of::<u16>()),
238                (&s3, TypeId::of::<u8>()),
239            ])
240            .for_each(|(a, b)| {
241                assert_eq!(a.0, b.0);
242                assert_eq!(a.1.type_id(), b.1);
243            });
244    }
245
246    #[test]
247    fn forwards_compatibility() {
248        let scheds: [(&'static str, Arc<dyn Generator<Ix1> + Send + Sync>); 4] = [
249            (
250                "qt",
251                Arc::from(Quantiles::new(|len| qsin(len, QSinBias::Low, 3.))),
252            ),
253            (
254                "pg",
255                Arc::from(
256                    // Y-Perm
257                    SinWeightedPoissonGap::new(*b"F R U' R' U' R U R F' R U R' U' ")
258                        .fill_corners(|_, _| [1, 1]),
259                ),
260            ),
261            (
262                "ru",
263                Arc::from(
264                    RandomSampling::new(unweighted, *b"Butter, Honey, Sugar, Cinnamon, ")
265                        .fill_corners(|_, _| [1, 1]),
266                ),
267            ),
268            (
269                "av",
270                Arc::from(Averaging::new(
271                    |v| exponential(v, 4.),
272                    8,
273                    *b"when life gives you f(x), f(henr",
274                )),
275            ),
276        ];
277
278        // (count, length, backfill, TM, ITP)
279        let configs = [
280            (64, 256, 8, false, false),
281            (64, 256, 8, true, false),
282            (64, 256, 8, false, true),
283            (64, 256, 8, true, true),
284            (128, 512, 12, false, false),
285            (128, 512, 12, true, false),
286            (128, 512, 12, false, true),
287            (128, 512, 12, true, true),
288            (96, 512, 12, true, true),
289            (96, 512, 12, false, false),
290            (96, 512, 12, true, false),
291            (96, 512, 12, false, true),
292            (52, 256, 8, false, false),
293            (52, 256, 8, true, true),
294            (192, 1024, 16, true, true),
295            (154, 1024, 16, true, true),
296            (308, 2048, 20, true, true),
297            (410, 4096, 24, true, true),
298            (20, 48, 5, true, true),
299            (32, 128, 6, true, true),
300            (48, 192, 6, false, false),
301            (48, 192, 6, false, true),
302            (48, 192, 6, true, false),
303            (48, 192, 6, true, true),
304        ];
305
306        let mut threads = Vec::new();
307
308        for (name, generator) in &scheds {
309            for (count, length, backfill, tm, itp) in configs {
310                let generator = Arc::clone(generator);
311                let name = *name;
312                threads.push(thread::spawn(move || {
313                    let mut name = format!("{name}-{count}x{length}");
314
315                    if tm {
316                        name.push_str("-tm");
317                    }
318
319                    if itp {
320                        name.push_str("-itp");
321                    }
322
323                    name.push_str(".sch");
324
325                    let mut sched = (generator as Arc<dyn Generator<Ix1>>)
326                        .fill_corners(|_, _| [backfill, 1])
327                        .generate(count, Ix1(length));
328
329                    if tm {
330                        sched = TMFilter::new().filter(sched);
331                    }
332
333                    if itp {
334                        sched = PSFPolisher::new(0.1, 0.32, DisplayMode::Abs).filter(sched);
335                    }
336
337                    let path = format!("src/generators/tests/forwards_compat/{name}");
338
339                    println!("{}/{path}", std::env::current_dir().unwrap().display());
340
341                    let target = fs::read_to_string(&path).unwrap();
342                    let decoded = Schedule::decode(&target, crate::EncodingType::ZeroBased, |_| {
343                        Ok(Ix1(length))
344                    })
345                    .unwrap();
346
347                    assert_eq!(sched, decoded, "{}", path);
348                }));
349            }
350        }
351
352        let seed_variants = [
353            (52, 256, 0, false, false, 1),
354            (52, 256, 0, true, true, 1),
355            (52, 256, 0, false, false, 2),
356            (52, 256, 0, true, true, 2),
357            (52, 256, 0, false, false, 3),
358            (52, 256, 0, true, true, 3),
359            (52, 256, 0, false, false, 4),
360            (52, 256, 0, true, true, 4),
361            (52, 256, 0, false, false, 5),
362            (52, 256, 0, true, true, 5),
363            (52, 256, 0, false, false, 6),
364            (52, 256, 0, true, true, 6),
365            (52, 256, 0, false, false, 7),
366            (52, 256, 0, true, true, 7),
367            (52, 256, 0, false, false, 8),
368            (52, 256, 0, true, true, 8),
369        ];
370
371        for (count, length, backfill, tm, itp, iteration) in seed_variants {
372            let generator = Arc::clone(&scheds[1].1);
373            threads.push(thread::spawn(move || {
374                let mut name = format!("pg-{count}x{length}-{iteration}");
375
376                if tm {
377                    name.push_str("-tm");
378                }
379
380                if itp {
381                    name.push_str("-itp");
382                }
383
384                name.push_str(".sch");
385
386                let mut sched = (generator as Arc<dyn Generator<Ix1>>)
387                    .fill_corners(|_, _| [backfill, 1])
388                    .generate_with_iter(count, Ix1(length), iteration);
389
390                if tm {
391                    sched = TMFilter::new().filter(sched);
392                }
393
394                if itp {
395                    sched = PSFPolisher::new(0.1, 0.32, DisplayMode::Abs).filter(sched);
396                }
397
398                let path = format!("src/generators/tests/forwards_compat/{name}");
399
400                println!("{}/{path}", std::env::current_dir().unwrap().display());
401
402                let target = fs::read_to_string(path).unwrap();
403                let decoded =
404                    Schedule::decode(&target, crate::EncodingType::ZeroBased, |_| Ok(Ix1(length)))
405                        .unwrap();
406
407                assert_eq!(sched, decoded);
408            }));
409        }
410
411        for thread in threads {
412            if let Err(e) = thread.join() {
413                resume_unwind(e);
414            };
415        }
416    }
417}