use crate::{
error::MathError,
rational::{MatQ, Q},
traits::{MatrixDimensions, MatrixGetEntry, MatrixSetEntry},
};
use probability::{
prelude::{Gaussian, Sample},
source,
};
use rand::Rng;
use std::fmt::Display;
impl MatQ {
pub fn sample_gauss(center: &MatQ, sigma: impl Into<f64>) -> Result<MatQ, MathError> {
let mut out = MatQ::new(center.get_num_rows(), center.get_num_columns());
let sigma = sigma.into();
for i in 0..out.get_num_rows() {
for j in 0..out.get_num_columns() {
let center_entry_ij = center.get_entry(i, j)?;
let sample = Q::sample_gauss(center_entry_ij, sigma)?;
unsafe { out.set_entry_unchecked(i, j, sample) };
}
}
Ok(out)
}
pub fn sample_gauss_same_center(
num_rows: impl TryInto<i64> + Display,
num_cols: impl TryInto<i64> + Display,
center: impl Into<Q>,
sigma: impl Into<f64>,
) -> Result<MatQ, MathError> {
let mut out = MatQ::new(num_rows, num_cols);
let (center, sigma) = (center.into(), sigma.into());
if sigma <= 0.0 {
return Err(MathError::NonPositive(format!(
"The sigma has to be positive and not zero, but the provided value is {sigma}."
)));
}
let mut rng = rand::rng();
let mut source = source::default(rng.next_u64());
let sampler = Gaussian::new(0.0, sigma);
for i in 0..out.get_num_rows() {
for j in 0..out.get_num_columns() {
let mut sample = Q::from(sampler.sample(&mut source));
sample += ¢er;
unsafe { out.set_entry_unchecked(i, j, sample) };
}
}
Ok(out)
}
}
#[cfg(test)]
mod test_sample_gauss {
use crate::{rational::MatQ, traits::MatrixDimensions};
#[test]
fn non_positive_sigma() {
let center = MatQ::new(5, 5);
for sigma in [0, -1] {
assert!(MatQ::sample_gauss(¢er, sigma).is_err())
}
}
#[test]
fn correct_dimension() {
for (x, y) in [(5, 5), (1, 10), (10, 1)] {
let center = MatQ::new(x, y);
let sample = MatQ::sample_gauss(¢er, 1).unwrap();
assert_eq!(center.get_num_rows(), sample.get_num_rows());
assert_eq!(center.get_num_columns(), sample.get_num_columns());
}
}
}
#[cfg(test)]
mod test_sample_gauss_same_center {
use crate::{rational::MatQ, traits::MatrixDimensions};
#[test]
fn non_positive_sigma() {
for sigma in [0, -1] {
assert!(MatQ::sample_gauss_same_center(5, 5, 0, sigma).is_err())
}
}
#[test]
fn correct_dimension() {
for (x, y) in [(5, 5), (1, 10), (10, 1)] {
let sample = MatQ::sample_gauss_same_center(x, y, 0, 1).unwrap();
assert_eq!(x, sample.get_num_rows());
assert_eq!(y, sample.get_num_columns());
}
}
#[test]
#[should_panic]
fn negative_number_rows() {
let _ = MatQ::sample_gauss_same_center(-1, 1, 0, 1).unwrap();
}
#[test]
#[should_panic]
fn negative_number_columns() {
let _ = MatQ::sample_gauss_same_center(1, -1, 0, 1).unwrap();
}
}