qfall_math/rational/mat_q/sample/
gauss.rs

1// Copyright © 2024 Marvin Beckmann
2//
3// This file is part of qFALL-math.
4//
5// qFALL-math is free software: you can redistribute it and/or modify it under
6// the terms of the Mozilla Public License Version 2.0 as published by the
7// Mozilla Foundation. See <https://mozilla.org/en-US/MPL/2.0/>.
8
9//! This module contains sampling algorithms for Gaussian distributions over [`MatQ`].
10
11use crate::{
12    error::MathError,
13    rational::{MatQ, Q},
14    traits::{MatrixDimensions, MatrixGetEntry, MatrixSetEntry},
15};
16use probability::{
17    prelude::{Gaussian, Sample},
18    source,
19};
20use rand::RngCore;
21use std::fmt::Display;
22
23impl MatQ {
24    /// Chooses a [`MatQ`] instance according to the continuous Gaussian distribution.
25    /// Here, each entry is chosen according to the provided distribution.
26    ///
27    /// Parameters:
28    /// - `center`: specifies the center for each entry of the matrix individually
29    /// - `sigma`: specifies the standard deviation
30    ///
31    /// Returns new [`MatQ`] sample chosen according to the specified continuous Gaussian
32    /// distribution or a [`MathError`] if the specified parameters were not chosen
33    /// appropriately (`sigma > 0`).
34    ///
35    /// # Examples
36    /// ```
37    /// use qfall_math::rational::MatQ;
38    ///
39    /// let sample = MatQ::sample_gauss(&MatQ::new(5, 5), 1).unwrap();
40    /// ```
41    ///
42    /// # Errors and Failures
43    /// - Returns a [`MathError`] of type [`NonPositive`](MathError::NonPositive)
44    ///   if `sigma <= 0`.
45    pub fn sample_gauss(center: &MatQ, sigma: impl Into<f64>) -> Result<MatQ, MathError> {
46        let mut out = MatQ::new(center.get_num_rows(), center.get_num_columns());
47        let sigma = sigma.into();
48
49        for i in 0..out.get_num_rows() {
50            for j in 0..out.get_num_columns() {
51                let center_entry_ij = center.get_entry(i, j)?;
52                let sample = Q::sample_gauss(center_entry_ij, sigma)?;
53                unsafe { out.set_entry_unchecked(i, j, sample) };
54            }
55        }
56
57        Ok(out)
58    }
59
60    /// Chooses a [`MatQ`] instance according to the continuous Gaussian distribution.
61    /// Here, each entry is chosen according to the provided distribution and each entry
62    /// is sampled with the same center.
63    ///
64    /// Parameters:
65    /// - `num_rows`: specifies the number of rows of the sampled matrix
66    /// - `num_cols`: specifies the number of columns of the sampled matrix
67    /// - `center`: specifies the same center for each entry of the matrix
68    /// - `sigma`: specifies the standard deviation
69    ///
70    /// Returns new [`MatQ`] sample chosen according to the specified continuous Gaussian
71    /// distribution or a [`MathError`] if the specified parameters were not chosen
72    /// appropriately (`sigma > 0`).
73    ///
74    /// # Examples
75    /// ```
76    /// use qfall_math::rational::{Q, MatQ};
77    ///
78    /// let center = Q::from((5,2));
79    ///
80    /// let sample = MatQ::sample_gauss_same_center(5, 5, &center, 1).unwrap();
81    /// ```
82    ///
83    /// # Errors and Failures
84    /// - Returns a [`MathError`] of type [`NonPositive`](MathError::NonPositive)
85    ///   if `sigma <= 0`.
86    ///
87    /// # Panics ...
88    /// - if the number of rows or columns is negative, `0`, or does not fit into an [`i64`].
89    pub fn sample_gauss_same_center(
90        num_rows: impl TryInto<i64> + Display,
91        num_cols: impl TryInto<i64> + Display,
92        center: impl Into<Q>,
93        sigma: impl Into<f64>,
94    ) -> Result<MatQ, MathError> {
95        let mut out = MatQ::new(num_rows, num_cols);
96        let (center, sigma) = (center.into(), sigma.into());
97        if sigma <= 0.0 {
98            return Err(MathError::NonPositive(format!(
99                "The sigma has to be positive and not zero, but the provided value is {sigma}."
100            )));
101        }
102        let mut rng = rand::rng();
103        let mut source = source::default(rng.next_u64());
104
105        // Instead of sampling with a center of c, we sample with center 0 and add the
106        // center later. These are equivalent and this way we can sample in larger ranges
107        let sampler = Gaussian::new(0.0, sigma);
108
109        for i in 0..out.get_num_rows() {
110            for j in 0..out.get_num_columns() {
111                let mut sample = Q::from(sampler.sample(&mut source));
112                sample += &center;
113                unsafe { out.set_entry_unchecked(i, j, sample) };
114            }
115        }
116
117        Ok(out)
118    }
119}
120
121#[cfg(test)]
122mod test_sample_gauss {
123    use crate::{rational::MatQ, traits::MatrixDimensions};
124
125    /// Ensure that an error is returned if `sigma` is not positive
126    #[test]
127    fn non_positive_sigma() {
128        let center = MatQ::new(5, 5);
129        for sigma in [0, -1] {
130            assert!(MatQ::sample_gauss(&center, sigma).is_err())
131        }
132    }
133
134    /// Ensure that the samples are of correct dimension
135    #[test]
136    fn correct_dimension() {
137        for (x, y) in [(5, 5), (1, 10), (10, 1)] {
138            let center = MatQ::new(x, y);
139            let sample = MatQ::sample_gauss(&center, 1).unwrap();
140            assert_eq!(center.get_num_rows(), sample.get_num_rows());
141            assert_eq!(center.get_num_columns(), sample.get_num_columns());
142        }
143    }
144}
145
146#[cfg(test)]
147mod test_sample_gauss_same_center {
148
149    use crate::{rational::MatQ, traits::MatrixDimensions};
150
151    /// Ensure that an error is returned if `sigma` is not positive
152    #[test]
153    fn non_positive_sigma() {
154        for sigma in [0, -1] {
155            assert!(MatQ::sample_gauss_same_center(5, 5, 0, sigma).is_err())
156        }
157    }
158
159    /// Ensure that the samples are of correct dimension
160    #[test]
161    fn correct_dimension() {
162        for (x, y) in [(5, 5), (1, 10), (10, 1)] {
163            let sample = MatQ::sample_gauss_same_center(x, y, 0, 1).unwrap();
164            assert_eq!(x, sample.get_num_rows());
165            assert_eq!(y, sample.get_num_columns());
166        }
167    }
168
169    /// Ensure that a negative number of rows causes a panic
170    #[test]
171    #[should_panic]
172    fn negative_number_rows() {
173        let _ = MatQ::sample_gauss_same_center(-1, 1, 0, 1).unwrap();
174    }
175
176    /// Ensure that a negative number of columns causes a panic
177    #[test]
178    #[should_panic]
179    fn negative_number_columns() {
180        let _ = MatQ::sample_gauss_same_center(1, -1, 0, 1).unwrap();
181    }
182}