nmr_schedule/modifiers/
basic.rs

1use core::fmt::Display;
2
3use crate::{
4    generators::{Generator, Trace},
5    modifier, Schedule,
6};
7
8pub(crate) mod generators {
9    /// The generator after applying [`super::FillCorners`]
10    #[derive(Debug)]
11    pub struct FillCornersGenerator<T, F: Fn(usize, usize) -> [usize; 2]> {
12        pub(crate) generator: T,
13        pub(crate) fill_amt: F,
14    }
15}
16
17use alloc::borrow::ToOwned;
18use generators::*;
19use ndarray::Ix1;
20
21use super::Modifier;
22
23/// The trace information for `FillCorners`
24#[derive(Clone, Copy, Debug)]
25#[non_exhaustive]
26pub struct FillCornersTrace {
27    /// The amount of backfill as determined by the closure passed in
28    pub fill_amt: [usize; 2],
29    /// The number of iterations required to find a count parameter for the previous generator that allows sufficient backfilling without exceeding `count`
30    pub iterations: usize,
31    /// The count given to the previous generator
32    pub count_passed_along: usize,
33    /// Whether there was more fill specified than the number of points. If true, then the fill was reduced to accomodate the number of points available.
34    pub too_much_fill: bool,
35}
36
37impl Display for FillCornersTrace {
38    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
39        write!(
40            f,
41            "Fill amt: {:?}, Generation tries needed: {}, Sample count passed upstream: {}{}",
42            self.fill_amt,
43            self.iterations,
44            self.count_passed_along,
45            if self.too_much_fill {
46                ", More fill specified than sample points, reducing fill"
47            } else {
48                ""
49            }
50        )
51    }
52}
53
54impl<T: Generator<Ix1>, F: Fn(usize, usize) -> [usize; 2]> Generator<Ix1>
55    for FillCornersGenerator<T, F>
56{
57    fn _generate(&self, count: usize, dims: Ix1, iteration: u64) -> Trace<Ix1> {
58        let mut fill_amt = (self.fill_amt)(count, dims[0]);
59        let sum = fill_amt.iter().sum::<usize>();
60
61        if sum > count {
62            let mut true_sum = sum;
63
64            if fill_amt[1] > 1 && true_sum > count {
65                let dec_count = (true_sum - count).min(fill_amt[1] - 1);
66                fill_amt[1] -= dec_count;
67                true_sum -= dec_count;
68            }
69
70            if fill_amt[0] > 1 && true_sum > count {
71                let dec_count = (true_sum - count).min(fill_amt[0] - 1);
72                fill_amt[0] -= dec_count;
73                true_sum -= dec_count;
74            }
75
76            if fill_amt[1] == 1 && true_sum > count {
77                fill_amt[1] = 0;
78                true_sum -= 1;
79            }
80
81            if fill_amt[0] == 1 && true_sum > count {
82                fill_amt[0] = 0;
83            }
84
85            // At this point there's zero backfill left
86        }
87
88        let mut iterations = 0;
89        let mut current = count;
90        let trace: Trace<Ix1>;
91        let sched: Schedule<Ix1>;
92
93        loop {
94            iterations += 1;
95
96            let trying_trace = self
97                .generator
98                .generate_with_iter_and_trace(current, dims, iteration);
99
100            let mut trying_sched = trying_trace.sched().to_owned();
101
102            let mut amt_added = 0;
103
104            for i in 0..fill_amt[1] {
105                if !trying_sched[dims[0] - i - 1] {
106                    trying_sched[dims[0] - i - 1] = true;
107                    amt_added += 1;
108                }
109            }
110
111            let mut i = 0;
112            while i < fill_amt[0] || current + amt_added < count {
113                if !trying_sched[i] {
114                    trying_sched[i] = true;
115                    amt_added += 1;
116                }
117
118                i += 1;
119            }
120
121            if current + amt_added == count {
122                sched = trying_sched;
123                trace = trying_trace;
124                break;
125            }
126
127            current -= current + amt_added - count;
128        }
129
130        trace.with(
131            sched,
132            FillCornersTrace {
133                fill_amt,
134                iterations,
135                count_passed_along: current,
136                too_much_fill: sum > count,
137            },
138        )
139    }
140}
141
142modifier!(
143    <F: Fn(usize, usize) -> [usize; 2]>
144    FillCorners<Ix1>,
145    FillCornersBuilder,
146    r"Guarantee that the first `amt(count, len)[0]` samples and the last `amt(count, len)[1]` samples are taken.
147    
148A small amount of linear backfill (2-4% of the length of the schedule) can improve the quality of a schedule.
149    
150L. E. Cullen, A. Marchiori, D. Rovnyak, Magn Reson Chem 2023, 61(6), 337. <https://doi.org/10.1002/mrc.5340>
151    ",
152    fill_corners,
153    amt: F
154);
155
156impl<F: Fn(usize, usize) -> [usize; 2]> Modifier<Ix1> for FillCorners<F> {
157    type Output<T: Generator<Ix1>> = FillCornersGenerator<T, F>;
158
159    fn modify<T: Generator<Ix1>>(self, generator: T) -> Self::Output<T> {
160        FillCornersGenerator {
161            generator,
162            fill_amt: self.0,
163        }
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use ndarray::Ix1;
170
171    use crate::{
172        generators::{Generator, SinWeightedPoissonGap},
173        modifiers::FillCornersTrace,
174    };
175
176    use super::FillCornersBuilder;
177
178    #[test]
179    fn fills_correctly() {
180        for (i, (start, end, count)) in [
181            (30, 20, 100),
182            (8, 1, 64),
183            (100, 0, 100),
184            (0, 100, 100),
185            (50, 50, 100),
186        ]
187        .into_iter()
188        .enumerate()
189        {
190            let trace = SinWeightedPoissonGap::new(*b"Not all plants make 32 byte seed")
191                .fill_corners(|_, _| [start, end])
192                .generate_with_iter_and_trace(count, Ix1(256), i as u64);
193
194            let sched = trace.sched();
195
196            assert!(!trace.get::<FillCornersTrace>().unwrap().too_much_fill);
197
198            assert!(
199                sched.iter().take(start).all(|v| *v),
200                "{sched} {start} {end} {count}"
201            );
202            assert!(
203                sched.iter().rev().take(end).all(|v| *v),
204                "{sched} {start} {end} {count}"
205            );
206        }
207    }
208
209    #[test]
210    fn calls_correctly() {
211        let f = |count, len| [count / 10, len / 10];
212
213        for (i, (count, len)) in [(100, 200), (64, 128), (110, 301), (400, 1003), (1020, 1021)]
214            .into_iter()
215            .enumerate()
216        {
217            let trace = SinWeightedPoissonGap::new(*b"Not all plants make 32 byte seed")
218                .fill_corners(f)
219                .generate_with_iter_and_trace(count, Ix1(len), i as u64);
220
221            let sched = trace.sched();
222
223            assert!(!trace.get::<FillCornersTrace>().unwrap().too_much_fill);
224
225            let [start, end] = f(count, len);
226
227            assert!(
228                sched.iter().take(start).all(|v| *v),
229                "{sched} {start} {end} {count}"
230            );
231            assert!(
232                sched.iter().rev().take(end).all(|v| *v),
233                "{sched} {start} {end} {count}"
234            );
235        }
236    }
237
238    #[test]
239    fn overflows_correctly() {
240        for (i, (start, end, count, true_start, true_end)) in [
241            (30, 20, 40, 30, 10),
242            (8, 5, 7, 6, 1),
243            (4, 5, 2, 1, 1),
244            (16, 3, 1, 1, 0),
245            (20, 30, 0, 0, 0),
246            (300, 200, 256, 255, 1),
247        ]
248        .into_iter()
249        .enumerate()
250        {
251            let trace = SinWeightedPoissonGap::new(*b"Not all plants make 32 byte seed")
252                .fill_corners(|_, _| [start, end])
253                .generate_with_iter_and_trace(count, Ix1(256), i as u64);
254
255            let sched = trace.sched();
256
257            assert!(trace.get::<FillCornersTrace>().unwrap().too_much_fill);
258
259            assert!(
260                sched.iter().take(true_start).all(|v| *v),
261                "{sched} {start} {end} {count}"
262            );
263            assert!(
264                sched.iter().rev().take(true_end).all(|v| *v),
265                "{sched} {start} {end} {count}"
266            );
267        }
268    }
269}