nmr_schedule/pdf/mod.rs
1//! Implements Probability Distribution Functions
2//!
3//! These are used to indicate to schedule generators where samples should be taken. This module contains the core PDF implementation as well as some presets representing common PDFs.
4
5use core::ops::{Deref, Range};
6
7use alloc::{borrow::ToOwned, boxed::Box, rc::Rc, sync::Arc, vec::Vec};
8use ndarray::{Dim, Dimension, Ix};
9use once_cell::race::OnceBox;
10use rand::Rng;
11
12mod presets;
13
14pub use presets::*;
15
16use crate::{InitialState, Monotonicity, find_zero};
17
18/// Represents a Probability Distribution Function.
19///
20/// PDFs are used to indicate to schedule generators where samples should be taken.
21pub struct Pdf {
22    inner: PdfImpl,
23    // Memoized results of `self.get_distribution()`, `self.get_integral()`, and `self.probabilities()`
24    distribution: OnceBox<Vec<f64>>, // Only used in the continuous representation
25    integral: OnceBox<Vec<f64>>,
26    probabilities: OnceBox<Probabilities>,
27}
28
29impl core::fmt::Debug for Pdf {
30    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
31        match &self.inner {
32            PdfImpl::Continuous { len, integral: _ } => f
33                .debug_struct("Pdf::Continuous")
34                .field("len", len)
35                .field("integral", &"*closure*")
36                .finish(),
37            PdfImpl::Discrete { values } => f
38                .debug_struct("Pdf::Discrete")
39                .field("values", values)
40                .finish(),
41        }
42    }
43}
44
45/// The opaque inner representation of a PDF
46enum PdfImpl {
47    Continuous {
48        len: usize,
49        integral: Arc<dyn Fn(f64) -> f64 + Send + Sync>,
50    },
51    Discrete {
52        values: Vec<f64>,
53    },
54}
55
56impl Pdf {
57    /// Create a PDF from a continuous representation.
58    ///
59    /// Expects the *integral* of the PDF being represented. This is often called the Cumulative Distribution Function or CDF. The function will receive inputs from `0..=len`. The `len` parameter is the length of the PDF in number of samples.
60    ///
61    /// The CDF is expected to be monotonically increasing, and `integral(len as f64) - integral(0.)` must be one. These are the necessary conditions for a function to be a valid CDF.
62    ///
63    /// It is not assumed that `integral(0.) = 0.`.
64    ///
65    /// # Example
66    ///
67    /// ```
68    /// # use nmr_schedule::pdf::*;
69    /// // This represents an unweighted PDF because the integral of a constant function is linear.
70    /// // You can use the `unweighted` preset instead if you would like
71    /// let pdf = Pdf::from_integral(|v| v / 256., 256);
72    /// ```
73    ///
74    /// # Panics
75    ///
76    /// This function will assert that
77    ///   - `integral(len as f64) - integral(0.) ≈ 1.` with some margin for floating point error allowed;
78    ///   - the CDF is monotonically increasing for particular intervals; and
79    ///   - `len` is non-zero.
80    pub fn from_integral(integral: impl Fn(f64) -> f64 + 'static + Send + Sync, len: usize) -> Pdf {
81        assert!(len > 0, "Cannot create a PDF of length zero.");
82        assert!(
83            (integral(len as f64) - integral(0.) - 1.).abs() < 0.001,
84            "The difference between the start and end of the CDF is not one."
85        );
86        assert!(
87            (0..len).all(|i| integral(i as f64) <= integral(i as f64 + 1.)),
88            "The CDF is not increasing."
89        );
90
91        Pdf {
92            inner: PdfImpl::Continuous {
93                len,
94                integral: Arc::from(integral),
95            },
96            distribution: OnceBox::new(),
97            integral: OnceBox::new(),
98            probabilities: OnceBox::new(),
99        }
100    }
101
102    /// Create a PDF from a discrete representation.
103    ///
104    /// Expects a list of probabilities, one for each sample, that are all non-negative and sum to one.
105    ///
106    /// # Example
107    ///
108    /// ```
109    /// # use nmr_schedule::pdf::*;
110    /// let pdf = Pdf::from_discrete(vec![0.125, 0.5, 0.375]);
111    /// ```
112    ///
113    /// # Panics
114    ///
115    /// This function will assert that
116    ///   - the sum of all of the probabilities is one, with some margin for floating point error allowed;
117    ///   - all probabilities are non-negative; and
118    ///   - the length of the PDF is non-zero.
119    pub fn from_discrete(discrete: Vec<f64>) -> Pdf {
120        assert!(!discrete.is_empty(), "Cannot create a PDF of length zero.");
121        assert!(
122            (discrete.iter().sum::<f64>() - 1.).abs() < 0.000000001,
123            "The probabilities must sum to one."
124        );
125        assert!(
126            discrete.iter().all(|v| *v >= 0.),
127            "The probabilities must all be positive."
128        );
129
130        Pdf {
131            inner: PdfImpl::Discrete { values: discrete },
132            distribution: OnceBox::new(),
133            integral: OnceBox::new(),
134            probabilities: OnceBox::new(),
135        }
136    }
137
138    /// Calculate the discrete representation of the PDF.
139    ///
140    /// This method is memoized.
141    ///
142    /// # Example
143    ///
144    /// ```
145    /// # use nmr_schedule::pdf::*;
146    /// assert!(
147    ///     unweighted(3)
148    ///         .get_distribution()
149    ///         .iter()
150    ///         .zip([1. / 3., 1. / 3., 1. / 3.])
151    ///         .all(|(l, r)| (l - r).abs() < 0.00001),
152    /// );
153    /// ```
154    pub fn get_distribution(&self) -> &[f64] {
155        match &self.inner {
156            PdfImpl::Continuous { len, integral } => self.distribution.get_or_init(|| {
157                Box::new(
158                    (0..*len)
159                        .map(|pos| integral(pos as f64 + 1.) - integral(pos as f64))
160                        .collect(),
161                )
162            }),
163            PdfImpl::Discrete { values } => values,
164        }
165    }
166
167    /// Calculate the integral of the PDF for whole number values
168    ///
169    /// The output has length `self.len() + 1`. The value at index `0` is guaranteed to be zero and the output at index `self.len()` is guaranteed to be one with a small margin for floating point error.
170    ///
171    /// This method is memoized.
172    ///
173    /// # Example
174    ///
175    /// ```
176    /// # use nmr_schedule::pdf::*;
177    /// assert!(
178    ///     Pdf::from_discrete(vec![0.3, 0.6, 0.1])
179    ///         .get_integral()
180    ///         .iter()
181    ///         .zip([0., 0.3, 0.9, 1.])
182    ///         .all(|(l, r)| (l - r).abs() < 0.00001),
183    /// );
184    /// ```
185    #[allow(clippy::missing_panics_doc)] // Panics are bugs
186    pub fn get_integral(&self) -> &[f64] {
187        self.integral.get_or_init(|| {
188            let distribution = self.get_distribution();
189
190            let mut sum = 0.;
191
192            let mut integral = distribution
193                .iter()
194                .map(|v| {
195                    let prev = sum;
196                    sum += v;
197                    prev
198                })
199                .collect::<Vec<_>>();
200
201            integral.push(sum);
202
203            debug_assert!(*integral.first().unwrap() == 0.);
204            debug_assert!((integral.last().unwrap() - 1.).abs() < 0.00001);
205
206            Box::new(integral)
207        })
208    }
209
210    /// Slice a PDF to fit a certain range
211    ///
212    /// The PDF will be automatically rescaled to still sum to one.
213    ///
214    /// # Example
215    ///
216    /// ```
217    /// # use nmr_schedule::pdf::*;
218    /// assert!(
219    ///     exponential(256, 4.)
220    ///         .slice(0..128)
221    ///         .get_distribution()
222    ///         .iter()
223    ///         .zip(exponential(128, 2.)
224    ///            .get_distribution()
225    ///         )
226    ///         .all(|(l, r)| (l - r).abs() < 0.00001),
227    /// );
228    /// ```
229    pub fn slice(&self, range: Range<usize>) -> Pdf {
230        match &self.inner {
231            PdfImpl::Continuous { len: _, integral } => {
232                let scale = 1. / (integral(range.end as f64) - integral(range.start as f64));
233
234                let integral_for_closure = Arc::clone(integral);
235
236                Pdf::from_integral(
237                    move |v| integral_for_closure(v - range.start as f64) * scale,
238                    range.len(),
239                )
240            }
241            PdfImpl::Discrete { values } => {
242                let mut sliced = values[range].to_owned();
243                let sum = sliced.iter().sum::<f64>();
244                sliced.iter_mut().for_each(|v| *v /= sum);
245                Pdf::from_discrete(sliced)
246            }
247        }
248    }
249
250    /// Returns the length of the PDF in number of samples.
251    ///
252    /// # Example
253    ///
254    /// ```
255    /// # use nmr_schedule::pdf::*;
256    /// assert_eq!(linear(256).len(), 256);
257    /// ```
258    #[allow(clippy::len_without_is_empty)] // A PDF cannot be empty
259    pub fn len(&self) -> usize {
260        match &self.inner {
261            PdfImpl::Continuous { len, integral: _ } => *len,
262            PdfImpl::Discrete { values } => values.len(),
263        }
264    }
265
266    /// Calculate the probabilities of selecting each sample position under random sampling given the number of samples to select.
267    ///
268    /// # Example
269    /// ```
270    /// # use nmr_schedule::pdf::*;
271    /// // With unweighted sampling and selecting 128 out of 256,
272    /// // each sample position has a 1/2 chance of being selected.
273    /// assert!(
274    ///     unweighted(256)
275    ///         .probabilities(128)
276    ///         .iter()
277    ///         .all(|v| (*v - 0.5).abs() < 0.000001)
278    /// );
279    /// ```
280    ///
281    /// Note that this only gives a very good approximation because the true value is (as of now) computationally infeasible to find.
282    ///
283    /// This method is memoized.
284    #[allow(clippy::missing_panics_doc)] // Panicking is a bug
285    pub fn probabilities(&self, count: usize) -> &Probabilities {
286        self.probabilities.get_or_init(|| {
287            let (_power, probabilities) = find_zero(
288                Monotonicity::Increasing,
289                InitialState::new(count as f64, 0., f64::INFINITY),
290                crate::Precision::Image {
291                    amount: 0.0000001,
292                    debug: (&|prob1, prob2| -> () { panic!("{prob1:?}\n{prob2:?}") }) as &_,
293                },
294                |power| {
295                    let probabilities = self
296                        .get_distribution()
297                        .iter()
298                        .map(move |v| 1. - (1. - v).powf(power))
299                        .collect::<Vec<_>>();
300
301                    (
302                        probabilities.iter().sum::<f64>() - (count as f64),
303                        Probabilities { probabilities },
304                    )
305                },
306            );
307
308            Box::new(probabilities)
309        })
310    }
311
312    /// Sample the PDF using the given `Rng`
313    ///
314    /// Returns the zero-based index of the position that was sampled.
315    pub fn sample_pdf<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
316        let choice = rng.random();
317
318        debug_assert!((0. ..1.).contains(&choice));
319
320        match self.get_integral()[1..].binary_search_by(|v| v.total_cmp(&choice)) {
321            Ok(v) => v,
322            Err(v) => v,
323        }
324    }
325
326    /// Return a continuous representation of the integral of the PDF. Applying `0.` to the returned function will return `0.`.
327    ///
328    /// This function is implemented for discretely represented PDFs by interpolating `self.get_integral()` using Catmull-Rom to give a once-differentiable interpolation.
329    ///
330    /// # Example
331    ///
332    /// ```
333    /// # use core::f64::consts::*;
334    /// # use nmr_schedule::pdf::*;
335    /// let pdf = qsin(256, QSinBias::Low, PI);
336    /// let integral = pdf.continuous_integral();
337    ///
338    /// assert_eq!(integral(0.), 0.);
339    /// assert_eq!(integral(256.), 1.);
340    /// let c = FRAC_PI_2 - 1.;
341    /// assert_eq!(integral(128.), (FRAC_PI_4 + FRAC_PI_4.cos()) / c - 1. / c);
342    /// ```
343    ///
344    /// # Panics
345    ///
346    /// The function returned will panic if it is called with values outside of the range of the PDF.
347    pub fn continuous_integral<'a>(&'a self) -> Rc<dyn Fn(f64) -> f64 + 'a> {
348        fn if_out_of_range(v: f64, len: usize) -> Option<f64> {
349            if v < 0. {
350                return Some(0.);
351            }
352
353            if v > len as f64 {
354                return Some(1.);
355            }
356
357            None
358        }
359
360        match &self.inner {
361            PdfImpl::Continuous { len, integral } => {
362                let integral_for_closure = Arc::clone(integral);
363                let start_value = integral_for_closure(0.);
364                Rc::new(move |v| {
365                    if let Some(v) = if_out_of_range(v, *len) {
366                        return v;
367                    }
368
369                    integral_for_closure(v) - start_value
370                })
371            }
372            PdfImpl::Discrete { values: _ } => {
373                let integral = self.get_integral();
374                Rc::from(|v: f64| {
375                    if let Some(v) = if_out_of_range(v, integral.len() - 1) {
376                        return v;
377                    }
378
379                    let idx_below = v.floor() as usize;
380
381                    let output_vector = [
382                        if idx_below == 0 {
383                            0.
384                        } else {
385                            integral[idx_below - 1]
386                        },
387                        integral[idx_below],
388                        *integral.get(idx_below + 1).unwrap_or(&1.),
389                        *integral.get(idx_below + 2).unwrap_or(&1.),
390                    ];
391
392                    let time = v.fract();
393
394                    let time_vector = [1., time, time * time, time * time * time];
395
396                    // Interpolate using Catmull-Rom (thanks freya! (https://youtu.be/jvPPXbo87ds?t=2967))
397
398                    const SPLINE: [[f64; 4]; 4] = [
399                        [0., 1., 0., 0.],
400                        [-0.5, 0., 0.5, 0.],
401                        [1., -2.5, 2., -0.5],
402                        [-0.5, 1.5, -1.5, 0.5],
403                    ];
404
405                    SPLINE
406                        .iter()
407                        .map(|v| v.iter().zip(output_vector).map(|(a, b)| a * b).sum::<f64>())
408                        .zip(time_vector)
409                        .map(|(a, b)| a * b)
410                        .sum()
411                })
412            }
413        }
414    }
415}
416
417/// A value that can generate PDFs of arbitrary dimensions.
418///
419/// There exists a blanket implementation for closures that input a length and return a PDF of that length. The closure will be called once per dimension.
420pub trait PdfGenerator<Dim: Dimension> {
421    /// Get a PDF with the requested dimensions; returns one 1D PDF per dimension.
422    ///
423    /// Asserts that the implementation gave a PDF of the correct dimensions.
424    ///
425    /// Implementors should not override this method and instead implement `_get_unchecked`.
426    fn get(&self, dim: Dim) -> Vec<Pdf> {
427        let pdfs = self._get_unchecked(dim.to_owned());
428
429        assert_eq!(
430            pdfs.len(),
431            dim.ndim(),
432            "The number of dimensions of the PDF are different from the dimensions specified."
433        );
434
435        for (i, pdf) in pdfs.iter().enumerate() {
436            assert_eq!(
437                dim[i],
438                pdf.len(),
439                "The length of dimension {i} is different from what was specified."
440            );
441        }
442
443        pdfs
444    }
445
446    /// Return one 1D PDF for each dimension requested. Callers should not call this function directly and instead call `get`.
447    fn _get_unchecked(&self, dim: Dim) -> Vec<Pdf>;
448}
449
450impl<Dim: Dimension, F: Fn(usize) -> Pdf> PdfGenerator<Dim> for F {
451    fn _get_unchecked(&self, dim: Dim) -> Vec<Pdf> {
452        let mut out = Vec::with_capacity(dim.ndim());
453
454        for i in 0..dim.ndim() {
455            out.push(self(dim[i]));
456        }
457
458        out
459    }
460}
461
462impl<Dim: Dimension> PdfGenerator<Dim> for fn(Dim) -> Vec<Pdf> {
463    fn _get_unchecked(&self, dim: Dim) -> Vec<Pdf> {
464        self(dim)
465    }
466}
467
468impl<const N: usize> PdfGenerator<Dim<[Ix; N]>> for [fn(usize) -> Pdf; N]
469where
470    Dim<[Ix; N]>: Dimension,
471{
472    fn _get_unchecked(&self, dim: Dim<[Ix; N]>) -> Vec<Pdf> {
473        let mut out = Vec::with_capacity(N);
474
475        for (i, f) in self.iter().enumerate() {
476            out.push((f)(dim[i]));
477        }
478
479        out
480    }
481}
482
483/// A list of probabilities, each one associated with a sample.
484///
485/// May represent the probability of selecting a sample or of a pattern being present at a given position.
486#[derive(Debug, Clone)]
487pub struct Probabilities {
488    probabilities: Vec<f64>,
489}
490
491impl Deref for Probabilities {
492    type Target = [f64];
493
494    fn deref(&self) -> &Self::Target {
495        &self.probabilities
496    }
497}
498
499impl Probabilities {
500    /// Returns the probability of a kernel (sampling pattern) appearing at any given position assuming samples are selected independently.
501    ///
502    /// The independence condition is not valid for most NUS algorithms, but this method will still give an idea of where patterns are most likely to show up.
503    ///
504    /// # Example
505    ///
506    /// ```
507    /// # use nmr_schedule::pdf::*;
508    /// # use core::f64::consts::*;
509    /// let pdf = qsin(256, QSinBias::Low, PI);
510    /// let probabilities = pdf.probabilities(64);
511    /// let kdf = probabilities.kernel_density([true, false, true]);
512    ///
513    /// // The [true, false, true] pattern can't possibly show up at the
514    /// // second to last position since there are only two samples left
515    /// // which isn't enough room to fit a pattern of length 3.
516    /// assert_eq!(kdf.len(), 254);
517    /// assert!((kdf[32] - 0.13).abs() < 0.01); // Pattern is relatively likely to appear at the start
518    /// assert!(kdf[250] < 0.001); // Pattern is relatively unlikely to appear at the end
519    /// ```
520    ///
521    /// # Panics
522    /// The function will panic if the length of the kernel is zero.
523    pub fn kernel_density(&self, kernel: impl AsRef<[bool]>) -> Probabilities {
524        let kernel = kernel.as_ref();
525
526        assert!(!kernel.is_empty());
527
528        Probabilities {
529            probabilities: self
530                .probabilities
531                .windows(kernel.len())
532                .map(|v| {
533                    v.iter()
534                        .zip(kernel.iter())
535                        .map(|(value, kernel_val)| if *kernel_val { *value } else { 1. - *value })
536                        .reduce(|a, v| a * v)
537                        .unwrap()
538                })
539                .collect(),
540        }
541    }
542}
543
544#[cfg(test)]
545mod tests {
546    use alloc::vec;
547
548    use crate::pdf::QSinBias;
549
550    use super::{Pdf, qsin};
551
552    #[test]
553    #[should_panic(expected = "The difference between the start and end of the CDF is not one.")]
554    fn from_integral_not_zero_to_one() {
555        Pdf::from_integral(|v| v / 512., 256);
556    }
557
558    #[test]
559    #[should_panic(expected = "The CDF is not increasing.")]
560    fn from_integral_not_monotonic() {
561        Pdf::from_integral(|v| 2. * (v / 256.).powi(2) - v / 256., 256);
562    }
563
564    #[test]
565    #[should_panic(expected = "Cannot create a PDF of length zero.")]
566    fn from_integral_len_zero() {
567        Pdf::from_integral(|_| 0., 0);
568    }
569
570    #[test]
571    #[should_panic(expected = "The probabilities must sum to one.")]
572    fn from_discrete_not_sum_to_one() {
573        Pdf::from_discrete(vec![1., 2., 3.]);
574    }
575
576    #[test]
577    #[should_panic(expected = "The probabilities must all be positive.")]
578    fn from_discrete_not_all_positive() {
579        Pdf::from_discrete(vec![1., -1., 1.]);
580    }
581
582    #[test]
583    #[should_panic(expected = "Cannot create a PDF of length zero.")]
584    fn from_discrete_len_zero() {
585        Pdf::from_discrete(vec![]);
586    }
587
588    #[test]
589    fn continuous_integral_sample_lz() {
590        let pdf = qsin(256, QSinBias::Low, 3.);
591        assert_eq!(0., (pdf.continuous_integral())(-1.));
592    }
593
594    #[test]
595    fn continuous_integral_sample_gt_len() {
596        let pdf = qsin(256, QSinBias::Low, 3.);
597        assert_eq!(1., (pdf.continuous_integral())(257.));
598    }
599}