qfall_math/rational/mat_q/sample/gauss.rs
1// Copyright © 2024 Marvin Beckmann
2//
3// This file is part of qFALL-math.
4//
5// qFALL-math is free software: you can redistribute it and/or modify it under
6// the terms of the Mozilla Public License Version 2.0 as published by the
7// Mozilla Foundation. See <https://mozilla.org/en-US/MPL/2.0/>.
8
9//! This module contains sampling algorithms for Gaussian distributions over [`MatQ`].
10
11use crate::{
12 error::MathError,
13 rational::{MatQ, Q},
14 traits::{MatrixDimensions, MatrixGetEntry, MatrixSetEntry},
15};
16use probability::{
17 prelude::{Gaussian, Sample},
18 source,
19};
20use rand::RngCore;
21use std::fmt::Display;
22
23impl MatQ {
24 /// Chooses a [`MatQ`] instance according to the continuous Gaussian distribution.
25 /// Here, each entry is chosen according to the provided distribution.
26 ///
27 /// Parameters:
28 /// - `center`: specifies the center for each entry of the matrix individually
29 /// - `sigma`: specifies the standard deviation
30 ///
31 /// Returns new [`MatQ`] sample chosen according to the specified continuous Gaussian
32 /// distribution or a [`MathError`] if the specified parameters were not chosen
33 /// appropriately (`sigma > 0`).
34 ///
35 /// # Examples
36 /// ```
37 /// use qfall_math::rational::MatQ;
38 ///
39 /// let sample = MatQ::sample_gauss(&MatQ::new(5, 5), 1).unwrap();
40 /// ```
41 ///
42 /// # Errors and Failures
43 /// - Returns a [`MathError`] of type [`NonPositive`](MathError::NonPositive)
44 /// if `sigma <= 0`.
45 pub fn sample_gauss(center: &MatQ, sigma: impl Into<f64>) -> Result<MatQ, MathError> {
46 let mut out = MatQ::new(center.get_num_rows(), center.get_num_columns());
47 let sigma = sigma.into();
48
49 for i in 0..out.get_num_rows() {
50 for j in 0..out.get_num_columns() {
51 let center_entry_ij = center.get_entry(i, j)?;
52 let sample = Q::sample_gauss(center_entry_ij, sigma)?;
53 unsafe { out.set_entry_unchecked(i, j, sample) };
54 }
55 }
56
57 Ok(out)
58 }
59
60 /// Chooses a [`MatQ`] instance according to the continuous Gaussian distribution.
61 /// Here, each entry is chosen according to the provided distribution and each entry
62 /// is sampled with the same center.
63 ///
64 /// Parameters:
65 /// - `num_rows`: specifies the number of rows of the sampled matrix
66 /// - `num_cols`: specifies the number of columns of the sampled matrix
67 /// - `center`: specifies the same center for each entry of the matrix
68 /// - `sigma`: specifies the standard deviation
69 ///
70 /// Returns new [`MatQ`] sample chosen according to the specified continuous Gaussian
71 /// distribution or a [`MathError`] if the specified parameters were not chosen
72 /// appropriately (`sigma > 0`).
73 ///
74 /// # Examples
75 /// ```
76 /// use qfall_math::rational::{Q, MatQ};
77 ///
78 /// let center = Q::from((5,2));
79 ///
80 /// let sample = MatQ::sample_gauss_same_center(5, 5, ¢er, 1).unwrap();
81 /// ```
82 ///
83 /// # Errors and Failures
84 /// - Returns a [`MathError`] of type [`NonPositive`](MathError::NonPositive)
85 /// if `sigma <= 0`.
86 ///
87 /// # Panics ...
88 /// - if the number of rows or columns is negative, `0`, or does not fit into an [`i64`].
89 pub fn sample_gauss_same_center(
90 num_rows: impl TryInto<i64> + Display,
91 num_cols: impl TryInto<i64> + Display,
92 center: impl Into<Q>,
93 sigma: impl Into<f64>,
94 ) -> Result<MatQ, MathError> {
95 let mut out = MatQ::new(num_rows, num_cols);
96 let (center, sigma) = (center.into(), sigma.into());
97 if sigma <= 0.0 {
98 return Err(MathError::NonPositive(format!(
99 "The sigma has to be positive and not zero, but the provided value is {sigma}."
100 )));
101 }
102 let mut rng = rand::rng();
103 let mut source = source::default(rng.next_u64());
104
105 // Instead of sampling with a center of c, we sample with center 0 and add the
106 // center later. These are equivalent and this way we can sample in larger ranges
107 let sampler = Gaussian::new(0.0, sigma);
108
109 for i in 0..out.get_num_rows() {
110 for j in 0..out.get_num_columns() {
111 let mut sample = Q::from(sampler.sample(&mut source));
112 sample += ¢er;
113 unsafe { out.set_entry_unchecked(i, j, sample) };
114 }
115 }
116
117 Ok(out)
118 }
119}
120
121#[cfg(test)]
122mod test_sample_gauss {
123 use crate::{rational::MatQ, traits::MatrixDimensions};
124
125 /// Ensure that an error is returned if `sigma` is not positive
126 #[test]
127 fn non_positive_sigma() {
128 let center = MatQ::new(5, 5);
129 for sigma in [0, -1] {
130 assert!(MatQ::sample_gauss(¢er, sigma).is_err())
131 }
132 }
133
134 /// Ensure that the samples are of correct dimension
135 #[test]
136 fn correct_dimension() {
137 for (x, y) in [(5, 5), (1, 10), (10, 1)] {
138 let center = MatQ::new(x, y);
139 let sample = MatQ::sample_gauss(¢er, 1).unwrap();
140 assert_eq!(center.get_num_rows(), sample.get_num_rows());
141 assert_eq!(center.get_num_columns(), sample.get_num_columns());
142 }
143 }
144}
145
146#[cfg(test)]
147mod test_sample_gauss_same_center {
148
149 use crate::{rational::MatQ, traits::MatrixDimensions};
150
151 /// Ensure that an error is returned if `sigma` is not positive
152 #[test]
153 fn non_positive_sigma() {
154 for sigma in [0, -1] {
155 assert!(MatQ::sample_gauss_same_center(5, 5, 0, sigma).is_err())
156 }
157 }
158
159 /// Ensure that the samples are of correct dimension
160 #[test]
161 fn correct_dimension() {
162 for (x, y) in [(5, 5), (1, 10), (10, 1)] {
163 let sample = MatQ::sample_gauss_same_center(x, y, 0, 1).unwrap();
164 assert_eq!(x, sample.get_num_rows());
165 assert_eq!(y, sample.get_num_columns());
166 }
167 }
168
169 /// Ensure that a negative number of rows causes a panic
170 #[test]
171 #[should_panic]
172 fn negative_number_rows() {
173 let _ = MatQ::sample_gauss_same_center(-1, 1, 0, 1).unwrap();
174 }
175
176 /// Ensure that a negative number of columns causes a panic
177 #[test]
178 #[should_panic]
179 fn negative_number_columns() {
180 let _ = MatQ::sample_gauss_same_center(1, -1, 0, 1).unwrap();
181 }
182}