nmr_schedule/modifiers/
psf_polisher.rs

1pub(crate) mod generators {
2    use super::PSFPolisher;
3
4    /// The generator after applying [`super::PSFPolisher`]
5    #[derive(Debug)]
6    pub struct PSFPolisherGenerator<T> {
7        pub(crate) precursor: T,
8        pub(crate) filter: PSFPolisher,
9    }
10}
11
12use core::fmt::Display;
13use core::ops::Deref;
14
15use alloc::{sync::Arc, vec::Vec};
16use ndarray::{Dimension, Ix1};
17use rand::SeedableRng;
18use rand_chacha::ChaCha12Rng;
19use rustfft::{
20    Fft, FftPlanner,
21    num_complex::{Complex, ComplexFloat},
22};
23
24use crate::{
25    ComplexSequence, DisplayMode, Schedule,
26    generators::{Generator, Trace},
27    modifier,
28    point_spread::PointSpread,
29    quickselect,
30};
31
32use self::generators::PSFPolisherGenerator;
33
34use super::{Filter, Modifier};
35
36// Precompute the FT of a swap and use the time shift theorem and linearity of the FT to efficiently calculate the FT of the schedule after swapping
37struct SwappingBox {
38    sched: Schedule<Ix1>,
39    psf: PointSpread,
40    // The FT of [1, -1, 0, ..., 0]
41    swap_ft: Vec<Complex<f64>>,
42    // A precomputed map f(i) → e ^ (τ / len * -i)
43    // for calculating the linear phase
44    complex_phase: Vec<Complex<f64>>,
45}
46
47impl SwappingBox {
48    fn new(sched: Schedule<Ix1>, scratch: &mut Scratch) -> SwappingBox {
49        // Precompute the FFT of a swap
50
51        let mut swap = alloc::vec![Complex::new(0., 0.); sched.len()];
52        swap[0] = Complex::new(1., 0.);
53        swap[1] = Complex::new(-1., 0.);
54
55        scratch
56            .fft
57            .process_with_scratch(&mut swap, &mut scratch.for_fft);
58        let rescale = (sched.len() as f64).sqrt().recip();
59        swap.apply(|v| v * rescale);
60
61        let psf = sched.point_spread();
62
63        // Precompute the complex phase
64
65        let complex_phase = (0..sched.len())
66            .map(|i| {
67                Complex::from_polar(
68                    1.,
69                    core::f64::consts::TAU / (sched.len() as f64) * -(i as f64),
70                )
71            })
72            .collect::<Vec<_>>();
73
74        SwappingBox {
75            sched,
76            psf,
77            swap_ft: swap,
78            complex_phase,
79        }
80    }
81
82    const fn point_spread(&self) -> &PointSpread {
83        &self.psf
84    }
85
86    fn swap(&mut self, pos: usize) {
87        assert_ne!(self[pos], self[pos + 1]);
88
89        let len = self.len();
90
91        // Depending on which way we have to swap, we have to either add to or subtract from the original
92        let flippy = match self[pos] {
93            true => -1.,
94            false => 1.,
95        };
96
97        // Apply the linear phase and add it to the original PSF
98        self.psf
99            .inner_mut()
100            .iter_mut()
101            .zip(self.swap_ft.iter())
102            .enumerate()
103            .for_each(|(i, (original_val, swap_val))| {
104                let current_phase = self.complex_phase[(pos * i) % len];
105
106                *original_val += flippy * swap_val * current_phase;
107            });
108
109        self.sched.swap(pos, pos + 1);
110    }
111}
112
113impl Deref for SwappingBox {
114    type Target = Schedule<Ix1>;
115
116    fn deref(&self) -> &Self::Target {
117        &self.sched
118    }
119}
120
121struct Scratch {
122    fft: Arc<dyn Fft<f64>>,
123    ift: Arc<dyn Fft<f64>>,
124    for_fft: Vec<Complex<f64>>,
125    thresholded: Vec<Complex<f64>>,
126    peaks: Vec<f64>,
127    rng: ChaCha12Rng,
128}
129
130impl Scratch {
131    fn new(len: usize) -> Scratch {
132        let fft = FftPlanner::new().plan_fft_forward(len);
133        let ift = FftPlanner::new().plan_fft_inverse(len);
134
135        Scratch {
136            for_fft: alloc::vec![Complex::new(0., 0.); fft.get_inplace_scratch_len().max(ift.get_inplace_scratch_len())],
137            thresholded: alloc::vec![Complex::new(0., 0.); len],
138            peaks: alloc::vec![0.; len / 2],
139            // RNG is for quickselect; OK to preset the seed
140            rng: ChaCha12Rng::from_seed(*b"Not all seeds plant plants, some"),
141            fft,
142            ift,
143        }
144    }
145}
146
147fn find_best_swap(
148    threshold: f64,
149    sched: &SwappingBox,
150    can_swap: &[bool],
151    scratch: &mut Scratch,
152    mode: DisplayMode,
153) -> Option<usize> {
154    let spread = sched.point_spread();
155
156    // THRESHOLD THE PEAKS
157
158    // Take the values that aren't the central peak
159
160    let zeroed_width = spread.central_peak_radius(mode);
161
162    let (highest_peaks, _) = scratch.peaks.split_at_mut(spread.len() / 2 - zeroed_width);
163
164    spread[zeroed_width + 1..zeroed_width + highest_peaks.len() + 1]
165        .apply_into(highest_peaks, |v| mode.magnitude(v));
166
167    // Find the value that should be taken to be the threshold
168
169    let amount_to_correct = (sched.len() as f64 * threshold * 0.5).round() as usize;
170
171    quickselect(
172        &mut scratch.rng,
173        highest_peaks,
174        |a, b| a.total_cmp(b),
175        amount_to_correct,
176    );
177
178    let threshold = highest_peaks.get(amount_to_correct).unwrap_or(&0.);
179
180    // Perform the threshold
181
182    let thresholded = &mut *scratch.thresholded;
183    thresholded.copy_from_slice(spread);
184    thresholded[0..zeroed_width].fill(Complex::new(0., 0.));
185    thresholded[spread.len() - zeroed_width + 1..].fill(Complex::new(0., 0.));
186
187    mode.threshold(thresholded, *threshold);
188    mode.maybe_real_part(thresholded);
189
190    // CALCULATE THE BEST PLACES TO MAKE THE SWAPS
191
192    // IFT the thresholded peaks
193
194    scratch
195        .ift
196        .process_with_scratch(thresholded, &mut scratch.for_fft);
197    let rescale = (sched.len() as f64).sqrt().recip();
198    thresholded.apply(|v| v * rescale);
199
200    // Find the most impactful swap and return it (if there exists an allowable swap that makes the PSF better)
201
202    thresholded
203        .windows(2)
204        .zip(sched.windows(2))
205        .map(|(v, s)| {
206            // Note that the IFT values are guaranteed to have a zero imaginary part
207            if s[0] && !s[1] {
208                v[0].re() - v[1].re()
209            } else if !s[0] && s[1] {
210                v[1].re() - v[0].re()
211            } else {
212                -f64::INFINITY
213            }
214        })
215        .enumerate()
216        .filter(|(i, v)| can_swap[*i] && *v > 0.)
217        .max_by(|a, b| a.1.total_cmp(&b.1))
218        .map(|(i, _)| i)
219}
220
221/// Trace information for `PointSpreadFilter`
222#[derive(Debug)]
223pub struct PSFPolisherTrace {
224    /// Which swap was recommended by each iteration.
225    pub swaps: Vec<usize>,
226    /// The list of PSR scores before and after each iteration. This vector will be one element longer than `swaps` because it includes the initial and final PSRs.
227    pub psr_scores: Vec<f64>,
228    /// The filter applied all swaps `swaps[0..taken_before]` leaving the PSR `psr_scores[taken_before]`.
229    pub taken_before: usize,
230}
231
232impl Display for PSFPolisherTrace {
233    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
234        write!(
235            f,
236            "Performed {} swaps, Final PSR score: {}",
237            self.taken_before, self.psr_scores[self.taken_before]
238        )
239    }
240}
241
242impl<T: Generator<Ix1>> Generator<Ix1> for PSFPolisherGenerator<T> {
243    fn _generate(&self, count: usize, dims: Ix1, iteration: u64) -> Trace<Ix1> {
244        let trace = self
245            .precursor
246            .generate_with_iter_and_trace(count, dims, iteration);
247
248        self.filter.filter_from_iter_and_trace(trace, iteration)
249    }
250}
251
252modifier!(
253    PSFPolisher<Ix1>,
254    PSFPolisherBuilder,
255    "Swap sample points to smooth away point-spread function artifacts.",
256    polish_psf,
257    threshold: f64,
258    swap_value: f64,
259    mode: DisplayMode
260);
261
262impl Modifier<Ix1> for PSFPolisher {
263    type Output<T: Generator<Ix1>> = PSFPolisherGenerator<T>;
264
265    fn modify<T: Generator<Ix1>>(self, generator: T) -> Self::Output<T> {
266        PSFPolisherGenerator {
267            precursor: generator,
268            filter: self,
269        }
270    }
271}
272
273impl Filter<Ix1> for PSFPolisher {
274    fn filter_with_iter_and_trace(&self, sched: Schedule<Ix1>, _iteration: u64) -> Trace<Ix1> {
275        let threshold = self.0;
276        let swap_value = self.1;
277        let mode = self.2;
278
279        let dims = sched.raw_dim();
280
281        if (0..dims.ndim()).all(|v| dims[v] < 2) {
282            return Trace::new(
283                sched,
284                PSFPolisherTrace {
285                    swaps: Vec::new(),
286                    psr_scores: Vec::new(),
287                    taken_before: 0,
288                },
289            );
290        }
291
292        // Calculate the swap value from the universal user-supplied parameter
293        let individual_swap_value = swap_value / (sched.count() as f64);
294
295        let mut can_swap = alloc::vec![true; dims[0] - 1];
296
297        can_swap[0] = false;
298        *can_swap.last_mut().unwrap() = false;
299
300        let mut scratch = Scratch::new(dims[0]);
301
302        let mut swapping_box = SwappingBox::new(sched, &mut scratch);
303
304        let mut swaps = Vec::new();
305        let mut psrs = Vec::new();
306
307        // PERFORM SWAPS UNTIL DOING SO IS NO LONGER POSSIBLE
308
309        // Guaranteed to terminate because the `can_swap` array will get more `false`s each iteration.
310        while let Some(swap) =
311            find_best_swap(threshold, &swapping_box, &can_swap, &mut scratch, mode)
312        {
313            let psr = swapping_box.point_spread().peak_to_sidelobe_ratio(mode);
314
315            swaps.push(swap);
316            psrs.push(psr);
317
318            swapping_box.swap(swap);
319
320            can_swap[swap] = false;
321            if swap < can_swap.len() - 1 {
322                can_swap[swap + 1] = false;
323            }
324            if swap > 0 {
325                can_swap[swap - 1] = false;
326            }
327        }
328
329        psrs.push(swapping_box.point_spread().peak_to_sidelobe_ratio(mode));
330
331        let mut sched = swapping_box.sched;
332
333        // PICK THE BEST SCHEDULE TO RETURN
334
335        // Pick the index of the schedule to take to be the final one
336
337        let taken = psrs
338            .iter()
339            .enumerate()
340            .map(|(i, score)| (i, score + (i as f64 * individual_swap_value)))
341            .min_by(|a, b| a.1.total_cmp(&b.1))
342            .unwrap()
343            .0;
344
345        // Undo swaps that turned out not to be necessary
346
347        swaps
348            .iter()
349            .enumerate()
350            .rev()
351            .take_while(|(i, _)| *i >= taken)
352            .for_each(|(_, swap)| sched.swap(*swap, *swap + 1));
353
354        Trace::new(
355            sched,
356            PSFPolisherTrace {
357                swaps,
358                psr_scores: psrs,
359                taken_before: taken,
360            },
361        )
362    }
363}
364
365#[cfg(test)]
366mod tests {
367
368    use alloc::borrow::ToOwned;
369    use alloc::string::ToString;
370    use core::f64::consts::PI;
371
372    use alloc::{format, vec};
373    use ndarray::{Array1, Ix1, s};
374    use rustfft::num_complex::ComplexFloat;
375
376    use crate::{
377        DisplayMode, Schedule,
378        generators::{Generator, Quantiles},
379        pdf::{QSinBias, qsin},
380        point_spread::PointSpread,
381    };
382
383    use super::{Scratch, SwappingBox, find_best_swap};
384
385    #[test]
386    fn test_swapping_box() {
387        fn assert_psf_eq(psf_a: &PointSpread, psf_b: &PointSpread, name: &str) {
388            assert_eq!(psf_a.len(), psf_b.len());
389
390            psf_a.iter().zip(psf_b.iter()).for_each(|(a, b)| {
391                assert!((a - b).abs() < 0.000000000000001, "{name}: {a}, {b}");
392            })
393        }
394
395        let sched = Quantiles::new(|len| qsin(len, QSinBias::Low, PI)).generate(64, Ix1(256));
396
397        let mut ground_truth_sched = sched.to_owned();
398        let mut swapping_box = SwappingBox::new(sched.to_owned(), &mut Scratch::new(256));
399
400        assert_eq!(&*ground_truth_sched, &**swapping_box);
401        assert_psf_eq(
402            &ground_truth_sched.point_spread(),
403            swapping_box.point_spread(),
404            "Start",
405        );
406
407        let swaps = [40, 50, 100];
408
409        for swap in swaps {
410            ground_truth_sched[swap] = !ground_truth_sched[swap];
411            ground_truth_sched[swap + 1] = !ground_truth_sched[swap + 1];
412
413            swapping_box.swap(swap);
414
415            assert_eq!(&*ground_truth_sched, &**swapping_box);
416            assert_psf_eq(
417                &ground_truth_sched.point_spread(),
418                swapping_box.point_spread(),
419                &swap.to_string(),
420            );
421        }
422
423        let sched = Schedule::new(sched.into_inner().slice(s![0..255]).to_owned());
424
425        let mut ground_truth_sched = sched.to_owned();
426        let mut swapping_box = SwappingBox::new(sched, &mut Scratch::new(255));
427
428        assert_eq!(&*ground_truth_sched, &**swapping_box);
429        assert_psf_eq(
430            &ground_truth_sched.point_spread(),
431            swapping_box.point_spread(),
432            "Start",
433        );
434
435        let swaps = [40, 50, 100];
436
437        for swap in swaps {
438            ground_truth_sched[swap] = !ground_truth_sched[swap];
439            ground_truth_sched[swap + 1] = !ground_truth_sched[swap + 1];
440
441            swapping_box.swap(swap);
442
443            assert_eq!(&*ground_truth_sched, &**swapping_box);
444            assert_psf_eq(
445                &ground_truth_sched.point_spread(),
446                swapping_box.point_spread(),
447                &swap.to_string(),
448            );
449        }
450    }
451
452    #[test]
453    fn test_find_best_swap() {
454        let sched = Schedule::new(Array1::from_vec(vec![
455            true, true, true, true, true, true, true, false, true, false, true, false, true, false,
456            true, false, false, false, true, false, false, false, false, false, false, false,
457            false, false, false, false, false, true,
458        ]));
459        assert_eq!(format!("{sched}"), "▩▩▩▩▩▩▩_▩_▩_▩_▩___▩____________▩");
460        let mut scratch = Scratch::new(32);
461        let sched_box = SwappingBox::new(sched, &mut scratch);
462        let mut can_swap = [true; 32];
463        can_swap[0] = false;
464        can_swap[31] = false;
465        let idx =
466            find_best_swap(0.2, &sched_box, &can_swap, &mut scratch, DisplayMode::Abs).unwrap();
467        assert_eq!(idx, 10);
468        // panic!()
469    }
470}