nmr_schedule/
lib.rs

1//! This crate implements algorithms for Non-Uniform Sampling in NMR spectroscopy.
2//!
3//! This crate primarily implements algorithms for generating and filtering sampling schedules. It includes various base schedules along with various post-processing filters for schedules.
4//!
5//! Schedulers implemented:
6//! - [Quantiles](`generators::Quantiles`)
7//! - [Poisson Gap](`generators::SinWeightedPoissonGap`)
8//! - [Random](`generators::RandomSampling`)
9//! - [Averaging](`generators::Averaging`)
10//!
11//! Filters implemented:
12//! - [Backfill](`modifiers::FillCorners`)
13//! - [Seed searching](`modifiers::Iterate`)
14//! - [Thue-Morse filter](`modifiers::TMFilter`)
15//! - [PSF polisher](`modifiers::PSFPolisher`)
16//!
17//! This library currently only supports 1D schedules, however there are plans to support higher dimensional schedules in the future and the architecture is already generic over dimension.
18//!
19//! # Examples
20//!
21//! ```
22//! # use nmr_schedule::{*, pdf::*, generators::*, modifiers::*, ndarray::Ix1};
23//! // Generate a 64x256 schedule with Quantiles with QSin weighting, 8 points backfill, and TMPF filtering
24//! let sched = Quantiles::new(|len| qsin(len, QSinBias::Low, 3.))
25//!     .fill_corners(|_, _| [8, 1]) // Any function of the count and length of the schedule
26//!     .tm_filter()
27//!     .polish_psf(0.1, 0.32, DisplayMode::Abs)
28//!     .generate(64, Ix1(256));
29//!
30//! println!("{sched}");
31//! ```
32//!
33//! ```
34//! # use nmr_schedule::{*, pdf::*, generators::*, modifiers::*, ndarray::Ix1};
35//! // Apply TMPF filtering to an existing schedule
36//! let sched_encoded = "0\n1\n2\n5\n7\n9\n20";
37//! let sched =
38//!     Schedule::decode(sched_encoded, EncodingType::ZeroBased, |dim| Ok(dim)).unwrap();
39//!
40//! let filtered =
41//!     PSFPolisher::new(0.1, 0.32, DisplayMode::Abs).filter(TMFilter::new().filter(sched));
42//!
43//! let encoded = filtered.encode(EncodingType::ZeroBased);
44//! println!("{encoded}");
45//! ```
46
47#![cfg_attr(not(test), no_std)]
48#![warn(clippy::cargo)]
49#![warn(clippy::complexity)]
50#![warn(clippy::correctness)]
51#![warn(clippy::perf)]
52#![warn(clippy::style)]
53#![warn(clippy::suspicious)]
54#![warn(missing_docs)]
55#![warn(missing_copy_implementations)]
56#![warn(missing_debug_implementations)]
57#![warn(clippy::missing_panics_doc)]
58#![warn(clippy::missing_const_for_fn)]
59
60extern crate alloc;
61
62use core::cmp::Ordering;
63
64pub mod generators;
65pub mod modifiers;
66pub mod pdf;
67pub mod reconstruction;
68mod schedule;
69#[cfg(feature = "terminal-viz")]
70pub mod terminal_viz;
71
72use rand::Rng;
73use rand_chacha::ChaCha12Rng;
74use rustfft::num_complex::{Complex, ComplexFloat};
75pub use schedule::*;
76
77pub use rustfft;
78
79pub use ndarray;
80
81/// Represents whether a spectrum is meant to be displayed as the real part or as the absolute value
82///
83/// Note that this is very experimental and you should default to using [`DisplayMode::Abs`] in parameters. The PSF Polisher publication only uses [`DisplayMode::Abs`].
84///
85/// Future research can look into whether knowing the display mode of an experiment can inform schedule generation and reconstruction.
86///
87/// The PSF polisher, when passed in [`DisplayMode::RealPart`], will only "see" the real part of the PSF during filtering and ignore the imaginary part, potentially leaving artifacts in and moving PSF noise into the imaginary part. If it is known that all of the signal will be exclusively in the real part, then moving sampling noise from the real part to the imaginary part may be desirable.
88///
89/// The IST reconstructor will also only "see" the real part when passed in [`DisplayMode::RealPart`]. It will assume that all signals in the imaginary axis are noise and ignore them.
90#[derive(Clone, Copy, Debug, Default)]
91pub enum DisplayMode {
92    /// The spectrum will be displayed in absolute value mode
93    #[default]
94    Abs,
95    /// The spectrum will be displayed in real part mode; the complex part will be ignored.
96    RealPart,
97}
98
99impl DisplayMode {
100    /// Calculate the magnitude of the complex number given the display mode. In `Abs` mode, this will return `complex.abs()`, and in `RealPart` mode, this will return `complex.re().abs()`.
101    pub fn magnitude(self, complex: Complex<f64>) -> f64 {
102        match self {
103            DisplayMode::Abs => complex.abs(),
104            DisplayMode::RealPart => complex.re().abs(),
105        }
106    }
107
108    /// Apply a soft threshold to a sequence, zeroing out all components with [`DisplayMode::magnitude`] less than the threshold.
109    pub fn threshold<T: ComplexSequence + ?Sized>(self, sequence: &mut T, threshold: f64) {
110        sequence.apply(|v| {
111            let mag = self.magnitude(v);
112            if mag > threshold {
113                v * (mag - threshold) / mag
114            } else {
115                Complex::new(0., 0.)
116            }
117        })
118    }
119
120    /// Take the real part of the sequence if `self` is `RealPart`, otherwise leaving it alone.
121    pub fn maybe_real_part<T: ComplexSequence + ?Sized>(self, sequence: &mut T) {
122        if let Self::RealPart = self {
123            sequence.apply(|v| Complex::new(v.re(), 0.));
124        }
125    }
126}
127
128fn partition<T>(rng: &mut ChaCha12Rng, slice: &mut [T], by: &impl Fn(&T, &T) -> Ordering) -> usize {
129    slice.swap(0, rng.random_range(0..slice.len()));
130
131    let mut i = 1;
132    let mut j = slice.len() - 1;
133
134    loop {
135        while i < slice.len() && !matches!(by(&slice[i], &slice[0]), Ordering::Less) {
136            i += 1;
137        }
138
139        while matches!(by(&slice[j], &slice[0]), Ordering::Less) {
140            j -= 1;
141        }
142
143        // If the indices crossed, return
144        if i > j {
145            slice.swap(0, j);
146            return j;
147        }
148
149        // Swap the elements at the left and right indices
150        slice.swap(i, j);
151        i += 1;
152    }
153}
154
155/// Standard quickselect algorithm: https://en.wikipedia.org/wiki/Quickselect
156/// Modified to sort in descending order (since I want maxima in all of my usecases)
157///
158/// After calling this function, the value at index `find_spot` is guaranteed to be at the correctly sorted position and all values at indices less than `find_spot` are guaranteed to be greater than the value at `find_spot` and vice versa for indices greater.
159pub(crate) fn quickselect<T>(
160    rng: &mut ChaCha12Rng,
161    mut slice: &mut [T],
162    by: impl Fn(&T, &T) -> Ordering,
163    mut find_spot: usize,
164) {
165    loop {
166        let len = slice.len();
167
168        if len < 2 {
169            return;
170        }
171
172        let spot_found = partition(rng, slice, &by);
173
174        match find_spot.cmp(&spot_found) {
175            Ordering::Less => slice = &mut slice[0..spot_found],
176            Ordering::Equal => return,
177            Ordering::Greater => {
178                slice = &mut slice[spot_found + 1..len];
179                find_spot = find_spot - spot_found - 1;
180            }
181        }
182    }
183}
184
185/// An extension trait adding utility functions to lists of complex numbers.
186pub trait ComplexSequence {
187    /// Apply a function componentwise on each complex number.
188    fn apply(&mut self, f: impl FnMut(Complex<f64>) -> Complex<f64>);
189
190    /// Apply a function componentwise on each complex number and copy the output to the `out` array.
191    ///
192    /// # Panics
193    ///
194    /// The buffer of complex numbers must be the same length `out`.
195    fn apply_into<T>(&self, out: &mut [T], f: impl FnMut(Complex<f64>) -> T);
196
197    /// Multiply each value in the sequence by e^iθ, rotating the phase by θ radians.
198    fn phase(&mut self, θ: f64) {
199        let rotator = Complex::new(0., θ).exp();
200        self.apply(|v| v * rotator);
201    }
202
203    /// Copy the real part of each value into `out`.
204    ///
205    /// # Panics
206    ///
207    /// The buffer of complex numbers must be the same length as `out`.
208    fn re(&self, out: &mut [f64]) {
209        self.apply_into(out, |v| v.re());
210    }
211}
212
213impl ComplexSequence for [Complex<f64>] {
214    fn apply(&mut self, mut f: impl FnMut(Complex<f64>) -> Complex<f64>) {
215        for v in self {
216            *v = f(*v);
217        }
218    }
219
220    fn apply_into<T>(&self, out: &mut [T], mut f: impl FnMut(Complex<f64>) -> T) {
221        assert_eq!(self.len(), out.len());
222
223        out.iter_mut()
224            .zip(self.iter())
225            .for_each(|(out_v, complex)| *out_v = f(*complex));
226    }
227}
228
229impl<V: AsMut<[Complex<f64>]> + AsRef<[Complex<f64>]>> ComplexSequence for V {
230    fn apply(&mut self, f: impl FnMut(Complex<f64>) -> Complex<f64>) {
231        self.as_mut().apply(f);
232    }
233
234    fn apply_into<T>(&self, out: &mut [T], f: impl FnMut(Complex<f64>) -> T) {
235        self.as_ref().apply_into(out, f);
236    }
237}
238
239/// Represents whether a function is monotonically increasing or decreasing.
240#[derive(Clone, Copy, Debug)]
241enum Monotonicity {
242    /// The function is increasing
243    Increasing,
244    /// The function is decreasing
245    Decreasing,
246}
247
248/// Represents the initial parameters of a binary search
249#[derive(Clone, Copy, Debug)]
250struct InitialState {
251    start: f64,
252    min: f64,
253    max: f64,
254}
255
256impl InitialState {
257    /// Create initial parameters for binary search
258    ///
259    /// `start` is the initial guess.
260    pub const fn new(start: f64, min: f64, max: f64) -> Self {
261        Self { start, min, max }
262    }
263}
264
265type BinsearchPoint<T> = (f64, (f64, T));
266
267/// Represents the precision required from a binary search
268enum Precision<'a, T> {
269    #[allow(dead_code)]
270    // In case I make changes that need to use this, I don't want to get rid of it just yet
271    Preimage(f64),
272    Image {
273        amount: f64,
274        /// A function to be called if the required precision can't be acheived
275        debug: &'a dyn Fn(BinsearchPoint<T>, BinsearchPoint<T>),
276    },
277}
278
279/// Perform binary search over `f` to find where it is zero.
280///
281/// ```
282/// ```
283///
284/// `monotonicity` tells whether `f` is monotonically increasing or decreasing.
285/// `initial_state` defines the initial state for binary search.
286/// `precision` defines the target precision. The absolute value of `f`'s return value is guaranteed to be less than `precision` except when convergence fails.
287/// `f` is the function to perform binary search over. It is allowed to return an arbitrary value along with the value to binary search over.
288/// `debug` will be called if the binary search doesn't converge and runs out of precision to differentiate the minimum and maximum values. It's called with tuples of (input, output) for the minimum and maximum states. In this situation, the function will return the value of the last guess regardless of whether it's correct.
289fn find_zero<T>(
290    monotonicity: Monotonicity,
291    initial_state: InitialState,
292    precision: Precision<'_, T>,
293    f: impl Fn(f64) -> (f64, T),
294) -> (f64, T) {
295    let mut min = initial_state.min;
296    let mut current = initial_state.start;
297    let mut max = initial_state.max;
298
299    loop {
300        let (v, ret) = f(current);
301
302        if let Precision::Image { amount, debug: _ } = precision {
303            if v.abs() < amount {
304                return (current, ret);
305            }
306        }
307
308        match (v.total_cmp(&0.), monotonicity) {
309            (Ordering::Less, Monotonicity::Increasing)
310            | (Ordering::Greater, Monotonicity::Decreasing) => min = current,
311            (Ordering::Equal, _) => return (current, ret),
312            (Ordering::Greater, Monotonicity::Increasing)
313            | (Ordering::Less, Monotonicity::Decreasing) => max = current,
314        }
315
316        if max != f64::INFINITY {
317            let new_current = min / 2. + max / 2.;
318
319            match precision {
320                Precision::Preimage(amount) => {
321                    if max - min < amount || new_current == current {
322                        return (current, ret);
323                    }
324                }
325                Precision::Image { amount: _, debug } => {
326                    if new_current == current {
327                        debug((min, f(min)), (max, f(max)));
328                        return (current, ret);
329                    }
330                }
331            }
332
333            current = new_current;
334        } else {
335            current *= 2.;
336        }
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    use alloc::borrow::ToOwned;
343    use core::cell::OnceCell;
344    use rand_chacha::ChaCha12Rng;
345
346    use alloc::vec::Vec;
347    use rand::{Rng, SeedableRng};
348    use rustfft::num_complex::{Complex, ComplexFloat};
349
350    use crate::{find_zero, quickselect, DisplayMode, InitialState, Monotonicity};
351
352    #[test]
353    fn test_display_mode() {
354        assert_eq!(DisplayMode::Abs.magnitude(Complex::new(4., 3.)), 5.);
355        assert_eq!(DisplayMode::RealPart.magnitude(Complex::new(4., 3.)), 4.);
356
357        let mut seq = [Complex::new(6., 8.), Complex::new(1., 1.)];
358
359        DisplayMode::Abs.threshold(&mut seq[..], 5.);
360        DisplayMode::Abs.maybe_real_part(&mut seq[..]);
361        assert!((seq[0] - Complex::new(3., 4.)).abs() < 0.000000001);
362        assert_eq!(seq[1], Complex::ZERO);
363
364        seq = [Complex::new(2., 9.), Complex::new(0.5, 1000.)];
365
366        DisplayMode::RealPart.threshold(&mut seq[..], 1.);
367
368        assert!((seq[0] - Complex::new(1., 4.5)).abs() < 0.000000001);
369        assert_eq!(seq[1], Complex::ZERO);
370
371        DisplayMode::RealPart.maybe_real_part(&mut seq[..]);
372        assert!((seq[0] - Complex::new(1., 0.)).abs() < 0.000000001);
373
374        let mut rng = ChaCha12Rng::from_seed(*b"Not all plants spread seeds, and");
375
376        for mode in [DisplayMode::Abs, DisplayMode::RealPart] {
377            let seq = (0..1000)
378                .map(|_| Complex::new(rng.random::<f64>() - 0.5, rng.random::<f64>() - 0.5) * 128.)
379                .collect::<Vec<_>>();
380
381            let mut seq_new = seq.to_owned();
382
383            mode.threshold(&mut *seq_new, 32.);
384
385            seq.iter().zip(seq_new.iter()).for_each(|(before, after)| {
386                if mode.magnitude(*before) < 32. {
387                    assert_eq!(*after, Complex::ZERO);
388                } else {
389                    assert!(
390                        (mode.magnitude(*before) - mode.magnitude(*after) - 32.).abs() < 0.00000001,
391                        "{before} {after} {mode:?}"
392                    );
393
394                    assert!(
395                        (before.arg() - after.arg()).abs() < 0.00000001,
396                        "{before} {after} {mode:?}"
397                    );
398
399                    match mode {
400                        DisplayMode::Abs => {}
401                        DisplayMode::RealPart => {
402                            assert_eq!(
403                                before.re().signum(),
404                                after.re().signum(),
405                                "{before} {after} {mode:?}"
406                            );
407                        }
408                    }
409                }
410            });
411
412            let mut seq_maybe_re = seq_new.to_owned();
413
414            mode.maybe_real_part(&mut *seq_maybe_re);
415
416            seq_new
417                .iter()
418                .zip(seq_maybe_re.iter())
419                .for_each(|(before, after)| match mode {
420                    DisplayMode::Abs => assert_eq!(before, after),
421                    DisplayMode::RealPart => {
422                        assert_eq!(after.im(), 0.);
423                    }
424                });
425        }
426    }
427
428    #[test]
429    fn test_binsearch() {
430        let v = find_zero(
431            Monotonicity::Increasing,
432            InitialState {
433                start: 1.,
434                min: 0.,
435                max: f64::INFINITY,
436            },
437            crate::Precision::Preimage(0.1),
438            |v| (-5. + v, ()),
439        )
440        .0;
441
442        assert!(v > 5. - 0.1 && v < 5. + 0.1);
443
444        let v = find_zero(
445            Monotonicity::Decreasing,
446            InitialState {
447                start: 1.,
448                min: 0.,
449                max: f64::INFINITY,
450            },
451            crate::Precision::Preimage(0.1),
452            |v| (4. - 0.1 * v, ()),
453        )
454        .0;
455
456        assert!(v > 40. - 0.1 && v < 40. + 0.1, "v = {v}");
457
458        let v = find_zero(
459            Monotonicity::Increasing,
460            InitialState {
461                start: 1.,
462                min: 0.,
463                max: f64::INFINITY,
464            },
465            crate::Precision::Image {
466                amount: 0.1,
467                debug: (&|_, _| panic!("Debug should not have been called!")) as &dyn Fn(_, _),
468            },
469            |v| (-3. + 10. * v, ()),
470        )
471        .0;
472
473        assert!(v > 0.3 - 0.01 && v < 0.3 + 0.01, "v = {v}");
474
475        let v = find_zero(
476            Monotonicity::Decreasing,
477            InitialState {
478                start: 1.,
479                min: 0.,
480                max: f64::INFINITY,
481            },
482            crate::Precision::Image {
483                amount: 0.1,
484                debug: (&|_, _| panic!("Debug should not have been called!")) as &dyn Fn(_, _),
485            },
486            |v| (5. - 5. * v, ()),
487        )
488        .0;
489
490        assert!(v > 1. - 0.02 && v < 1. + 0.02, "v = {v}");
491
492        let debug_called = OnceCell::new();
493
494        find_zero(
495            Monotonicity::Decreasing,
496            InitialState {
497                start: 1.,
498                min: 0.,
499                max: f64::INFINITY,
500            },
501            crate::Precision::Image {
502                amount: 0.1,
503                debug: (&|_, _| {
504                    debug_called.set(true).unwrap();
505                }) as &dyn Fn(_, _),
506            },
507            |v| (if v > 5. { 1. } else { -1. }, ()),
508        );
509
510        assert!(debug_called.get().is_some());
511    }
512
513    #[test]
514    fn test_quickselect() {
515        fn verify(rng: &mut ChaCha12Rng, pos: usize, slice: &[f64]) {
516            let mut slice = slice
517                .iter()
518                .enumerate()
519                .map(|(a, b)| (*b, a))
520                .collect::<Vec<_>>();
521
522            quickselect(rng, &mut slice, |a, b| a.0.total_cmp(&b.0), pos);
523
524            for i in 0..pos {
525                assert!(
526                    slice[i].0 >= slice[pos].0,
527                    "Pos: {pos}, Index: {i} - {slice:?}"
528                );
529            }
530
531            for i in pos + 1..slice.len() {
532                assert!(
533                    slice[i].0 <= slice[pos].0,
534                    "Pos: {pos}, Index: {i} - {slice:?}"
535                );
536            }
537
538            let v = slice[pos];
539
540            slice.sort_by(|a, b| b.0.total_cmp(&a.0));
541
542            assert_eq!(slice[pos].0, v.0);
543        }
544
545        let mut rng = ChaCha12Rng::from_seed(*b"Not all seeds plant plants, some");
546
547        verify(&mut rng, 2, &[5., 4., 3., 2., 1.]);
548        verify(&mut rng, 2, &[1., 2., 3., 4., 5.]);
549        verify(&mut rng, 3, &[1., 2., 1., 4., 3.]);
550
551        for i in 0..100 {
552            let pos = rng.random_range(0..i + 1);
553            let data = (0..i + 1).map(|_| rng.random()).collect::<Vec<_>>();
554            verify(&mut rng, pos, &data);
555        }
556    }
557}