nmr_schedule/generators/
averaging.rs

1use core::fmt::Display;
2
3use alloc::{collections::BTreeSet, vec::Vec};
4
5use ndarray::{Array, Ix1};
6use rand::{Rng, SeedableRng};
7use rand_chacha::ChaCha12Rng;
8
9use crate::{generators::xor_iteration, pdf::PdfGenerator, quickselect, Schedule};
10
11use super::{Generator, Trace};
12
13/// A generator that randomly samples the PDF without replacement.
14///
15/// Iteration alters the random seed.
16#[derive(Clone, Copy, Debug)]
17pub struct RandomSampling<G: PdfGenerator<Ix1>>(G, [u8; 32]);
18
19impl<G: PdfGenerator<Ix1>> RandomSampling<G> {
20    /// Create a new `RandomSampling` from a seed.
21    pub const fn new(pdf: G, seed: [u8; 32]) -> RandomSampling<G> {
22        RandomSampling(pdf, seed)
23    }
24}
25
26/// Trace information for `RandomSampling`. Currently empty.
27#[derive(Clone, Copy, Debug)]
28pub struct RandomSamplingTrace;
29
30impl Display for RandomSamplingTrace {
31    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
32        write!(f, "No trace information")
33    }
34}
35
36impl<G: PdfGenerator<Ix1>> Generator<Ix1> for RandomSampling<G> {
37    // P. Efraimidis, P. Spirakis, Information Processing Letters, 97, 181-185. <https://doi.org/10.1016/j.ipl.2005.11.003>
38    fn _generate_no_trace(&self, count: usize, dims: Ix1, iteration: u64) -> Schedule<Ix1> {
39        let pdf = self.0.get(dims).pop().unwrap();
40        let mut rng = ChaCha12Rng::from_seed(xor_iteration(self.1, iteration));
41
42        let mut values = pdf
43            .get_distribution()
44            .iter()
45            .enumerate()
46            .map(|(i, v)| (rng.random::<f64>().powf(v.recip()), i))
47            .collect::<Vec<_>>();
48
49        // Make it so all items greater than the count-th largest item are below the count-th index
50        quickselect(&mut rng, &mut values, |a, b| a.0.total_cmp(&b.0), count);
51
52        let mut sched = alloc::vec![false; pdf.len()];
53
54        for i in 0..count {
55            sched[values[i].1] = true;
56        }
57
58        Schedule::new(Array::from_vec(sched))
59    }
60
61    fn _generate(&self, count: usize, dims: Ix1, iteration: u64) -> Trace<Ix1> {
62        Trace::new(
63            self._generate_no_trace(count, dims, iteration),
64            AveragingTrace,
65        )
66    }
67}
68
69/// TODO: Citation
70/// A generator that...
71/// 1. Generates a specified number of randomly sampled schedules
72/// 2. Lists the positions of the samples in sorted order
73/// 3. Averages each position across each of the randomly sampled schedules
74/// 4. Quantizes them to the Nyquist grid.
75///
76/// The iteration parameter alters the random seed.
77#[derive(Clone, Copy, Debug)]
78pub struct Averaging<G: PdfGenerator<Ix1>> {
79    avg_count: usize,
80    random: RandomSampling<G>,
81}
82
83impl<G: PdfGenerator<Ix1>> Averaging<G> {
84    /// Create a new `AveragedSchedule` where `pdf` is the PDF, `avg_count` is the number of random schedules to average, and `seed` is the random seed.
85    pub const fn new(pdf: G, avg_count: usize, seed: [u8; 32]) -> Averaging<G> {
86        Averaging {
87            avg_count,
88            random: RandomSampling::new(pdf, seed),
89        }
90    }
91}
92
93/// Trace information for `Averaging`. Currently empty.
94#[derive(Clone, Copy, Debug)]
95pub struct AveragingTrace;
96
97impl Display for AveragingTrace {
98    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
99        write!(f, "No trace information")
100    }
101}
102
103impl<G: PdfGenerator<Ix1>> Generator<Ix1> for Averaging<G> {
104    fn _generate_no_trace(&self, count: usize, dims: Ix1, iteration: u64) -> Schedule<Ix1> {
105        let mut sum = alloc::vec![0; count];
106
107        for i in 0..self.avg_count {
108            let sched = self
109                .random
110                .generate_with_iter(count, dims, iteration + i as u64);
111
112            let mut found = 0;
113            for (i, item) in sched.iter().enumerate() {
114                if *item {
115                    sum[found] += i;
116                    found += 1;
117                }
118            }
119        }
120
121        for value in sum.iter_mut() {
122            *value /= self.avg_count;
123        }
124
125        assert!(sum.iter().collect::<BTreeSet<_>>().len() == count);
126
127        let mut result = alloc::vec![false; dims[0]];
128
129        for value in sum {
130            result[value] = true;
131        }
132
133        Schedule::new(Array::from_vec(result))
134    }
135
136    fn _generate(&self, count: usize, dims: Ix1, iteration: u64) -> Trace<Ix1> {
137        Trace::new(
138            self._generate_no_trace(count, dims, iteration),
139            AveragingTrace,
140        )
141    }
142}