nmr_schedule/generators/
mod.rs

1//! Implements various algorithms for generating new schedules.
2
3use core::any;
4use core::any::Any;
5
6use crate::{modifiers::Modifier, Schedule};
7
8mod averaging;
9mod poisson_gap;
10mod quantiles;
11
12use alloc::{borrow::ToOwned, boxed::Box, vec, vec::Vec};
13pub use averaging::*;
14use ndarray::{Dimension, ShapeBuilder};
15pub use poisson_gap::*;
16pub use quantiles::*;
17
18/// Generates new schedules.
19///
20/// Generators take in the number of samples to generate, the dimensions of the schedule, and an `iteration` parameter.
21///
22/// The `iteration` parameter controls implementation-specific arbitrary parameters of the algorithm like random seeds. This allows seed searching using [`crate::modifiers::Iterate`].
23///
24/// Implementations are expected to generate schedules with the dimensions given and the number of samples specified by `count`. This is verified by assertions in the default implementations of `generate`, `generate_with_trace`, and `generate_with_iter_and_trace`.
25pub trait Generator<Dim: Dimension> {
26    /// Generate a schedule where the iteration parameter is set to zero.
27    fn generate(&self, count: usize, dims: Dim) -> Schedule<Dim> {
28        self.generate_with_iter_and_trace(count, dims, 0)
29            .into_sched()
30    }
31
32    /// Generate a schedule with a user-defined iteration parameter.
33    fn generate_with_iter(&self, count: usize, dims: Dim, iteration: u64) -> Schedule<Dim> {
34        assert!(
35            count <= (0..dims.ndim()).map(|v| dims[v]).product(),
36            "Count must be less than the number of positions"
37        );
38
39        let sched = self._generate_no_trace(count, dims.to_owned(), iteration);
40
41        validate_schedule::<_, Self>(&sched, dims, count);
42
43        sched
44    }
45
46    /// Generate a schedule including trace output from each generation step.
47    ///
48    /// The iteration parameter is set to zero.
49    fn generate_with_trace(&self, count: usize, dims: Dim) -> Trace<Dim> {
50        self.generate_with_iter_and_trace(count, dims, 0)
51    }
52
53    /// Generate a schedule with a user-defined iteration parameter while returning a trace.
54    fn generate_with_iter_and_trace(&self, count: usize, dims: Dim, iteration: u64) -> Trace<Dim> {
55        assert!(
56            count <= (0..dims.ndim()).map(|v| dims[v]).product(),
57            "Count must be less than the number of positions"
58        );
59
60        let trace = self._generate(count, dims.to_owned(), iteration);
61
62        validate_schedule::<_, Self>(trace.sched(), dims, count);
63
64        trace
65    }
66
67    /// Apply a modifier to the generator.
68    ///
69    /// You may use a modifier's builder method instead of this method directly if you do not need to determine the modifier at runtime.
70    fn then<T: Modifier<Dim>>(self, modifier: T) -> T::Output<Self>
71    where
72        Self: Sized,
73    {
74        modifier.modify(self)
75    }
76
77    /// The underlying implementation of a schedule generator. Users should not call this directly because it doesn't perform correctness assertions.
78    ///
79    /// Implementors should not override any other methods of `Generator` except possibly [`Generator::_generate_no_trace`].
80    ///
81    /// Implementors must push their trace output value to the top of the trace.
82    fn _generate(&self, count: usize, dims: Dim, iteration: u64) -> Trace<Dim>;
83
84    /// This function may be overridden when a generator can be sped up in cases where the trace is not needed. Users should not call this directly because it doesn't perform correctness assertions.
85    fn _generate_no_trace(&self, count: usize, dims: Dim, iteration: u64) -> Schedule<Dim> {
86        self._generate(count, dims, iteration).into_sched()
87    }
88}
89
90impl<T: core::ops::Deref<Target = dyn Generator<Dim>>, Dim: Dimension> Generator<Dim> for T {
91    fn _generate(&self, count: usize, dims: Dim, iteration: u64) -> Trace<Dim> {
92        (**self)._generate(count, dims, iteration)
93    }
94
95    fn _generate_no_trace(&self, count: usize, dims: Dim, iteration: u64) -> Schedule<Dim> {
96        (**self)._generate_no_trace(count, dims, iteration)
97    }
98}
99
100fn validate_schedule<Dim: Dimension, T: Generator<Dim> + ?Sized>(
101    sched: &Schedule<Dim>,
102    dims: Dim,
103    count: usize,
104) {
105    let real_count = sched.iter().filter(|v| **v).count();
106
107    assert!(
108        real_count == count,
109        "Returned the wrong count (found {real_count}, expected {count})! In {}",
110        any::type_name::<T>()
111    );
112
113    assert!(
114        dims == sched.raw_dim(),
115        "Returned the wrong length (found {:?}, expected {:?})! In {}",
116        dims,
117        sched.raw_dim(),
118        any::type_name::<T>()
119    );
120}
121
122/// A helper function that will perform the bitwise XOR of a seed with an iteration parameter.
123pub fn xor_iteration(mut seed: [u8; 32], iteration: u64) -> [u8; 32] {
124    for (i, byte) in iteration.to_le_bytes().into_iter().enumerate() {
125        seed[i] ^= byte;
126    }
127    seed
128}
129
130/// A trace outputted by a generator
131pub trait TraceOutput: Any + core::fmt::Debug + core::fmt::Display {
132    /// Upcast the value to `&dyn Any`. This will be removed when trait upcasting is stabilized.
133    fn as_any(&self) -> &dyn Any;
134}
135
136impl<T: Any + core::fmt::Debug + core::fmt::Display> TraceOutput for T {
137    fn as_any(&self) -> &dyn Any {
138        self
139    }
140}
141
142/// A trace of the steps taken to generate a schedule.
143///
144/// The trace can be queried for the schedule output of each generator and filter, and each generator and filter can attach useful information to the trace detailing what it did to the schedule.
145pub struct Trace<Dim: Dimension> {
146    pub(crate) stack: Vec<(Schedule<Dim>, Box<dyn TraceOutput>)>,
147}
148
149impl<Dim: Dimension> core::fmt::Debug for Trace<Dim> {
150    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
151        for (sched, trace) in self.iter() {
152            writeln!(f, "- {trace}")?;
153            if sched.dim().into_shape_with_order().size() == 1 {
154                writeln!(f, "  {sched}")?;
155            }
156        }
157
158        Ok(())
159    }
160}
161
162impl<Dim: Dimension> Trace<Dim> {
163    /// Create a new trace that starts with `sched` and the trace output `trace`.
164    pub fn new<T: TraceOutput>(sched: Schedule<Dim>, trace: T) -> Trace<Dim> {
165        Trace {
166            stack: vec![(sched, Box::new(trace))],
167        }
168    }
169
170    /// Get the final schedule that was generated.
171    #[allow(clippy::missing_panics_doc)]
172    pub fn sched(&self) -> &Schedule<Dim> {
173        // The stack is guaranteed not to be empty
174        &self.stack.last().unwrap().0
175    }
176
177    /// Discard the trace and take ownership of the final schedule.
178    #[allow(clippy::missing_panics_doc)]
179    pub fn into_sched(mut self) -> Schedule<Dim> {
180        // The stack is guaranteed not to be empty
181        self.stack.pop().unwrap().0
182    }
183
184    /// Get the trace ouput of a particular generation step. Returns the highest value in the stack or `None` if it is not in the trace.
185    pub fn get<T: TraceOutput>(&self) -> Option<&T> {
186        for v in self.stack.iter().rev() {
187            if let Some(t) = (*v.1).as_any().downcast_ref::<T>() {
188                return Some(t);
189            }
190        }
191
192        None
193    }
194
195    /// Push a schedule and trace output onto the stack.
196    pub fn with<T: TraceOutput>(mut self, sched: Schedule<Dim>, trace: T) -> Trace<Dim> {
197        self.stack.push((sched, Box::new(trace)));
198        self
199    }
200
201    /// Iterate over all generation steps.
202    pub fn iter(&self) -> impl Iterator<Item = (&Schedule<Dim>, &dyn TraceOutput)> {
203        self.stack.iter().map(|v| (&v.0, &*v.1))
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use core::any::TypeId;
210    use std::panic::resume_unwind;
211    use std::{fs, thread};
212
213    use alloc::vec::Vec;
214    use alloc::{borrow::ToOwned, sync::Arc};
215    use ndarray::{Array, Ix1};
216
217    use crate::{
218        modifiers::{FillCornersBuilder, Filter, PSFPolisher, TMFilter},
219        pdf::{exponential, qsin, unweighted, QSinBias},
220        DisplayMode, Schedule,
221    };
222
223    use super::{Averaging, Generator, Quantiles, RandomSampling, SinWeightedPoissonGap, Trace};
224
225    #[test]
226    fn trace() {
227        let s1 = Schedule::new(Array::from_vec(vec![true, false, true]));
228        let s2 = Schedule::new(Array::from_vec(vec![true, false, false]));
229        let s3 = Schedule::new(Array::from_vec(vec![false, false, true]));
230
231        let trace = Trace::new(s1.to_owned(), 1_u8)
232            .with(s2.to_owned(), 2_u16)
233            .with(s3.to_owned(), 3_u8);
234
235        assert_eq!(trace.sched(), &s3);
236
237        assert_eq!(*trace.get::<u8>().unwrap(), 3);
238        assert_eq!(trace.get::<u32>(), None);
239        trace
240            .iter()
241            .zip([
242                (&s1, TypeId::of::<u8>()),
243                (&s2, TypeId::of::<u16>()),
244                (&s3, TypeId::of::<u8>()),
245            ])
246            .for_each(|(a, b)| {
247                assert_eq!(a.0, b.0);
248                assert_eq!(a.1.type_id(), b.1);
249            });
250    }
251
252    #[test]
253    fn forwards_compatibility() {
254        let scheds: [(&'static str, Arc<dyn Generator<Ix1> + Send + Sync>); 4] = [
255            (
256                "qt",
257                Arc::from(Quantiles::new(|len| qsin(len, QSinBias::Low, 3.))),
258            ),
259            (
260                "pg",
261                Arc::from(
262                    // Y-Perm
263                    SinWeightedPoissonGap::new(*b"F R U' R' U' R U R F' R U R' U' ")
264                        .fill_corners(|_, _| [1, 1]),
265                ),
266            ),
267            (
268                "ru",
269                Arc::from(
270                    RandomSampling::new(unweighted, *b"Butter, Honey, Sugar, Cinnamon, ")
271                        .fill_corners(|_, _| [1, 1]),
272                ),
273            ),
274            (
275                "av",
276                Arc::from(Averaging::new(
277                    |v| exponential(v, 4.),
278                    8,
279                    *b"when life gives you f(x), f(henr",
280                )),
281            ),
282        ];
283
284        // (count, length, backfill, TM, ITP)
285        let configs = [
286            (64, 256, 8, false, false),
287            (64, 256, 8, true, false),
288            (64, 256, 8, false, true),
289            (64, 256, 8, true, true),
290            (128, 512, 12, false, false),
291            (128, 512, 12, true, false),
292            (128, 512, 12, false, true),
293            (128, 512, 12, true, true),
294            (96, 512, 12, true, true),
295            (96, 512, 12, false, false),
296            (96, 512, 12, true, false),
297            (96, 512, 12, false, true),
298            (52, 256, 8, false, false),
299            (52, 256, 8, true, true),
300            (192, 1024, 16, true, true),
301            (154, 1024, 16, true, true),
302            (308, 2048, 20, true, true),
303            (410, 4096, 24, true, true),
304            (20, 48, 5, true, true),
305            (32, 128, 6, true, true),
306            (48, 192, 6, false, false),
307            (48, 192, 6, false, true),
308            (48, 192, 6, true, false),
309            (48, 192, 6, true, true),
310        ];
311
312        let mut threads = Vec::new();
313
314        for (name, gen) in &scheds {
315            for (count, length, backfill, tm, itp) in configs {
316                let gen = Arc::clone(gen);
317                let name = *name;
318                threads.push(thread::spawn(move || {
319                    let mut name = format!("{name}-{count}x{length}");
320
321                    if tm {
322                        name.push_str("-tm");
323                    }
324
325                    if itp {
326                        name.push_str("-itp");
327                    }
328
329                    name.push_str(".sch");
330
331                    let mut sched = (gen as Arc<dyn Generator<Ix1>>)
332                        .fill_corners(|_, _| [backfill, 1])
333                        .generate(count, Ix1(length));
334
335                    if tm {
336                        sched = TMFilter::new().filter(sched);
337                    }
338
339                    if itp {
340                        sched = PSFPolisher::new(0.1, 0.32, DisplayMode::Abs).filter(sched);
341                    }
342
343                    let path = format!("src/generators/tests/forwards_compat/{name}");
344
345                    println!("{}/{path}", std::env::current_dir().unwrap().display());
346
347                    let target = fs::read_to_string(&path).unwrap();
348                    let decoded = Schedule::decode(&target, crate::EncodingType::ZeroBased, |_| {
349                        Ok(Ix1(length))
350                    })
351                    .unwrap();
352
353                    assert_eq!(sched, decoded, "{}", path);
354                }));
355            }
356        }
357
358        let seed_variants = [
359            (52, 256, 0, false, false, 1),
360            (52, 256, 0, true, true, 1),
361            (52, 256, 0, false, false, 2),
362            (52, 256, 0, true, true, 2),
363            (52, 256, 0, false, false, 3),
364            (52, 256, 0, true, true, 3),
365            (52, 256, 0, false, false, 4),
366            (52, 256, 0, true, true, 4),
367            (52, 256, 0, false, false, 5),
368            (52, 256, 0, true, true, 5),
369            (52, 256, 0, false, false, 6),
370            (52, 256, 0, true, true, 6),
371            (52, 256, 0, false, false, 7),
372            (52, 256, 0, true, true, 7),
373            (52, 256, 0, false, false, 8),
374            (52, 256, 0, true, true, 8),
375        ];
376
377        for (count, length, backfill, tm, itp, iteration) in seed_variants {
378            let gen = Arc::clone(&scheds[1].1);
379            threads.push(thread::spawn(move || {
380                let mut name = format!("pg-{count}x{length}-{iteration}");
381
382                if tm {
383                    name.push_str("-tm");
384                }
385
386                if itp {
387                    name.push_str("-itp");
388                }
389
390                name.push_str(".sch");
391
392                let mut sched = (gen as Arc<dyn Generator<Ix1>>)
393                    .fill_corners(|_, _| [backfill, 1])
394                    .generate_with_iter(count, Ix1(length), iteration);
395
396                if tm {
397                    sched = TMFilter::new().filter(sched);
398                }
399
400                if itp {
401                    sched = PSFPolisher::new(0.1, 0.32, DisplayMode::Abs).filter(sched);
402                }
403
404                let path = format!("src/generators/tests/forwards_compat/{name}");
405
406                println!("{}/{path}", std::env::current_dir().unwrap().display());
407
408                let target = fs::read_to_string(path).unwrap();
409                let decoded =
410                    Schedule::decode(&target, crate::EncodingType::ZeroBased, |_| Ok(Ix1(length)))
411                        .unwrap();
412
413                assert_eq!(sched, decoded);
414            }));
415        }
416
417        for thread in threads {
418            if let Err(e) = thread.join() {
419                resume_unwind(e);
420            };
421        }
422    }
423}