nmr_schedule/generators/
averaging.rs1use core::fmt::Display;
2
3use alloc::{collections::BTreeSet, vec::Vec};
4
5use ndarray::{Array, Ix1};
6use rand::{Rng, SeedableRng};
7use rand_chacha::ChaCha12Rng;
8
9use crate::{generators::xor_iteration, pdf::PdfGenerator, quickselect, Schedule};
10
11use super::{Generator, Trace};
12
13#[derive(Clone, Copy, Debug)]
17pub struct RandomSampling<G: PdfGenerator<Ix1>>(G, [u8; 32]);
18
19impl<G: PdfGenerator<Ix1>> RandomSampling<G> {
20 pub const fn new(pdf: G, seed: [u8; 32]) -> RandomSampling<G> {
22 RandomSampling(pdf, seed)
23 }
24}
25
26#[derive(Clone, Copy, Debug)]
28pub struct RandomSamplingTrace;
29
30impl Display for RandomSamplingTrace {
31 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
32 write!(f, "No trace information")
33 }
34}
35
36impl<G: PdfGenerator<Ix1>> Generator<Ix1> for RandomSampling<G> {
37 fn _generate_no_trace(&self, count: usize, dims: Ix1, iteration: u64) -> Schedule<Ix1> {
39 let pdf = self.0.get(dims).pop().unwrap();
40 let mut rng = ChaCha12Rng::from_seed(xor_iteration(self.1, iteration));
41
42 let mut values = pdf
43 .get_distribution()
44 .iter()
45 .enumerate()
46 .map(|(i, v)| (rng.random::<f64>().powf(v.recip()), i))
47 .collect::<Vec<_>>();
48
49 quickselect(&mut rng, &mut values, |a, b| a.0.total_cmp(&b.0), count);
51
52 let mut sched = alloc::vec![false; pdf.len()];
53
54 for i in 0..count {
55 sched[values[i].1] = true;
56 }
57
58 Schedule::new(Array::from_vec(sched))
59 }
60
61 fn _generate(&self, count: usize, dims: Ix1, iteration: u64) -> Trace<Ix1> {
62 Trace::new(
63 self._generate_no_trace(count, dims, iteration),
64 AveragingTrace,
65 )
66 }
67}
68
69#[derive(Clone, Copy, Debug)]
79pub struct Averaging<G: PdfGenerator<Ix1>> {
80 avg_count: usize,
81 random: RandomSampling<G>,
82}
83
84impl<G: PdfGenerator<Ix1>> Averaging<G> {
85 pub const fn new(pdf: G, avg_count: usize, seed: [u8; 32]) -> Averaging<G> {
87 Averaging {
88 avg_count,
89 random: RandomSampling::new(pdf, seed),
90 }
91 }
92}
93
94#[derive(Clone, Copy, Debug)]
96pub struct AveragingTrace;
97
98impl Display for AveragingTrace {
99 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
100 write!(f, "No trace information")
101 }
102}
103
104impl<G: PdfGenerator<Ix1>> Generator<Ix1> for Averaging<G> {
105 fn _generate_no_trace(&self, count: usize, dims: Ix1, iteration: u64) -> Schedule<Ix1> {
106 let mut sum = alloc::vec![0; count];
107
108 for i in 0..self.avg_count {
109 let sched = self
110 .random
111 .generate_with_iter(count, dims, iteration + i as u64);
112
113 let mut found = 0;
114 for (i, item) in sched.iter().enumerate() {
115 if *item {
116 sum[found] += i;
117 found += 1;
118 }
119 }
120 }
121
122 for value in sum.iter_mut() {
123 *value /= self.avg_count;
124 }
125
126 assert!(sum.iter().collect::<BTreeSet<_>>().len() == count);
127
128 let mut result = alloc::vec![false; dims[0]];
129
130 for value in sum {
131 result[value] = true;
132 }
133
134 Schedule::new(Array::from_vec(result))
135 }
136
137 fn _generate(&self, count: usize, dims: Ix1, iteration: u64) -> Trace<Ix1> {
138 Trace::new(
139 self._generate_no_trace(count, dims, iteration),
140 AveragingTrace,
141 )
142 }
143}