nmr_schedule/modifiers/
psf_polisher.rs

1pub(crate) mod generators {
2    use super::PSFPolisher;
3
4    /// The generator after applying [`super::PSFPolisher`]
5    #[derive(Debug)]
6    pub struct PSFPolisherGenerator<T> {
7        pub(crate) precursor: T,
8        pub(crate) filter: PSFPolisher,
9    }
10}
11
12use core::fmt::Display;
13use core::ops::Deref;
14
15use alloc::vec::Vec;
16use ndarray::{Dimension, Ix1};
17use rand::SeedableRng;
18use rand_chacha::ChaCha12Rng;
19use rustfft::{
20    num_complex::{Complex, ComplexFloat},
21    FftPlanner,
22};
23
24use crate::{
25    generators::{Generator, Trace},
26    modifier,
27    point_spread::PointSpread,
28    quickselect, ComplexSequence, DisplayMode, Schedule,
29};
30
31use self::generators::PSFPolisherGenerator;
32
33use super::{Filter, Modifier};
34
35// Precompute the FT of a swap and use the time shift theorem and linearity of the FT to efficiently calculate the FT of the schedule after swapping
36struct SwappingBox {
37    sched: Schedule<Ix1>,
38    psf: PointSpread,
39    // The FT of [1, -1, 0, ..., 0]
40    swap_ft: Vec<Complex<f64>>,
41    // A precomputed map f(i) → e ^ (τ / len * -i)
42    // for calculating the linear phase
43    complex_phase: Vec<Complex<f64>>,
44}
45
46impl SwappingBox {
47    fn new(sched: Schedule<Ix1>, scratch: &mut Scratch) -> SwappingBox {
48        // Precompute the FFT of a swap
49
50        let fft = FftPlanner::new().plan_fft_forward(sched.len());
51
52        let mut swap = alloc::vec![Complex::new(0., 0.); sched.len()];
53        swap[0] = Complex::new(1., 0.);
54        swap[1] = Complex::new(-1., 0.);
55
56        fft.process_with_scratch(&mut swap, &mut scratch.fft);
57        let rescale = (sched.len() as f64).sqrt().recip();
58        swap.apply(|v| v * rescale);
59
60        let psf = sched.point_spread();
61
62        // Precompute the complex phase
63
64        let complex_phase = (0..sched.len())
65            .map(|i| {
66                Complex::from_polar(
67                    1.,
68                    core::f64::consts::TAU / (sched.len() as f64) * -(i as f64),
69                )
70            })
71            .collect::<Vec<_>>();
72
73        SwappingBox {
74            sched,
75            psf,
76            swap_ft: swap,
77            complex_phase,
78        }
79    }
80
81    const fn point_spread(&self) -> &PointSpread {
82        &self.psf
83    }
84
85    fn swap(&mut self, pos: usize) {
86        assert_ne!(self[pos], self[pos + 1]);
87
88        let len = self.len();
89
90        // Depending on which way we have to swap, we have to either add to or subtract from the original
91        let flippy = match self[pos] {
92            true => -1.,
93            false => 1.,
94        };
95
96        // Apply the linear phase and add it to the original PSF
97        self.psf
98            .inner_mut()
99            .iter_mut()
100            .zip(self.swap_ft.iter())
101            .enumerate()
102            .for_each(|(i, (original_val, swap_val))| {
103                let current_phase = self.complex_phase[(pos * i) % len];
104
105                *original_val += flippy * swap_val * current_phase;
106            });
107
108        self.sched.swap(pos, pos + 1);
109    }
110}
111
112impl Deref for SwappingBox {
113    type Target = Schedule<Ix1>;
114
115    fn deref(&self) -> &Self::Target {
116        &self.sched
117    }
118}
119
120struct Scratch {
121    fft: Vec<Complex<f64>>,
122    thresholded: Vec<Complex<f64>>,
123    peaks: Vec<f64>,
124    rng: ChaCha12Rng,
125}
126
127impl Scratch {
128    fn new(len: usize) -> Scratch {
129        Scratch {
130            fft: alloc::vec![Complex::new(0., 0.); len],
131            thresholded: alloc::vec![Complex::new(0., 0.); len],
132            peaks: alloc::vec![0.; len / 2],
133            // RNG is for quickselect; OK to preset the seed
134            rng: ChaCha12Rng::from_seed(*b"Not all seeds plant plants, some"),
135        }
136    }
137}
138
139fn find_best_swap(
140    threshold: f64,
141    sched: &SwappingBox,
142    can_swap: &[bool],
143    scratch: &mut Scratch,
144    mode: DisplayMode,
145) -> Option<usize> {
146    let spread = sched.point_spread();
147
148    // THRESHOLD THE PEAKS
149
150    // Take the values that aren't the central peak
151
152    let zeroed_width = spread.central_peak_radius(mode);
153
154    let (highest_peaks, _) = scratch.peaks.split_at_mut(spread.len() / 2 - zeroed_width);
155
156    spread[zeroed_width + 1..zeroed_width + highest_peaks.len() + 1]
157        .apply_into(highest_peaks, |v| mode.magnitude(v));
158
159    // Find the value that should be taken to be the threshold
160
161    let amount_to_correct = (sched.len() as f64 * threshold * 0.5).round() as usize;
162
163    quickselect(
164        &mut scratch.rng,
165        highest_peaks,
166        |a, b| a.total_cmp(b),
167        amount_to_correct,
168    );
169
170    let threshold = highest_peaks.get(amount_to_correct).unwrap_or(&0.);
171
172    // Perform the threshold
173
174    let thresholded = &mut *scratch.thresholded;
175    thresholded.copy_from_slice(spread);
176    thresholded[0..zeroed_width].fill(Complex::new(0., 0.));
177    thresholded[spread.len() - zeroed_width + 1..].fill(Complex::new(0., 0.));
178
179    mode.threshold(thresholded, *threshold);
180    mode.maybe_real_part(thresholded);
181
182    // CALCULATE THE BEST PLACES TO MAKE THE SWAPS
183
184    // IFT the thresholded peaks
185
186    let ift = FftPlanner::new().plan_fft_inverse(sched.len());
187    ift.process_with_scratch(thresholded, &mut scratch.fft);
188    let rescale = (sched.len() as f64).sqrt().recip();
189    thresholded.apply(|v| v * rescale);
190
191    // Find the most impactful swap and return it (if there exists an allowable swap that makes the PSF better)
192
193    thresholded
194        .windows(2)
195        .zip(sched.windows(2))
196        .map(|(v, s)| {
197            // Note that the IFT values are guaranteed to have a zero imaginary part
198            if s[0] && !s[1] {
199                v[0].re() - v[1].re()
200            } else if !s[0] && s[1] {
201                v[1].re() - v[0].re()
202            } else {
203                -f64::INFINITY
204            }
205        })
206        .enumerate()
207        .filter(|(i, v)| can_swap[*i] && *v > 0.)
208        .max_by(|a, b| a.1.total_cmp(&b.1))
209        .map(|(i, _)| i)
210}
211
212/// Trace information for `PointSpreadFilter`
213#[derive(Debug)]
214pub struct PSFPolisherTrace {
215    /// Which swap was recommended by each iteration.
216    pub swaps: Vec<usize>,
217    /// The list of PSR scores before and after each iteration. This vector will be one element longer than `swaps` because it includes the initial and final PSRs.
218    pub psr_scores: Vec<f64>,
219    /// The filter applied all swaps `swaps[0..taken_before]` leaving the PSR `psr_scores[taken_before]`.
220    pub taken_before: usize,
221}
222
223impl Display for PSFPolisherTrace {
224    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
225        write!(
226            f,
227            "Performed {} swaps, Final PSR score: {}",
228            self.taken_before, self.psr_scores[self.taken_before]
229        )
230    }
231}
232
233impl<T: Generator<Ix1>> Generator<Ix1> for PSFPolisherGenerator<T> {
234    fn _generate(&self, count: usize, dims: Ix1, iteration: u64) -> Trace<Ix1> {
235        let trace = self
236            .precursor
237            .generate_with_iter_and_trace(count, dims, iteration);
238
239        self.filter.filter_from_iter_and_trace(trace, iteration)
240    }
241}
242
243modifier!(
244    PSFPolisher<Ix1>,
245    PSFPolisherBuilder,
246    "Swap sample points to smooth away point-spread function artifacts.",
247    polish_psf,
248    threshold: f64,
249    swap_value: f64,
250    mode: DisplayMode
251);
252
253impl Modifier<Ix1> for PSFPolisher {
254    type Output<T: Generator<Ix1>> = PSFPolisherGenerator<T>;
255
256    fn modify<T: Generator<Ix1>>(self, generator: T) -> Self::Output<T> {
257        PSFPolisherGenerator {
258            precursor: generator,
259            filter: self,
260        }
261    }
262}
263
264impl Filter<Ix1> for PSFPolisher {
265    fn filter_with_iter_and_trace(&self, sched: Schedule<Ix1>, _iteration: u64) -> Trace<Ix1> {
266        let threshold = self.0;
267        let swap_value = self.1;
268        let mode = self.2;
269
270        let dims = sched.raw_dim();
271
272        if (0..dims.ndim()).all(|v| dims[v] < 2) {
273            return Trace::new(
274                sched,
275                PSFPolisherTrace {
276                    swaps: Vec::new(),
277                    psr_scores: Vec::new(),
278                    taken_before: 0,
279                },
280            );
281        }
282
283        // Calculate the swap value from the universal user-supplied parameter
284        let individual_swap_value = swap_value / (sched.count() as f64);
285
286        let mut can_swap = alloc::vec![true; dims[0] - 1];
287
288        can_swap[0] = false;
289        *can_swap.last_mut().unwrap() = false;
290
291        let mut scratch = Scratch::new(dims[0]);
292
293        let mut swapping_box = SwappingBox::new(sched, &mut scratch);
294
295        let mut swaps = Vec::new();
296        let mut psrs = Vec::new();
297
298        // PERFORM SWAPS UNTIL DOING SO IS NO LONGER POSSIBLE
299
300        // Guaranteed to terminate because the `can_swap` array will get more `false`s each iteration.
301        while let Some(swap) =
302            find_best_swap(threshold, &swapping_box, &can_swap, &mut scratch, mode)
303        {
304            let psr = swapping_box.point_spread().peak_to_sidelobe_ratio(mode);
305
306            swaps.push(swap);
307            psrs.push(psr);
308
309            swapping_box.swap(swap);
310
311            can_swap[swap] = false;
312            if swap < can_swap.len() - 1 {
313                can_swap[swap + 1] = false;
314            }
315            if swap > 0 {
316                can_swap[swap - 1] = false;
317            }
318        }
319
320        psrs.push(swapping_box.point_spread().peak_to_sidelobe_ratio(mode));
321
322        let mut sched = swapping_box.sched;
323
324        // PICK THE BEST SCHEDULE TO RETURN
325
326        // Pick the index of the schedule to take to be the final one
327
328        let taken = psrs
329            .iter()
330            .enumerate()
331            .map(|(i, score)| (i, score + (i as f64 * individual_swap_value)))
332            .min_by(|a, b| a.1.total_cmp(&b.1))
333            .unwrap()
334            .0;
335
336        // Undo swaps that turned out not to be necessary
337
338        swaps
339            .iter()
340            .enumerate()
341            .rev()
342            .take_while(|(i, _)| *i >= taken)
343            .for_each(|(_, swap)| sched.swap(*swap, *swap + 1));
344
345        Trace::new(
346            sched,
347            PSFPolisherTrace {
348                swaps,
349                psr_scores: psrs,
350                taken_before: taken,
351            },
352        )
353    }
354}
355
356#[cfg(test)]
357mod tests {
358
359    use alloc::borrow::ToOwned;
360    use alloc::string::ToString;
361    use core::f64::consts::PI;
362
363    use alloc::{format, vec};
364    use ndarray::{s, Array1, Ix1};
365    use rustfft::num_complex::ComplexFloat;
366
367    use crate::{
368        generators::{Generator, Quantiles},
369        pdf::{qsin, QSinBias},
370        point_spread::PointSpread,
371        DisplayMode, Schedule,
372    };
373
374    use super::{find_best_swap, Scratch, SwappingBox};
375
376    #[test]
377    fn test_swapping_box() {
378        fn assert_psf_eq(psf_a: &PointSpread, psf_b: &PointSpread, name: &str) {
379            assert_eq!(psf_a.len(), psf_b.len());
380
381            psf_a.iter().zip(psf_b.iter()).for_each(|(a, b)| {
382                assert!((a - b).abs() < 0.000000000000001, "{name}: {a}, {b}");
383            })
384        }
385
386        let sched = Quantiles::new(|len| qsin(len, QSinBias::Low, PI)).generate(64, Ix1(256));
387
388        let mut ground_truth_sched = sched.to_owned();
389        let mut swapping_box = SwappingBox::new(sched.to_owned(), &mut Scratch::new(256));
390
391        assert_eq!(&*ground_truth_sched, &**swapping_box);
392        assert_psf_eq(
393            &ground_truth_sched.point_spread(),
394            swapping_box.point_spread(),
395            "Start",
396        );
397
398        let swaps = [40, 50, 100];
399
400        for swap in swaps {
401            ground_truth_sched[swap] = !ground_truth_sched[swap];
402            ground_truth_sched[swap + 1] = !ground_truth_sched[swap + 1];
403
404            swapping_box.swap(swap);
405
406            assert_eq!(&*ground_truth_sched, &**swapping_box);
407            assert_psf_eq(
408                &ground_truth_sched.point_spread(),
409                swapping_box.point_spread(),
410                &swap.to_string(),
411            );
412        }
413
414        let sched = Schedule::new(sched.into_inner().slice(s![0..255]).to_owned());
415
416        let mut ground_truth_sched = sched.to_owned();
417        let mut swapping_box = SwappingBox::new(sched, &mut Scratch::new(255));
418
419        assert_eq!(&*ground_truth_sched, &**swapping_box);
420        assert_psf_eq(
421            &ground_truth_sched.point_spread(),
422            swapping_box.point_spread(),
423            "Start",
424        );
425
426        let swaps = [40, 50, 100];
427
428        for swap in swaps {
429            ground_truth_sched[swap] = !ground_truth_sched[swap];
430            ground_truth_sched[swap + 1] = !ground_truth_sched[swap + 1];
431
432            swapping_box.swap(swap);
433
434            assert_eq!(&*ground_truth_sched, &**swapping_box);
435            assert_psf_eq(
436                &ground_truth_sched.point_spread(),
437                swapping_box.point_spread(),
438                &swap.to_string(),
439            );
440        }
441    }
442
443    #[test]
444    fn test_find_best_swap() {
445        let sched = Schedule::new(Array1::from_vec(vec![
446            true, true, true, true, true, true, true, false, true, false, true, false, true, false,
447            true, false, false, false, true, false, false, false, false, false, false, false,
448            false, false, false, false, false, true,
449        ]));
450        assert_eq!(format!("{sched}"), "▩▩▩▩▩▩▩_▩_▩_▩_▩___▩____________▩");
451        let mut scratch = Scratch::new(32);
452        let sched_box = SwappingBox::new(sched, &mut scratch);
453        let mut can_swap = [true; 32];
454        can_swap[0] = false;
455        can_swap[31] = false;
456        let idx =
457            find_best_swap(0.2, &sched_box, &can_swap, &mut scratch, DisplayMode::Abs).unwrap();
458        assert_eq!(idx, 10);
459        // panic!()
460    }
461}