use crate::error::{RandError, RandResult};
struct SplitMix64 {
state: u64,
}
impl SplitMix64 {
fn new(seed: u64) -> Self {
Self { state: seed }
}
fn next_u64(&mut self) -> u64 {
self.state = self.state.wrapping_add(0x9E37_79B9_7F4A_7C15);
let mut z = self.state;
z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
z ^ (z >> 31)
}
fn next_f64(&mut self) -> f64 {
(self.next_u64() >> 11) as f64 / ((1u64 << 53) as f64)
}
fn next_normal_pair(&mut self) -> (f64, f64) {
loop {
let u1 = self.next_f64();
let u2 = self.next_f64();
if u1 > 0.0 {
let r = (-2.0 * u1.ln()).sqrt();
let theta = 2.0 * std::f64::consts::PI * u2;
return (r * theta.cos(), r * theta.sin());
}
}
}
fn next_normal(&mut self) -> f64 {
self.next_normal_pair().0
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum MatrixLayout {
RowMajor,
ColMajor,
}
impl std::fmt::Display for MatrixLayout {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::RowMajor => write!(f, "RowMajor"),
Self::ColMajor => write!(f, "ColMajor"),
}
}
}
#[derive(Debug, Clone)]
pub struct RandomMatrix {
rows: usize,
cols: usize,
data: Vec<f64>,
layout: MatrixLayout,
}
impl RandomMatrix {
pub fn new(rows: usize, cols: usize, data: Vec<f64>, layout: MatrixLayout) -> RandResult<Self> {
if data.len() != rows * cols {
return Err(RandError::InvalidSize(format!(
"data length {} does not match {}x{} = {}",
data.len(),
rows,
cols,
rows * cols
)));
}
Ok(Self {
rows,
cols,
data,
layout,
})
}
pub fn zeros(rows: usize, cols: usize, layout: MatrixLayout) -> Self {
Self {
rows,
cols,
data: vec![0.0; rows * cols],
layout,
}
}
pub fn identity(n: usize, layout: MatrixLayout) -> RandResult<Self> {
if n == 0 {
return Err(RandError::InvalidSize(
"identity matrix dimension must be positive".to_string(),
));
}
let mut data = vec![0.0; n * n];
for i in 0..n {
data[i * n + i] = 1.0;
}
Ok(Self {
rows: n,
cols: n,
data,
layout,
})
}
pub fn rows(&self) -> usize {
self.rows
}
pub fn cols(&self) -> usize {
self.cols
}
pub fn layout(&self) -> MatrixLayout {
self.layout
}
pub fn data(&self) -> &[f64] {
&self.data
}
pub fn data_mut(&mut self) -> &mut [f64] {
&mut self.data
}
pub fn into_data(self) -> Vec<f64> {
self.data
}
pub fn get(&self, i: usize, j: usize) -> f64 {
match self.layout {
MatrixLayout::RowMajor => self.data[i * self.cols + j],
MatrixLayout::ColMajor => self.data[j * self.rows + i],
}
}
pub fn set(&mut self, i: usize, j: usize, value: f64) {
match self.layout {
MatrixLayout::RowMajor => self.data[i * self.cols + j] = value,
MatrixLayout::ColMajor => self.data[j * self.rows + i] = value,
}
}
pub fn is_square(&self) -> bool {
self.rows == self.cols
}
pub fn frobenius_norm(&self) -> f64 {
self.data.iter().map(|x| x * x).sum::<f64>().sqrt()
}
pub fn transpose(&self) -> Self {
let data = transpose(&self.data, self.rows, self.cols);
Self {
rows: self.cols,
cols: self.rows,
data,
layout: self.layout,
}
}
}
impl std::fmt::Display for RandomMatrix {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"RandomMatrix({}x{}, {})",
self.rows, self.cols, self.layout
)
}
}
pub fn transpose(matrix: &[f64], rows: usize, cols: usize) -> Vec<f64> {
let mut result = vec![0.0; rows * cols];
for i in 0..rows {
for j in 0..cols {
result[j * rows + i] = matrix[i * cols + j];
}
}
result
}
pub fn matrix_multiply(a: &[f64], b: &[f64], m: usize, n: usize, k: usize) -> Vec<f64> {
let mut c = vec![0.0; m * n];
for i in 0..m {
for p in 0..k {
let a_ip = a[i * k + p];
for j in 0..n {
c[i * n + j] += a_ip * b[p * n + j];
}
}
}
c
}
pub fn cholesky_decompose(matrix: &[f64], n: usize) -> RandResult<Vec<f64>> {
if matrix.len() != n * n {
return Err(RandError::InvalidSize(format!(
"expected {}x{} = {} elements, got {}",
n,
n,
n * n,
matrix.len()
)));
}
let mut l = vec![0.0; n * n];
for j in 0..n {
let mut sum = 0.0;
for k in 0..j {
sum += l[j * n + k] * l[j * n + k];
}
let diag = matrix[j * n + j] - sum;
if diag <= 0.0 {
return Err(RandError::InternalError(format!(
"Cholesky decomposition failed: matrix is not positive definite \
(diagonal element {} became {:.6e})",
j, diag
)));
}
l[j * n + j] = diag.sqrt();
for i in (j + 1)..n {
let mut sum = 0.0;
for k in 0..j {
sum += l[i * n + k] * l[j * n + k];
}
l[i * n + j] = (matrix[i * n + j] - sum) / l[j * n + j];
}
}
Ok(l)
}
pub struct GaussianMatrixGenerator;
impl GaussianMatrixGenerator {
pub fn generate(rows: usize, cols: usize, mean: f64, stddev: f64, seed: u64) -> RandomMatrix {
let mut rng = SplitMix64::new(seed);
let total = rows * cols;
let mut data = Vec::with_capacity(total);
let mut i = 0;
while i + 1 < total {
let (z0, z1) = rng.next_normal_pair();
data.push(mean + stddev * z0);
data.push(mean + stddev * z1);
i += 2;
}
if data.len() < total {
let z = rng.next_normal();
data.push(mean + stddev * z);
}
RandomMatrix {
rows,
cols,
data,
layout: MatrixLayout::RowMajor,
}
}
}
pub struct WishartGenerator;
impl WishartGenerator {
pub fn generate(dim: usize, dof: usize, scale: &[f64], seed: u64) -> RandResult<RandomMatrix> {
if dim == 0 {
return Err(RandError::InvalidSize(
"Wishart dimension must be positive".to_string(),
));
}
if dof < dim {
return Err(RandError::InvalidSize(format!(
"degrees of freedom ({dof}) must be >= dimension ({dim}) for positive definiteness"
)));
}
if scale.len() != dim * dim {
return Err(RandError::InvalidSize(format!(
"scale matrix must have {} elements, got {}",
dim * dim,
scale.len()
)));
}
let l = cholesky_decompose(scale, dim)?;
let z = GaussianMatrixGenerator::generate(dof, dim, 0.0, 1.0, seed);
let lt = transpose(&l, dim, dim);
let x = matrix_multiply(z.data(), <, dof, dim, dim);
let xt = transpose(&x, dof, dim);
let w = matrix_multiply(&xt, &x, dim, dim, dof);
RandomMatrix::new(dim, dim, w, MatrixLayout::RowMajor)
}
}
pub struct OrthogonalMatrixGenerator;
impl OrthogonalMatrixGenerator {
pub fn generate(dim: usize, seed: u64) -> RandomMatrix {
if dim == 0 {
return RandomMatrix {
rows: 0,
cols: 0,
data: Vec::new(),
layout: MatrixLayout::RowMajor,
};
}
if dim == 1 {
return RandomMatrix {
rows: 1,
cols: 1,
data: vec![1.0],
layout: MatrixLayout::RowMajor,
};
}
let a = GaussianMatrixGenerator::generate(dim, dim, 0.0, 1.0, seed);
let q = modified_gram_schmidt(a.data(), dim);
RandomMatrix {
rows: dim,
cols: dim,
data: q,
layout: MatrixLayout::RowMajor,
}
}
}
fn modified_gram_schmidt(a: &[f64], n: usize) -> Vec<f64> {
let mut cols: Vec<Vec<f64>> = (0..n)
.map(|j| (0..n).map(|i| a[i * n + j]).collect())
.collect();
for j in 0..n {
let norm = cols[j].iter().map(|x| x * x).sum::<f64>().sqrt();
if norm > 1e-15 {
for elem in &mut cols[j] {
*elem /= norm;
}
}
for k in (j + 1)..n {
let dot: f64 = (0..n).map(|i| cols[j][i] * cols[k][i]).sum();
let col_j_copy: Vec<f64> = cols[j].clone();
for (elem, basis) in cols[k].iter_mut().zip(col_j_copy.iter()) {
*elem -= dot * basis;
}
}
}
let mut q = vec![0.0; n * n];
for j in 0..n {
for i in 0..n {
q[i * n + j] = cols[j][i];
}
}
q
}
pub struct SymmetricPositiveDefiniteGenerator;
impl SymmetricPositiveDefiniteGenerator {
pub fn generate(dim: usize, condition_number: f64, seed: u64) -> RandResult<RandomMatrix> {
if dim == 0 {
return Err(RandError::InvalidSize(
"SPD dimension must be positive".to_string(),
));
}
if condition_number < 1.0 {
return Err(RandError::InvalidSize(format!(
"condition number must be >= 1.0, got {condition_number}"
)));
}
let q_mat = OrthogonalMatrixGenerator::generate(dim, seed);
let q = q_mat.data();
let mut d = vec![0.0; dim];
if dim == 1 {
d[0] = 1.0;
} else {
let log_min = 0.0_f64; let log_max = condition_number.ln();
for (i, d_i) in d.iter_mut().enumerate() {
let t = i as f64 / (dim - 1) as f64;
*d_i = (log_min + t * (log_max - log_min)).exp();
}
}
let mut qd = vec![0.0; dim * dim];
for i in 0..dim {
for j in 0..dim {
qd[i * dim + j] = q[i * dim + j] * d[j];
}
}
let qt = transpose(q, dim, dim);
let result = matrix_multiply(&qd, &qt, dim, dim, dim);
RandomMatrix::new(dim, dim, result, MatrixLayout::RowMajor)
}
}
pub struct CorrelationMatrixGenerator;
impl CorrelationMatrixGenerator {
pub fn generate(dim: usize, seed: u64) -> RandResult<RandomMatrix> {
if dim == 0 {
return Err(RandError::InvalidSize(
"correlation matrix dimension must be positive".to_string(),
));
}
if dim == 1 {
return RandomMatrix::new(1, 1, vec![1.0], MatrixLayout::RowMajor);
}
let mut rng = SplitMix64::new(seed);
let mut l = vec![0.0; dim * dim];
l[0] = 1.0; for i in 1..dim {
let p = 2.0 * rng.next_f64() - 1.0;
l[i * dim] = p; }
for k in 1..dim {
let mut sum_sq = 0.0;
for j in 0..k {
sum_sq += l[k * dim + j] * l[k * dim + j];
}
let rem = 1.0 - sum_sq;
l[k * dim + k] = if rem > 0.0 { rem.sqrt() } else { 0.0 };
for i in (k + 1)..dim {
let mut sum_sq_i = 0.0;
for j in 0..k {
sum_sq_i += l[i * dim + j] * l[i * dim + j];
}
let rem_i = 1.0 - sum_sq_i;
if rem_i <= 0.0 {
l[i * dim + k] = 0.0;
continue;
}
let p = 2.0 * rng.next_f64() - 1.0;
l[i * dim + k] = p * rem_i.sqrt();
}
}
let lt = transpose(&l, dim, dim);
let c = matrix_multiply(&l, <, dim, dim, dim);
RandomMatrix::new(dim, dim, c, MatrixLayout::RowMajor)
}
}
#[cfg(test)]
mod tests {
use super::*;
const TOL: f64 = 1e-10;
fn is_symmetric(m: &RandomMatrix, tol: f64) -> bool {
if !m.is_square() {
return false;
}
let n = m.rows();
for i in 0..n {
for j in (i + 1)..n {
if (m.get(i, j) - m.get(j, i)).abs() > tol {
return false;
}
}
}
true
}
fn has_unit_diagonal(m: &RandomMatrix, tol: f64) -> bool {
let n = m.rows().min(m.cols());
for i in 0..n {
if (m.get(i, i) - 1.0).abs() > tol {
return false;
}
}
true
}
fn is_positive_definite(m: &RandomMatrix) -> bool {
if !m.is_square() {
return false;
}
cholesky_decompose(m.data(), m.rows()).is_ok()
}
#[test]
fn gaussian_correct_dimensions() {
let m = GaussianMatrixGenerator::generate(5, 3, 0.0, 1.0, 42);
assert_eq!(m.rows(), 5);
assert_eq!(m.cols(), 3);
assert_eq!(m.data().len(), 15);
}
#[test]
fn gaussian_mean_and_variance() {
let m = GaussianMatrixGenerator::generate(1000, 1000, 2.5, 0.5, 123);
let n = m.data().len() as f64;
let mean = m.data().iter().sum::<f64>() / n;
let variance = m.data().iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n;
assert!((mean - 2.5).abs() < 0.01, "mean = {mean}");
assert!((variance - 0.25).abs() < 0.01, "variance = {variance}");
}
#[test]
fn gaussian_deterministic_with_seed() {
let m1 = GaussianMatrixGenerator::generate(10, 10, 0.0, 1.0, 999);
let m2 = GaussianMatrixGenerator::generate(10, 10, 0.0, 1.0, 999);
assert_eq!(m1.data(), m2.data());
}
#[test]
fn cholesky_identity() {
let identity = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
let l = cholesky_decompose(&identity, 3).expect("cholesky should succeed");
for i in 0..3 {
for j in 0..3 {
let expected = if i == j { 1.0 } else { 0.0 };
assert!(
(l[i * 3 + j] - expected).abs() < TOL,
"L[{i},{j}] = {} expected {expected}",
l[i * 3 + j]
);
}
}
}
#[test]
fn cholesky_reconstruction() {
let a = vec![4.0, 2.0, 2.0, 3.0];
let l = cholesky_decompose(&a, 2).expect("cholesky should succeed");
let lt = transpose(&l, 2, 2);
let reconstructed = matrix_multiply(&l, <, 2, 2, 2);
for i in 0..4 {
assert!(
(reconstructed[i] - a[i]).abs() < TOL,
"element {i}: {} vs {}",
reconstructed[i],
a[i]
);
}
}
#[test]
fn cholesky_not_positive_definite() {
let a = vec![1.0, 2.0, 2.0, 1.0];
assert!(cholesky_decompose(&a, 2).is_err());
}
#[test]
fn transpose_round_trip() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; let at = transpose(&a, 2, 3); let att = transpose(&at, 3, 2); assert_eq!(a, att);
}
#[test]
fn matrix_multiply_identity() {
let a = vec![1.0, 2.0, 3.0, 4.0]; let id = vec![1.0, 0.0, 0.0, 1.0]; let result = matrix_multiply(&a, &id, 2, 2, 2);
for i in 0..4 {
assert!((result[i] - a[i]).abs() < TOL);
}
}
#[test]
fn orthogonal_qtq_is_identity() {
let q = OrthogonalMatrixGenerator::generate(5, 42);
let qt = q.transpose();
let qtq = matrix_multiply(qt.data(), q.data(), 5, 5, 5);
for i in 0..5 {
for j in 0..5 {
let expected = if i == j { 1.0 } else { 0.0 };
assert!(
(qtq[i * 5 + j] - expected).abs() < 1e-12,
"Q^T*Q[{i},{j}] = {} expected {expected}",
qtq[i * 5 + j]
);
}
}
}
#[test]
fn orthogonal_determinant_abs_one() {
let q = OrthogonalMatrixGenerator::generate(2, 77);
let det = q.get(0, 0) * q.get(1, 1) - q.get(0, 1) * q.get(1, 0);
assert!(
(det.abs() - 1.0).abs() < 1e-12,
"|det| = {} expected 1.0",
det.abs()
);
}
#[test]
fn wishart_is_symmetric_and_spd() {
let dim = 4;
let dof = 10;
let mut scale = vec![0.0; dim * dim];
for i in 0..dim {
scale[i * dim + i] = 1.0;
}
let w = WishartGenerator::generate(dim, dof, &scale, 42).expect("wishart should succeed");
assert_eq!(w.rows(), dim);
assert_eq!(w.cols(), dim);
assert!(
is_symmetric(&w, 1e-10),
"Wishart matrix should be symmetric"
);
assert!(is_positive_definite(&w), "Wishart matrix should be SPD");
}
#[test]
fn wishart_dof_less_than_dim_errors() {
let scale = vec![1.0, 0.0, 0.0, 1.0];
let result = WishartGenerator::generate(2, 1, &scale, 42);
assert!(result.is_err());
}
#[test]
fn spd_is_symmetric_and_positive_definite() {
let m = SymmetricPositiveDefiniteGenerator::generate(5, 100.0, 42)
.expect("spd gen should succeed");
assert!(is_symmetric(&m, 1e-10), "SPD matrix should be symmetric");
assert!(
is_positive_definite(&m),
"SPD matrix should be positive definite"
);
}
#[test]
fn spd_condition_number_bound() {
let dim = 4;
let kappa = 10.0;
let m = SymmetricPositiveDefiniteGenerator::generate(dim, kappa, 42)
.expect("spd gen should succeed");
let trace: f64 = (0..dim).map(|i| m.get(i, i)).sum();
assert!(trace > 0.0, "trace should be positive");
}
#[test]
fn spd_invalid_condition_number() {
let result = SymmetricPositiveDefiniteGenerator::generate(3, 0.5, 42);
assert!(result.is_err());
}
#[test]
fn correlation_unit_diagonal() {
let c =
CorrelationMatrixGenerator::generate(5, 42).expect("correlation gen should succeed");
assert!(
has_unit_diagonal(&c, 1e-10),
"correlation matrix should have unit diagonal"
);
}
#[test]
fn correlation_is_symmetric_and_psd() {
let c =
CorrelationMatrixGenerator::generate(5, 42).expect("correlation gen should succeed");
assert!(
is_symmetric(&c, 1e-10),
"correlation matrix should be symmetric"
);
assert!(
is_positive_definite(&c),
"correlation matrix should be positive semi-definite"
);
}
#[test]
fn correlation_entries_bounded() {
let c =
CorrelationMatrixGenerator::generate(6, 123).expect("correlation gen should succeed");
for val in c.data() {
assert!(
*val >= -1.0 - 1e-12 && *val <= 1.0 + 1e-12,
"correlation entry {val} out of [-1, 1]"
);
}
}
}