nmr_schedule/
lib.rs

1//! This crate implements algorithms for Non-Uniform Sampling in NMR spectroscopy.
2//!
3//! This crate primarily implements algorithms for generating and filtering sampling schedules. It includes various base schedules along with various post-processing filters for schedules.
4//!
5//! Schedulers implemented:
6//! - [Quantiles](`generators::Quantiles`)
7//! - [Poisson Gap](`generators::SinWeightedPoissonGap`)
8//! - [Random](`generators::RandomSampling`)
9//! - [Averaging](`generators::Averaging`)
10//!
11//! Filters implemented:
12//! - [Backfill](`modifiers::FillCorners`)
13//! - [Seed searching](`modifiers::Iterate`)
14//! - [Thue-Morse filter](`modifiers::TMFilter`)
15//! - [PSF polisher](`modifiers::PSFPolisher`)
16//!
17//! This library currently only supports 1D schedules, however there are plans to support higher dimensional schedules in the future and the architecture is already generic over dimension.
18//!
19//! # Examples
20//!
21//! ```
22//! # use nmr_schedule::{*, pdf::*, generators::*, modifiers::*, ndarray::Ix1};
23//! // Generate a 64x256 schedule with Quantiles with QSin weighting, 8 points backfill, and TMPF filtering
24//! let sched = Quantiles::new(|len| qsin(len, QSinBias::Low, 3.))
25//!     .fill_corners(|_, _| [8, 1]) // Any function of the count and length of the schedule
26//!     .tm_filter()
27//!     .polish_psf(0.1, 0.32, DisplayMode::Abs)
28//!     .generate(64, Ix1(256));
29//!
30//! println!("{sched}");
31//! ```
32//!
33//! ```
34//! # use nmr_schedule::{*, pdf::*, generators::*, modifiers::*, ndarray::Ix1};
35//! // Apply TMPF filtering to an existing schedule
36//! let sched_encoded = "0\n1\n2\n5\n7\n9\n20";
37//! let sched =
38//!     Schedule::decode(sched_encoded, EncodingType::ZeroBased, |dim| Ok(dim)).unwrap();
39//!
40//! let filtered =
41//!     PSFPolisher::new(0.1, 0.32, DisplayMode::Abs).filter(TMFilter::new().filter(sched));
42//!
43//! let encoded = filtered.encode(EncodingType::ZeroBased);
44//! println!("{encoded}");
45//! ```
46
47#![cfg_attr(not(test), no_std)]
48#![cfg_attr(docsrs, feature(doc_cfg))]
49#![warn(clippy::cargo)]
50#![warn(clippy::complexity)]
51#![warn(clippy::correctness)]
52#![warn(clippy::perf)]
53#![warn(clippy::style)]
54#![warn(clippy::suspicious)]
55#![warn(missing_docs)]
56#![warn(missing_copy_implementations)]
57#![warn(missing_debug_implementations)]
58#![warn(clippy::missing_panics_doc)]
59#![warn(clippy::missing_const_for_fn)]
60// Remove allow after https://github.com/rust-lang/rust-clippy/pull/14609 goes into stable
61#![allow(clippy::unused_unit)]
62
63extern crate alloc;
64
65use core::cmp::Ordering;
66
67pub mod generators;
68pub mod modifiers;
69pub mod pdf;
70pub mod reconstruction;
71mod schedule;
72#[cfg(feature = "terminal-viz")]
73#[cfg_attr(docsrs, doc(cfg(feature = "terminal-viz")))]
74pub mod terminal_viz;
75
76use rand::Rng;
77use rand_chacha::ChaCha12Rng;
78use rustfft::num_complex::{Complex, ComplexFloat};
79pub use schedule::*;
80
81pub use rustfft;
82
83pub use ndarray;
84
85/// Represents whether a spectrum is meant to be displayed as the real part or as the absolute value
86///
87/// Note that this is very experimental and you should default to using [`DisplayMode::Abs`] in parameters. The PSF Polisher publication only uses [`DisplayMode::Abs`].
88///
89/// Future research can look into whether knowing the display mode of an experiment can inform schedule generation and reconstruction.
90///
91/// The PSF polisher, when passed in [`DisplayMode::RealPart`], will only "see" the real part of the PSF during filtering and ignore the imaginary part, potentially leaving artifacts in and moving PSF noise into the imaginary part. If it is known that all of the signal will be exclusively in the real part, then moving sampling noise from the real part to the imaginary part may be desirable.
92///
93/// The IST reconstructor will also only "see" the real part when passed in [`DisplayMode::RealPart`]. It will assume that all signals in the imaginary axis are noise and ignore them.
94#[derive(Clone, Copy, Debug, Default)]
95pub enum DisplayMode {
96    /// The spectrum will be displayed in absolute value mode
97    #[default]
98    Abs,
99    /// The spectrum will be displayed in real part mode; the complex part will be ignored.
100    RealPart,
101}
102
103impl DisplayMode {
104    /// Calculate the magnitude of the complex number given the display mode. In `Abs` mode, this will return `complex.abs()`, and in `RealPart` mode, this will return `complex.re().abs()`.
105    pub fn magnitude(self, complex: Complex<f64>) -> f64 {
106        match self {
107            DisplayMode::Abs => complex.abs(),
108            DisplayMode::RealPart => complex.re().abs(),
109        }
110    }
111
112    /// Apply a soft threshold to a sequence, zeroing out all components with [`DisplayMode::magnitude`] less than the threshold.
113    pub fn threshold<T: ComplexSequence + ?Sized>(self, sequence: &mut T, threshold: f64) {
114        sequence.apply(|v| {
115            let mag = self.magnitude(v);
116            if mag > threshold {
117                v * (mag - threshold) / mag
118            } else {
119                Complex::new(0., 0.)
120            }
121        })
122    }
123
124    /// Take the real part of the sequence if `self` is `RealPart`, otherwise leaving it alone.
125    pub fn maybe_real_part<T: ComplexSequence + ?Sized>(self, sequence: &mut T) {
126        if let Self::RealPart = self {
127            sequence.apply(|v| Complex::new(v.re(), 0.));
128        }
129    }
130}
131
132fn partition<T>(rng: &mut ChaCha12Rng, slice: &mut [T], by: &impl Fn(&T, &T) -> Ordering) -> usize {
133    slice.swap(0, rng.random_range(0..slice.len()));
134
135    let mut i = 1;
136    let mut j = slice.len() - 1;
137
138    loop {
139        while i < slice.len() && !matches!(by(&slice[i], &slice[0]), Ordering::Less) {
140            i += 1;
141        }
142
143        while matches!(by(&slice[j], &slice[0]), Ordering::Less) {
144            j -= 1;
145        }
146
147        // If the indices crossed, return
148        if i > j {
149            slice.swap(0, j);
150            return j;
151        }
152
153        // Swap the elements at the left and right indices
154        slice.swap(i, j);
155        i += 1;
156    }
157}
158
159/// Standard quickselect algorithm: https://en.wikipedia.org/wiki/Quickselect
160/// Modified to sort in descending order (since I want maxima in all of my usecases)
161///
162/// After calling this function, the value at index `find_spot` is guaranteed to be at the correctly sorted position and all values at indices less than `find_spot` are guaranteed to be greater than the value at `find_spot` and vice versa for indices greater.
163pub(crate) fn quickselect<T>(
164    rng: &mut ChaCha12Rng,
165    mut slice: &mut [T],
166    by: impl Fn(&T, &T) -> Ordering,
167    mut find_spot: usize,
168) {
169    loop {
170        let len = slice.len();
171
172        if len < 2 {
173            return;
174        }
175
176        let spot_found = partition(rng, slice, &by);
177
178        match find_spot.cmp(&spot_found) {
179            Ordering::Less => slice = &mut slice[0..spot_found],
180            Ordering::Equal => return,
181            Ordering::Greater => {
182                slice = &mut slice[spot_found + 1..len];
183                find_spot = find_spot - spot_found - 1;
184            }
185        }
186    }
187}
188
189/// An extension trait adding utility functions to lists of complex numbers.
190pub trait ComplexSequence {
191    /// Apply a function componentwise on each complex number.
192    fn apply(&mut self, f: impl FnMut(Complex<f64>) -> Complex<f64>);
193
194    /// Apply a function componentwise on each complex number and copy the output to the `out` array.
195    ///
196    /// # Panics
197    ///
198    /// The buffer of complex numbers must be the same length `out`.
199    fn apply_into<T>(&self, out: &mut [T], f: impl FnMut(Complex<f64>) -> T);
200
201    /// Multiply each value in the sequence by e^iθ, rotating the phase by θ radians.
202    fn phase(&mut self, θ: f64) {
203        let rotator = Complex::new(0., θ).exp();
204        self.apply(|v| v * rotator);
205    }
206
207    /// Copy the real part of each value into `out`.
208    ///
209    /// # Panics
210    ///
211    /// The buffer of complex numbers must be the same length as `out`.
212    fn re(&self, out: &mut [f64]) {
213        self.apply_into(out, |v| v.re());
214    }
215}
216
217impl ComplexSequence for [Complex<f64>] {
218    fn apply(&mut self, mut f: impl FnMut(Complex<f64>) -> Complex<f64>) {
219        for v in self {
220            *v = f(*v);
221        }
222    }
223
224    fn apply_into<T>(&self, out: &mut [T], mut f: impl FnMut(Complex<f64>) -> T) {
225        assert_eq!(self.len(), out.len());
226
227        out.iter_mut()
228            .zip(self.iter())
229            .for_each(|(out_v, complex)| *out_v = f(*complex));
230    }
231}
232
233impl<V: AsMut<[Complex<f64>]> + AsRef<[Complex<f64>]>> ComplexSequence for V {
234    fn apply(&mut self, f: impl FnMut(Complex<f64>) -> Complex<f64>) {
235        self.as_mut().apply(f);
236    }
237
238    fn apply_into<T>(&self, out: &mut [T], f: impl FnMut(Complex<f64>) -> T) {
239        self.as_ref().apply_into(out, f);
240    }
241}
242
243/// Represents whether a function is monotonically increasing or decreasing.
244#[derive(Clone, Copy, Debug)]
245enum Monotonicity {
246    /// The function is increasing
247    Increasing,
248    /// The function is decreasing
249    Decreasing,
250}
251
252/// Represents the initial parameters of a binary search
253#[derive(Clone, Copy, Debug)]
254struct InitialState {
255    start: f64,
256    min: f64,
257    max: f64,
258}
259
260impl InitialState {
261    /// Create initial parameters for binary search
262    ///
263    /// `start` is the initial guess.
264    pub const fn new(start: f64, min: f64, max: f64) -> Self {
265        Self { start, min, max }
266    }
267}
268
269type BinsearchPoint<T> = (f64, (f64, T));
270
271/// Represents the precision required from a binary search
272enum Precision<'a, T> {
273    #[allow(dead_code)]
274    // In case I make changes that need to use this, I don't want to get rid of it just yet
275    Preimage(f64),
276    Image {
277        amount: f64,
278        /// A function to be called if the required precision can't be acheived
279        debug: &'a dyn Fn(BinsearchPoint<T>, BinsearchPoint<T>),
280    },
281}
282
283/// Perform binary search over `f` to find where it is zero.
284///
285/// ```
286/// ```
287///
288/// `monotonicity` tells whether `f` is monotonically increasing or decreasing.
289/// `initial_state` defines the initial state for binary search.
290/// `precision` defines the target precision. The absolute value of `f`'s return value is guaranteed to be less than `precision` except when convergence fails.
291/// `f` is the function to perform binary search over. It is allowed to return an arbitrary value along with the value to binary search over.
292/// `debug` will be called if the binary search doesn't converge and runs out of precision to differentiate the minimum and maximum values. It's called with tuples of (input, output) for the minimum and maximum states. In this situation, the function will return the value of the last guess regardless of whether it's correct.
293fn find_zero<T>(
294    monotonicity: Monotonicity,
295    initial_state: InitialState,
296    precision: Precision<'_, T>,
297    f: impl Fn(f64) -> (f64, T),
298) -> (f64, T) {
299    let mut min = initial_state.min;
300    let mut current = initial_state.start;
301    let mut max = initial_state.max;
302
303    loop {
304        let (v, ret) = f(current);
305
306        if let Precision::Image { amount, debug: _ } = precision {
307            if v.abs() < amount {
308                return (current, ret);
309            }
310        }
311
312        match (v.total_cmp(&0.), monotonicity) {
313            (Ordering::Less, Monotonicity::Increasing)
314            | (Ordering::Greater, Monotonicity::Decreasing) => min = current,
315            (Ordering::Equal, _) => return (current, ret),
316            (Ordering::Greater, Monotonicity::Increasing)
317            | (Ordering::Less, Monotonicity::Decreasing) => max = current,
318        }
319
320        if max != f64::INFINITY {
321            let new_current = min / 2. + max / 2.;
322
323            match precision {
324                Precision::Preimage(amount) => {
325                    if max - min < amount || new_current == current {
326                        return (current, ret);
327                    }
328                }
329                Precision::Image { amount: _, debug } => {
330                    if new_current == current {
331                        debug((min, f(min)), (max, f(max)));
332                        return (current, ret);
333                    }
334                }
335            }
336
337            current = new_current;
338        } else {
339            current *= 2.;
340        }
341    }
342}
343
344#[cfg(test)]
345mod tests {
346    use alloc::borrow::ToOwned;
347    use core::cell::OnceCell;
348    use rand_chacha::ChaCha12Rng;
349
350    use alloc::vec::Vec;
351    use rand::{Rng, SeedableRng};
352    use rustfft::num_complex::{Complex, ComplexFloat};
353
354    use crate::{DisplayMode, InitialState, Monotonicity, find_zero, quickselect};
355
356    #[test]
357    fn test_display_mode() {
358        assert_eq!(DisplayMode::Abs.magnitude(Complex::new(4., 3.)), 5.);
359        assert_eq!(DisplayMode::RealPart.magnitude(Complex::new(4., 3.)), 4.);
360
361        let mut seq = [Complex::new(6., 8.), Complex::new(1., 1.)];
362
363        DisplayMode::Abs.threshold(&mut seq[..], 5.);
364        DisplayMode::Abs.maybe_real_part(&mut seq[..]);
365        assert!((seq[0] - Complex::new(3., 4.)).abs() < 0.000000001);
366        assert_eq!(seq[1], Complex::ZERO);
367
368        seq = [Complex::new(2., 9.), Complex::new(0.5, 1000.)];
369
370        DisplayMode::RealPart.threshold(&mut seq[..], 1.);
371
372        assert!((seq[0] - Complex::new(1., 4.5)).abs() < 0.000000001);
373        assert_eq!(seq[1], Complex::ZERO);
374
375        DisplayMode::RealPart.maybe_real_part(&mut seq[..]);
376        assert!((seq[0] - Complex::new(1., 0.)).abs() < 0.000000001);
377
378        let mut rng = ChaCha12Rng::from_seed(*b"Not all plants spread seeds, and");
379
380        for mode in [DisplayMode::Abs, DisplayMode::RealPart] {
381            let seq = (0..1000)
382                .map(|_| Complex::new(rng.random::<f64>() - 0.5, rng.random::<f64>() - 0.5) * 128.)
383                .collect::<Vec<_>>();
384
385            let mut seq_new = seq.to_owned();
386
387            mode.threshold(&mut *seq_new, 32.);
388
389            seq.iter().zip(seq_new.iter()).for_each(|(before, after)| {
390                if mode.magnitude(*before) < 32. {
391                    assert_eq!(*after, Complex::ZERO);
392                } else {
393                    assert!(
394                        (mode.magnitude(*before) - mode.magnitude(*after) - 32.).abs() < 0.00000001,
395                        "{before} {after} {mode:?}"
396                    );
397
398                    assert!(
399                        (before.arg() - after.arg()).abs() < 0.00000001,
400                        "{before} {after} {mode:?}"
401                    );
402
403                    match mode {
404                        DisplayMode::Abs => {}
405                        DisplayMode::RealPart => {
406                            assert_eq!(
407                                before.re().signum(),
408                                after.re().signum(),
409                                "{before} {after} {mode:?}"
410                            );
411                        }
412                    }
413                }
414            });
415
416            let mut seq_maybe_re = seq_new.to_owned();
417
418            mode.maybe_real_part(&mut *seq_maybe_re);
419
420            seq_new
421                .iter()
422                .zip(seq_maybe_re.iter())
423                .for_each(|(before, after)| match mode {
424                    DisplayMode::Abs => assert_eq!(before, after),
425                    DisplayMode::RealPart => {
426                        assert_eq!(after.im(), 0.);
427                    }
428                });
429        }
430    }
431
432    #[test]
433    fn test_binsearch() {
434        let v = find_zero(
435            Monotonicity::Increasing,
436            InitialState {
437                start: 1.,
438                min: 0.,
439                max: f64::INFINITY,
440            },
441            crate::Precision::Preimage(0.1),
442            |v| (-5. + v, ()),
443        )
444        .0;
445
446        assert!(v > 5. - 0.1 && v < 5. + 0.1);
447
448        let v = find_zero(
449            Monotonicity::Decreasing,
450            InitialState {
451                start: 1.,
452                min: 0.,
453                max: f64::INFINITY,
454            },
455            crate::Precision::Preimage(0.1),
456            |v| (4. - 0.1 * v, ()),
457        )
458        .0;
459
460        assert!(v > 40. - 0.1 && v < 40. + 0.1, "v = {v}");
461
462        let v = find_zero(
463            Monotonicity::Increasing,
464            InitialState {
465                start: 1.,
466                min: 0.,
467                max: f64::INFINITY,
468            },
469            crate::Precision::Image {
470                amount: 0.1,
471                debug: (&|_, _| -> () { panic!("Debug should not have been called!") })
472                    as &dyn Fn(_, _),
473            },
474            |v| (-3. + 10. * v, ()),
475        )
476        .0;
477
478        assert!(v > 0.3 - 0.01 && v < 0.3 + 0.01, "v = {v}");
479
480        let v = find_zero(
481            Monotonicity::Decreasing,
482            InitialState {
483                start: 1.,
484                min: 0.,
485                max: f64::INFINITY,
486            },
487            crate::Precision::Image {
488                amount: 0.1,
489                debug: (&|_, _| -> () { panic!("Debug should not have been called!") })
490                    as &dyn Fn(_, _),
491            },
492            |v| (5. - 5. * v, ()),
493        )
494        .0;
495
496        assert!(v > 1. - 0.02 && v < 1. + 0.02, "v = {v}");
497
498        let debug_called = OnceCell::new();
499
500        find_zero(
501            Monotonicity::Decreasing,
502            InitialState {
503                start: 1.,
504                min: 0.,
505                max: f64::INFINITY,
506            },
507            crate::Precision::Image {
508                amount: 0.1,
509                debug: (&|_, _| {
510                    debug_called.set(true).unwrap();
511                }) as &dyn Fn(_, _),
512            },
513            |v| (if v > 5. { 1. } else { -1. }, ()),
514        );
515
516        assert!(debug_called.get().is_some());
517    }
518
519    #[test]
520    fn test_quickselect() {
521        fn verify(rng: &mut ChaCha12Rng, pos: usize, slice: &[f64]) {
522            let mut slice = slice
523                .iter()
524                .enumerate()
525                .map(|(a, b)| (*b, a))
526                .collect::<Vec<_>>();
527
528            quickselect(rng, &mut slice, |a, b| a.0.total_cmp(&b.0), pos);
529
530            for i in 0..pos {
531                assert!(
532                    slice[i].0 >= slice[pos].0,
533                    "Pos: {pos}, Index: {i} - {slice:?}"
534                );
535            }
536
537            for i in pos + 1..slice.len() {
538                assert!(
539                    slice[i].0 <= slice[pos].0,
540                    "Pos: {pos}, Index: {i} - {slice:?}"
541                );
542            }
543
544            let v = slice[pos];
545
546            slice.sort_by(|a, b| b.0.total_cmp(&a.0));
547
548            assert_eq!(slice[pos].0, v.0);
549        }
550
551        let mut rng = ChaCha12Rng::from_seed(*b"Not all seeds plant plants, some");
552
553        verify(&mut rng, 2, &[5., 4., 3., 2., 1.]);
554        verify(&mut rng, 2, &[1., 2., 3., 4., 5.]);
555        verify(&mut rng, 3, &[1., 2., 1., 4., 3.]);
556
557        for i in 0..100 {
558            let pos = rng.random_range(0..i + 1);
559            let data = (0..i + 1).map(|_| rng.random()).collect::<Vec<_>>();
560            verify(&mut rng, pos, &data);
561        }
562    }
563}