nmr_schedule/pdf/mod.rs
1//! Implements Probability Distribution Functions
2//!
3//! These are used to indicate to schedule generators where samples should be taken. This module contains the core PDF implementation as well as some presets representing common PDFs.
4
5use core::ops::{Deref, Range};
6
7use alloc::{borrow::ToOwned, boxed::Box, rc::Rc, sync::Arc, vec::Vec};
8use ndarray::{Dim, Dimension, Ix};
9use once_cell::race::OnceBox;
10use rand::Rng;
11
12mod presets;
13
14pub use presets::*;
15
16use crate::{InitialState, Monotonicity, find_zero};
17
18/// Represents a Probability Distribution Function.
19///
20/// PDFs are used to indicate to schedule generators where samples should be taken.
21pub struct Pdf {
22 inner: PdfImpl,
23 // Memoized results of `self.get_distribution()`, `self.get_integral()`, and `self.probabilities()`
24 distribution: OnceBox<Vec<f64>>, // Only used in the continuous representation
25 integral: OnceBox<Vec<f64>>,
26 probabilities: OnceBox<Probabilities>,
27}
28
29impl core::fmt::Debug for Pdf {
30 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
31 match &self.inner {
32 PdfImpl::Continuous { len, integral: _ } => f
33 .debug_struct("Pdf::Continuous")
34 .field("len", len)
35 .field("integral", &"*closure*")
36 .finish(),
37 PdfImpl::Discrete { values } => f
38 .debug_struct("Pdf::Discrete")
39 .field("values", values)
40 .finish(),
41 }
42 }
43}
44
45/// The opaque inner representation of a PDF
46enum PdfImpl {
47 Continuous {
48 len: usize,
49 integral: Arc<dyn Fn(f64) -> f64 + Send + Sync>,
50 },
51 Discrete {
52 values: Vec<f64>,
53 },
54}
55
56impl Pdf {
57 /// Create a PDF from a continuous representation.
58 ///
59 /// Expects the *integral* of the PDF being represented. This is often called the Cumulative Distribution Function or CDF. The function will receive inputs from `0..=len`. The `len` parameter is the length of the PDF in number of samples.
60 ///
61 /// The CDF is expected to be monotonically increasing, and `integral(len as f64) - integral(0.)` must be one. These are the necessary conditions for a function to be a valid CDF.
62 ///
63 /// It is not assumed that `integral(0.) = 0.`.
64 ///
65 /// # Example
66 ///
67 /// ```
68 /// # use nmr_schedule::pdf::*;
69 /// // This represents an unweighted PDF because the integral of a constant function is linear.
70 /// // You can use the `unweighted` preset instead if you would like
71 /// let pdf = Pdf::from_integral(|v| v / 256., 256);
72 /// ```
73 ///
74 /// # Panics
75 ///
76 /// This function will assert that
77 /// - `integral(len as f64) - integral(0.) ≈ 1.` with some margin for floating point error allowed;
78 /// - the CDF is monotonically increasing for particular intervals; and
79 /// - `len` is non-zero.
80 pub fn from_integral(integral: impl Fn(f64) -> f64 + 'static + Send + Sync, len: usize) -> Pdf {
81 assert!(len > 0, "Cannot create a PDF of length zero.");
82 assert!(
83 (integral(len as f64) - integral(0.) - 1.).abs() < 0.001,
84 "The difference between the start and end of the CDF is not one."
85 );
86 assert!(
87 (0..len).all(|i| integral(i as f64) <= integral(i as f64 + 1.)),
88 "The CDF is not increasing."
89 );
90
91 Pdf {
92 inner: PdfImpl::Continuous {
93 len,
94 integral: Arc::from(integral),
95 },
96 distribution: OnceBox::new(),
97 integral: OnceBox::new(),
98 probabilities: OnceBox::new(),
99 }
100 }
101
102 /// Create a PDF from a discrete representation.
103 ///
104 /// Expects a list of probabilities, one for each sample, that are all non-negative and sum to one.
105 ///
106 /// # Example
107 ///
108 /// ```
109 /// # use nmr_schedule::pdf::*;
110 /// let pdf = Pdf::from_discrete(vec![0.125, 0.5, 0.375]);
111 /// ```
112 ///
113 /// # Panics
114 ///
115 /// This function will assert that
116 /// - the sum of all of the probabilities is one, with some margin for floating point error allowed;
117 /// - all probabilities are non-negative; and
118 /// - the length of the PDF is non-zero.
119 pub fn from_discrete(discrete: Vec<f64>) -> Pdf {
120 assert!(!discrete.is_empty(), "Cannot create a PDF of length zero.");
121 assert!(
122 (discrete.iter().sum::<f64>() - 1.).abs() < 0.000000001,
123 "The probabilities must sum to one."
124 );
125 assert!(
126 discrete.iter().all(|v| *v >= 0.),
127 "The probabilities must all be positive."
128 );
129
130 Pdf {
131 inner: PdfImpl::Discrete { values: discrete },
132 distribution: OnceBox::new(),
133 integral: OnceBox::new(),
134 probabilities: OnceBox::new(),
135 }
136 }
137
138 /// Calculate the discrete representation of the PDF.
139 ///
140 /// This method is memoized.
141 ///
142 /// # Example
143 ///
144 /// ```
145 /// # use nmr_schedule::pdf::*;
146 /// assert!(
147 /// unweighted(3)
148 /// .get_distribution()
149 /// .iter()
150 /// .zip([1. / 3., 1. / 3., 1. / 3.])
151 /// .all(|(l, r)| (l - r).abs() < 0.00001),
152 /// );
153 /// ```
154 pub fn get_distribution(&self) -> &[f64] {
155 match &self.inner {
156 PdfImpl::Continuous { len, integral } => self.distribution.get_or_init(|| {
157 Box::new(
158 (0..*len)
159 .map(|pos| integral(pos as f64 + 1.) - integral(pos as f64))
160 .collect(),
161 )
162 }),
163 PdfImpl::Discrete { values } => values,
164 }
165 }
166
167 /// Calculate the integral of the PDF for whole number values
168 ///
169 /// The output has length `self.len() + 1`. The value at index `0` is guaranteed to be zero and the output at index `self.len()` is guaranteed to be one with a small margin for floating point error.
170 ///
171 /// This method is memoized.
172 ///
173 /// # Example
174 ///
175 /// ```
176 /// # use nmr_schedule::pdf::*;
177 /// assert!(
178 /// Pdf::from_discrete(vec![0.3, 0.6, 0.1])
179 /// .get_integral()
180 /// .iter()
181 /// .zip([0., 0.3, 0.9, 1.])
182 /// .all(|(l, r)| (l - r).abs() < 0.00001),
183 /// );
184 /// ```
185 #[allow(clippy::missing_panics_doc)] // Panics are bugs
186 pub fn get_integral(&self) -> &[f64] {
187 self.integral.get_or_init(|| {
188 let distribution = self.get_distribution();
189
190 let mut sum = 0.;
191
192 let mut integral = distribution
193 .iter()
194 .map(|v| {
195 let prev = sum;
196 sum += v;
197 prev
198 })
199 .collect::<Vec<_>>();
200
201 integral.push(sum);
202
203 debug_assert!(*integral.first().unwrap() == 0.);
204 debug_assert!((integral.last().unwrap() - 1.).abs() < 0.00001);
205
206 Box::new(integral)
207 })
208 }
209
210 /// Slice a PDF to fit a certain range
211 ///
212 /// The PDF will be automatically rescaled to still sum to one.
213 ///
214 /// # Example
215 ///
216 /// ```
217 /// # use nmr_schedule::pdf::*;
218 /// assert!(
219 /// exponential(256, 4.)
220 /// .slice(0..128)
221 /// .get_distribution()
222 /// .iter()
223 /// .zip(exponential(128, 2.)
224 /// .get_distribution()
225 /// )
226 /// .all(|(l, r)| (l - r).abs() < 0.00001),
227 /// );
228 /// ```
229 pub fn slice(&self, range: Range<usize>) -> Pdf {
230 match &self.inner {
231 PdfImpl::Continuous { len: _, integral } => {
232 let scale = 1. / (integral(range.end as f64) - integral(range.start as f64));
233
234 let integral_for_closure = Arc::clone(integral);
235
236 Pdf::from_integral(
237 move |v| integral_for_closure(v - range.start as f64) * scale,
238 range.len(),
239 )
240 }
241 PdfImpl::Discrete { values } => {
242 let mut sliced = values[range].to_owned();
243 let sum = sliced.iter().sum::<f64>();
244 sliced.iter_mut().for_each(|v| *v /= sum);
245 Pdf::from_discrete(sliced)
246 }
247 }
248 }
249
250 /// Returns the length of the PDF in number of samples.
251 ///
252 /// # Example
253 ///
254 /// ```
255 /// # use nmr_schedule::pdf::*;
256 /// assert_eq!(linear(256).len(), 256);
257 /// ```
258 #[allow(clippy::len_without_is_empty)] // A PDF cannot be empty
259 pub fn len(&self) -> usize {
260 match &self.inner {
261 PdfImpl::Continuous { len, integral: _ } => *len,
262 PdfImpl::Discrete { values } => values.len(),
263 }
264 }
265
266 /// Calculate the probabilities of selecting each sample position under random sampling given the number of samples to select.
267 ///
268 /// # Example
269 /// ```
270 /// # use nmr_schedule::pdf::*;
271 /// // With unweighted sampling and selecting 128 out of 256,
272 /// // each sample position has a 1/2 chance of being selected.
273 /// assert!(
274 /// unweighted(256)
275 /// .probabilities(128)
276 /// .iter()
277 /// .all(|v| (*v - 0.5).abs() < 0.000001)
278 /// );
279 /// ```
280 ///
281 /// Note that this only gives a very good approximation because the true value is (as of now) computationally infeasible to find.
282 ///
283 /// This method is memoized.
284 #[allow(clippy::missing_panics_doc)] // Panicking is a bug
285 pub fn probabilities(&self, count: usize) -> &Probabilities {
286 self.probabilities.get_or_init(|| {
287 let (_power, probabilities) = find_zero(
288 Monotonicity::Increasing,
289 InitialState::new(count as f64, 0., f64::INFINITY),
290 crate::Precision::Image {
291 amount: 0.0000001,
292 debug: (&|prob1, prob2| -> () { panic!("{prob1:?}\n{prob2:?}") }) as &_,
293 },
294 |power| {
295 let probabilities = self
296 .get_distribution()
297 .iter()
298 .map(move |v| 1. - (1. - v).powf(power))
299 .collect::<Vec<_>>();
300
301 (
302 probabilities.iter().sum::<f64>() - (count as f64),
303 Probabilities { probabilities },
304 )
305 },
306 );
307
308 Box::new(probabilities)
309 })
310 }
311
312 /// Sample the PDF using the given `Rng`
313 ///
314 /// Returns the zero-based index of the position that was sampled.
315 pub fn sample_pdf<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
316 let choice = rng.random();
317
318 debug_assert!((0. ..1.).contains(&choice));
319
320 match self.get_integral()[1..].binary_search_by(|v| v.total_cmp(&choice)) {
321 Ok(v) => v,
322 Err(v) => v,
323 }
324 }
325
326 /// Return a continuous representation of the integral of the PDF. Applying `0.` to the returned function will return `0.`.
327 ///
328 /// This function is implemented for discretely represented PDFs by interpolating `self.get_integral()` using Catmull-Rom to give a once-differentiable interpolation.
329 ///
330 /// # Example
331 ///
332 /// ```
333 /// # use core::f64::consts::*;
334 /// # use nmr_schedule::pdf::*;
335 /// let pdf = qsin(256, QSinBias::Low, PI);
336 /// let integral = pdf.continuous_integral();
337 ///
338 /// assert_eq!(integral(0.), 0.);
339 /// assert_eq!(integral(256.), 1.);
340 /// let c = FRAC_PI_2 - 1.;
341 /// assert_eq!(integral(128.), (FRAC_PI_4 + FRAC_PI_4.cos()) / c - 1. / c);
342 /// ```
343 ///
344 /// # Panics
345 ///
346 /// The function returned will panic if it is called with values outside of the range of the PDF.
347 pub fn continuous_integral<'a>(&'a self) -> Rc<dyn Fn(f64) -> f64 + 'a> {
348 fn if_out_of_range(v: f64, len: usize) -> Option<f64> {
349 if v < 0. {
350 return Some(0.);
351 }
352
353 if v > len as f64 {
354 return Some(1.);
355 }
356
357 None
358 }
359
360 match &self.inner {
361 PdfImpl::Continuous { len, integral } => {
362 let integral_for_closure = Arc::clone(integral);
363 let start_value = integral_for_closure(0.);
364 Rc::new(move |v| {
365 if let Some(v) = if_out_of_range(v, *len) {
366 return v;
367 }
368
369 integral_for_closure(v) - start_value
370 })
371 }
372 PdfImpl::Discrete { values: _ } => {
373 let integral = self.get_integral();
374 Rc::from(|v: f64| {
375 if let Some(v) = if_out_of_range(v, integral.len() - 1) {
376 return v;
377 }
378
379 let idx_below = v.floor() as usize;
380
381 let output_vector = [
382 if idx_below == 0 {
383 0.
384 } else {
385 integral[idx_below - 1]
386 },
387 integral[idx_below],
388 *integral.get(idx_below + 1).unwrap_or(&1.),
389 *integral.get(idx_below + 2).unwrap_or(&1.),
390 ];
391
392 let time = v.fract();
393
394 let time_vector = [1., time, time * time, time * time * time];
395
396 // Interpolate using Catmull-Rom (thanks freya! (https://youtu.be/jvPPXbo87ds?t=2967))
397
398 const SPLINE: [[f64; 4]; 4] = [
399 [0., 1., 0., 0.],
400 [-0.5, 0., 0.5, 0.],
401 [1., -2.5, 2., -0.5],
402 [-0.5, 1.5, -1.5, 0.5],
403 ];
404
405 SPLINE
406 .iter()
407 .map(|v| v.iter().zip(output_vector).map(|(a, b)| a * b).sum::<f64>())
408 .zip(time_vector)
409 .map(|(a, b)| a * b)
410 .sum()
411 })
412 }
413 }
414 }
415}
416
417/// A value that can generate PDFs of arbitrary dimensions.
418///
419/// There exists a blanket implementation for closures that input a length and return a PDF of that length. The closure will be called once per dimension.
420pub trait PdfGenerator<Dim: Dimension> {
421 /// Get a PDF with the requested dimensions; returns one 1D PDF per dimension.
422 ///
423 /// Asserts that the implementation gave a PDF of the correct dimensions.
424 ///
425 /// Implementors should not override this method and instead implement `_get_unchecked`.
426 fn get(&self, dim: Dim) -> Vec<Pdf> {
427 let pdfs = self._get_unchecked(dim.to_owned());
428
429 assert_eq!(
430 pdfs.len(),
431 dim.ndim(),
432 "The number of dimensions of the PDF are different from the dimensions specified."
433 );
434
435 for (i, pdf) in pdfs.iter().enumerate() {
436 assert_eq!(
437 dim[i],
438 pdf.len(),
439 "The length of dimension {i} is different from what was specified."
440 );
441 }
442
443 pdfs
444 }
445
446 /// Return one 1D PDF for each dimension requested. Callers should not call this function directly and instead call `get`.
447 fn _get_unchecked(&self, dim: Dim) -> Vec<Pdf>;
448}
449
450impl<Dim: Dimension, F: Fn(usize) -> Pdf> PdfGenerator<Dim> for F {
451 fn _get_unchecked(&self, dim: Dim) -> Vec<Pdf> {
452 let mut out = Vec::with_capacity(dim.ndim());
453
454 for i in 0..dim.ndim() {
455 out.push(self(dim[i]));
456 }
457
458 out
459 }
460}
461
462impl<Dim: Dimension> PdfGenerator<Dim> for fn(Dim) -> Vec<Pdf> {
463 fn _get_unchecked(&self, dim: Dim) -> Vec<Pdf> {
464 self(dim)
465 }
466}
467
468impl<const N: usize> PdfGenerator<Dim<[Ix; N]>> for [fn(usize) -> Pdf; N]
469where
470 Dim<[Ix; N]>: Dimension,
471{
472 fn _get_unchecked(&self, dim: Dim<[Ix; N]>) -> Vec<Pdf> {
473 let mut out = Vec::with_capacity(N);
474
475 for (i, f) in self.iter().enumerate() {
476 out.push((f)(dim[i]));
477 }
478
479 out
480 }
481}
482
483/// A list of probabilities, each one associated with a sample.
484///
485/// May represent the probability of selecting a sample or of a pattern being present at a given position.
486#[derive(Debug, Clone)]
487pub struct Probabilities {
488 probabilities: Vec<f64>,
489}
490
491impl Deref for Probabilities {
492 type Target = [f64];
493
494 fn deref(&self) -> &Self::Target {
495 &self.probabilities
496 }
497}
498
499impl Probabilities {
500 /// Returns the probability of a kernel (sampling pattern) appearing at any given position assuming samples are selected independently.
501 ///
502 /// The independence condition is not valid for most NUS algorithms, but this method will still give an idea of where patterns are most likely to show up.
503 ///
504 /// # Example
505 ///
506 /// ```
507 /// # use nmr_schedule::pdf::*;
508 /// # use core::f64::consts::*;
509 /// let pdf = qsin(256, QSinBias::Low, PI);
510 /// let probabilities = pdf.probabilities(64);
511 /// let kdf = probabilities.kernel_density([true, false, true]);
512 ///
513 /// // The [true, false, true] pattern can't possibly show up at the
514 /// // second to last position since there are only two samples left
515 /// // which isn't enough room to fit a pattern of length 3.
516 /// assert_eq!(kdf.len(), 254);
517 /// assert!((kdf[32] - 0.13).abs() < 0.01); // Pattern is relatively likely to appear at the start
518 /// assert!(kdf[250] < 0.001); // Pattern is relatively unlikely to appear at the end
519 /// ```
520 ///
521 /// # Panics
522 /// The function will panic if the length of the kernel is zero.
523 pub fn kernel_density(&self, kernel: impl AsRef<[bool]>) -> Probabilities {
524 let kernel = kernel.as_ref();
525
526 assert!(!kernel.is_empty());
527
528 Probabilities {
529 probabilities: self
530 .probabilities
531 .windows(kernel.len())
532 .map(|v| {
533 v.iter()
534 .zip(kernel.iter())
535 .map(|(value, kernel_val)| if *kernel_val { *value } else { 1. - *value })
536 .reduce(|a, v| a * v)
537 .unwrap()
538 })
539 .collect(),
540 }
541 }
542}
543
544#[cfg(test)]
545mod tests {
546 use alloc::vec;
547
548 use crate::pdf::QSinBias;
549
550 use super::{Pdf, qsin};
551
552 #[test]
553 #[should_panic(expected = "The difference between the start and end of the CDF is not one.")]
554 fn from_integral_not_zero_to_one() {
555 Pdf::from_integral(|v| v / 512., 256);
556 }
557
558 #[test]
559 #[should_panic(expected = "The CDF is not increasing.")]
560 fn from_integral_not_monotonic() {
561 Pdf::from_integral(|v| 2. * (v / 256.).powi(2) - v / 256., 256);
562 }
563
564 #[test]
565 #[should_panic(expected = "Cannot create a PDF of length zero.")]
566 fn from_integral_len_zero() {
567 Pdf::from_integral(|_| 0., 0);
568 }
569
570 #[test]
571 #[should_panic(expected = "The probabilities must sum to one.")]
572 fn from_discrete_not_sum_to_one() {
573 Pdf::from_discrete(vec![1., 2., 3.]);
574 }
575
576 #[test]
577 #[should_panic(expected = "The probabilities must all be positive.")]
578 fn from_discrete_not_all_positive() {
579 Pdf::from_discrete(vec![1., -1., 1.]);
580 }
581
582 #[test]
583 #[should_panic(expected = "Cannot create a PDF of length zero.")]
584 fn from_discrete_len_zero() {
585 Pdf::from_discrete(vec![]);
586 }
587
588 #[test]
589 fn continuous_integral_sample_lz() {
590 let pdf = qsin(256, QSinBias::Low, 3.);
591 assert_eq!(0., (pdf.continuous_integral())(-1.));
592 }
593
594 #[test]
595 fn continuous_integral_sample_gt_len() {
596 let pdf = qsin(256, QSinBias::Low, 3.);
597 assert_eq!(1., (pdf.continuous_integral())(257.));
598 }
599}