Skip to main content

oxicuda_rand/
matrix_gen.rs

1//! Random matrix generation for statistical applications.
2//!
3//! Provides generators for several families of structured random matrices:
4//!
5//! - **Gaussian**: matrices with i.i.d. Gaussian entries
6//! - **Wishart**: positive-definite matrices from the Wishart distribution
7//! - **Orthogonal**: uniformly distributed orthogonal matrices (Haar measure)
8//! - **SPD**: symmetric positive-definite matrices with controlled condition number
9//! - **Correlation**: valid correlation matrices via the vine method
10//!
11//! All matrices are stored as flat `Vec<f64>` in row-major or column-major order.
12
13use crate::error::{RandError, RandResult};
14
15// ---------------------------------------------------------------------------
16// CPU-side PRNG (SplitMix64)
17// ---------------------------------------------------------------------------
18
19/// Simple SplitMix64 PRNG for CPU-side random number generation.
20///
21/// This is a fast, high-quality 64-bit PRNG suitable for seeding and
22/// generating random values on the CPU side.
23struct SplitMix64 {
24    state: u64,
25}
26
27impl SplitMix64 {
28    /// Creates a new SplitMix64 with the given seed.
29    fn new(seed: u64) -> Self {
30        Self { state: seed }
31    }
32
33    /// Returns the next u64 random value.
34    fn next_u64(&mut self) -> u64 {
35        self.state = self.state.wrapping_add(0x9E37_79B9_7F4A_7C15);
36        let mut z = self.state;
37        z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
38        z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
39        z ^ (z >> 31)
40    }
41
42    /// Returns a uniform f64 in [0, 1).
43    fn next_f64(&mut self) -> f64 {
44        (self.next_u64() >> 11) as f64 / ((1u64 << 53) as f64)
45    }
46
47    /// Returns a pair of standard normal f64 values via Box-Muller transform.
48    fn next_normal_pair(&mut self) -> (f64, f64) {
49        loop {
50            let u1 = self.next_f64();
51            let u2 = self.next_f64();
52            if u1 > 0.0 {
53                let r = (-2.0 * u1.ln()).sqrt();
54                let theta = 2.0 * std::f64::consts::PI * u2;
55                return (r * theta.cos(), r * theta.sin());
56            }
57        }
58    }
59
60    /// Returns a single standard normal f64 value.
61    fn next_normal(&mut self) -> f64 {
62        self.next_normal_pair().0
63    }
64}
65
66// ---------------------------------------------------------------------------
67// MatrixLayout
68// ---------------------------------------------------------------------------
69
70/// Storage layout for matrix data.
71#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
72pub enum MatrixLayout {
73    /// Row-major (C-style) ordering. Element (i,j) is at index `i * cols + j`.
74    RowMajor,
75    /// Column-major (Fortran-style) ordering. Element (i,j) is at index `j * rows + i`.
76    ColMajor,
77}
78
79impl std::fmt::Display for MatrixLayout {
80    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81        match self {
82            Self::RowMajor => write!(f, "RowMajor"),
83            Self::ColMajor => write!(f, "ColMajor"),
84        }
85    }
86}
87
88// ---------------------------------------------------------------------------
89// RandomMatrix
90// ---------------------------------------------------------------------------
91
92/// A random matrix stored as a flat vector.
93///
94/// Provides basic accessors for the matrix dimensions and elements.
95/// The data is stored as a contiguous `Vec<f64>` in the specified layout.
96#[derive(Debug, Clone)]
97pub struct RandomMatrix {
98    /// Number of rows.
99    rows: usize,
100    /// Number of columns.
101    cols: usize,
102    /// Flat data storage.
103    data: Vec<f64>,
104    /// Memory layout.
105    layout: MatrixLayout,
106}
107
108impl RandomMatrix {
109    /// Creates a new `RandomMatrix` from existing data.
110    ///
111    /// # Errors
112    ///
113    /// Returns `RandError::InvalidSize` if `data.len() != rows * cols`.
114    pub fn new(rows: usize, cols: usize, data: Vec<f64>, layout: MatrixLayout) -> RandResult<Self> {
115        if data.len() != rows * cols {
116            return Err(RandError::InvalidSize(format!(
117                "data length {} does not match {}x{} = {}",
118                data.len(),
119                rows,
120                cols,
121                rows * cols
122            )));
123        }
124        Ok(Self {
125            rows,
126            cols,
127            data,
128            layout,
129        })
130    }
131
132    /// Creates a zero-filled matrix.
133    pub fn zeros(rows: usize, cols: usize, layout: MatrixLayout) -> Self {
134        Self {
135            rows,
136            cols,
137            data: vec![0.0; rows * cols],
138            layout,
139        }
140    }
141
142    /// Creates an identity matrix (must be square).
143    ///
144    /// # Errors
145    ///
146    /// Returns `RandError::InvalidSize` if `n == 0`.
147    pub fn identity(n: usize, layout: MatrixLayout) -> RandResult<Self> {
148        if n == 0 {
149            return Err(RandError::InvalidSize(
150                "identity matrix dimension must be positive".to_string(),
151            ));
152        }
153        let mut data = vec![0.0; n * n];
154        for i in 0..n {
155            data[i * n + i] = 1.0;
156        }
157        // If ColMajor, the identity is the same (diagonal is at i*n+i in both layouts)
158        Ok(Self {
159            rows: n,
160            cols: n,
161            data,
162            layout,
163        })
164    }
165
166    /// Returns the number of rows.
167    pub fn rows(&self) -> usize {
168        self.rows
169    }
170
171    /// Returns the number of columns.
172    pub fn cols(&self) -> usize {
173        self.cols
174    }
175
176    /// Returns the matrix layout.
177    pub fn layout(&self) -> MatrixLayout {
178        self.layout
179    }
180
181    /// Returns a reference to the underlying data.
182    pub fn data(&self) -> &[f64] {
183        &self.data
184    }
185
186    /// Returns a mutable reference to the underlying data.
187    pub fn data_mut(&mut self) -> &mut [f64] {
188        &mut self.data
189    }
190
191    /// Consumes the matrix and returns the underlying data.
192    pub fn into_data(self) -> Vec<f64> {
193        self.data
194    }
195
196    /// Returns the element at position (i, j).
197    ///
198    /// # Panics
199    ///
200    /// Panics if `i >= rows` or `j >= cols`.
201    pub fn get(&self, i: usize, j: usize) -> f64 {
202        match self.layout {
203            MatrixLayout::RowMajor => self.data[i * self.cols + j],
204            MatrixLayout::ColMajor => self.data[j * self.rows + i],
205        }
206    }
207
208    /// Sets the element at position (i, j).
209    ///
210    /// # Panics
211    ///
212    /// Panics if `i >= rows` or `j >= cols`.
213    pub fn set(&mut self, i: usize, j: usize, value: f64) {
214        match self.layout {
215            MatrixLayout::RowMajor => self.data[i * self.cols + j] = value,
216            MatrixLayout::ColMajor => self.data[j * self.rows + i] = value,
217        }
218    }
219
220    /// Returns `true` if the matrix is square.
221    pub fn is_square(&self) -> bool {
222        self.rows == self.cols
223    }
224
225    /// Computes the Frobenius norm of the matrix.
226    pub fn frobenius_norm(&self) -> f64 {
227        self.data.iter().map(|x| x * x).sum::<f64>().sqrt()
228    }
229
230    /// Returns the transpose of this matrix.
231    pub fn transpose(&self) -> Self {
232        let data = transpose(&self.data, self.rows, self.cols);
233        Self {
234            rows: self.cols,
235            cols: self.rows,
236            data,
237            layout: self.layout,
238        }
239    }
240}
241
242impl std::fmt::Display for RandomMatrix {
243    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
244        write!(
245            f,
246            "RandomMatrix({}x{}, {})",
247            self.rows, self.cols, self.layout
248        )
249    }
250}
251
252// ---------------------------------------------------------------------------
253// Helper functions: linear algebra primitives
254// ---------------------------------------------------------------------------
255
256/// Transposes an `m x n` row-major matrix into an `n x m` row-major matrix.
257pub fn transpose(matrix: &[f64], rows: usize, cols: usize) -> Vec<f64> {
258    let mut result = vec![0.0; rows * cols];
259    for i in 0..rows {
260        for j in 0..cols {
261            result[j * rows + i] = matrix[i * cols + j];
262        }
263    }
264    result
265}
266
267/// Computes C = A * B where A is `m x k` and B is `k x n` (row-major).
268///
269/// The output C is `m x n` in row-major order.
270pub fn matrix_multiply(a: &[f64], b: &[f64], m: usize, n: usize, k: usize) -> Vec<f64> {
271    let mut c = vec![0.0; m * n];
272    for i in 0..m {
273        for p in 0..k {
274            let a_ip = a[i * k + p];
275            for j in 0..n {
276                c[i * n + j] += a_ip * b[p * n + j];
277            }
278        }
279    }
280    c
281}
282
283/// Computes the Cholesky decomposition of a symmetric positive-definite matrix.
284///
285/// Given an `n x n` SPD matrix A (row-major), returns the lower-triangular
286/// matrix L such that A = L * L^T.
287///
288/// # Errors
289///
290/// Returns `RandError::InternalError` if the matrix is not positive definite
291/// (i.e., a diagonal element becomes non-positive during decomposition).
292pub fn cholesky_decompose(matrix: &[f64], n: usize) -> RandResult<Vec<f64>> {
293    if matrix.len() != n * n {
294        return Err(RandError::InvalidSize(format!(
295            "expected {}x{} = {} elements, got {}",
296            n,
297            n,
298            n * n,
299            matrix.len()
300        )));
301    }
302
303    let mut l = vec![0.0; n * n];
304
305    for j in 0..n {
306        // Diagonal element
307        let mut sum = 0.0;
308        for k in 0..j {
309            sum += l[j * n + k] * l[j * n + k];
310        }
311        let diag = matrix[j * n + j] - sum;
312        if diag <= 0.0 {
313            return Err(RandError::InternalError(format!(
314                "Cholesky decomposition failed: matrix is not positive definite \
315                 (diagonal element {} became {:.6e})",
316                j, diag
317            )));
318        }
319        l[j * n + j] = diag.sqrt();
320
321        // Off-diagonal elements in column j
322        for i in (j + 1)..n {
323            let mut sum = 0.0;
324            for k in 0..j {
325                sum += l[i * n + k] * l[j * n + k];
326            }
327            l[i * n + j] = (matrix[i * n + j] - sum) / l[j * n + j];
328        }
329    }
330
331    Ok(l)
332}
333
334// ---------------------------------------------------------------------------
335// GaussianMatrixGenerator
336// ---------------------------------------------------------------------------
337
338/// Generates matrices with independent and identically distributed Gaussian entries.
339///
340/// Each entry is drawn from N(mean, stddev^2) using the Box-Muller transform
341/// applied to a SplitMix64 PRNG.
342pub struct GaussianMatrixGenerator;
343
344impl GaussianMatrixGenerator {
345    /// Generates an `rows x cols` matrix with i.i.d. Gaussian entries.
346    ///
347    /// # Arguments
348    ///
349    /// * `rows` - Number of rows
350    /// * `cols` - Number of columns
351    /// * `mean` - Mean of the Gaussian distribution
352    /// * `stddev` - Standard deviation of the Gaussian distribution
353    /// * `seed` - Random seed
354    pub fn generate(rows: usize, cols: usize, mean: f64, stddev: f64, seed: u64) -> RandomMatrix {
355        let mut rng = SplitMix64::new(seed);
356        let total = rows * cols;
357        let mut data = Vec::with_capacity(total);
358
359        // Generate pairs via Box-Muller
360        let mut i = 0;
361        while i + 1 < total {
362            let (z0, z1) = rng.next_normal_pair();
363            data.push(mean + stddev * z0);
364            data.push(mean + stddev * z1);
365            i += 2;
366        }
367        // Handle odd remaining element
368        if data.len() < total {
369            let z = rng.next_normal();
370            data.push(mean + stddev * z);
371        }
372
373        RandomMatrix {
374            rows,
375            cols,
376            data,
377            layout: MatrixLayout::RowMajor,
378        }
379    }
380}
381
382// ---------------------------------------------------------------------------
383// WishartGenerator
384// ---------------------------------------------------------------------------
385
386/// Generates random matrices from the Wishart distribution.
387///
388/// The Wishart distribution W ~ Wishart(Sigma, n) is the distribution of
389/// X^T * X where the rows of X are i.i.d. multivariate normal N(0, Sigma).
390///
391/// The resulting matrix is symmetric positive definite when `dof >= dim`.
392pub struct WishartGenerator;
393
394impl WishartGenerator {
395    /// Generates a Wishart-distributed random matrix.
396    ///
397    /// # Arguments
398    ///
399    /// * `dim` - Dimension p of the p x p output matrix
400    /// * `dof` - Degrees of freedom n (must be >= dim for positive definiteness)
401    /// * `scale` - The p x p scale matrix Sigma in row-major order
402    /// * `seed` - Random seed
403    ///
404    /// # Errors
405    ///
406    /// Returns `RandError::InvalidSize` if `dof < dim` or `scale` has wrong length.
407    /// Returns `RandError::InternalError` if the scale matrix is not positive definite.
408    pub fn generate(dim: usize, dof: usize, scale: &[f64], seed: u64) -> RandResult<RandomMatrix> {
409        if dim == 0 {
410            return Err(RandError::InvalidSize(
411                "Wishart dimension must be positive".to_string(),
412            ));
413        }
414        if dof < dim {
415            return Err(RandError::InvalidSize(format!(
416                "degrees of freedom ({dof}) must be >= dimension ({dim}) for positive definiteness"
417            )));
418        }
419        if scale.len() != dim * dim {
420            return Err(RandError::InvalidSize(format!(
421                "scale matrix must have {} elements, got {}",
422                dim * dim,
423                scale.len()
424            )));
425        }
426
427        // Cholesky decomposition: Sigma = L * L^T
428        let l = cholesky_decompose(scale, dim)?;
429
430        // Generate Z: dof x dim matrix of standard normals
431        let z = GaussianMatrixGenerator::generate(dof, dim, 0.0, 1.0, seed);
432
433        // X = Z * L^T (each row of X ~ N(0, Sigma))
434        let lt = transpose(&l, dim, dim);
435        let x = matrix_multiply(z.data(), &lt, dof, dim, dim);
436
437        // W = X^T * X (dim x dim SPD matrix)
438        let xt = transpose(&x, dof, dim);
439        let w = matrix_multiply(&xt, &x, dim, dim, dof);
440
441        RandomMatrix::new(dim, dim, w, MatrixLayout::RowMajor)
442    }
443}
444
445// ---------------------------------------------------------------------------
446// OrthogonalMatrixGenerator
447// ---------------------------------------------------------------------------
448
449/// Generates random orthogonal matrices uniformly distributed on O(n)
450/// according to the Haar measure.
451///
452/// The method is: generate a Gaussian random matrix, compute its QR
453/// decomposition via modified Gram-Schmidt, and return Q.
454pub struct OrthogonalMatrixGenerator;
455
456impl OrthogonalMatrixGenerator {
457    /// Generates a random `dim x dim` orthogonal matrix.
458    ///
459    /// # Arguments
460    ///
461    /// * `dim` - Dimension of the square orthogonal matrix
462    /// * `seed` - Random seed
463    pub fn generate(dim: usize, seed: u64) -> RandomMatrix {
464        if dim == 0 {
465            return RandomMatrix {
466                rows: 0,
467                cols: 0,
468                data: Vec::new(),
469                layout: MatrixLayout::RowMajor,
470            };
471        }
472        if dim == 1 {
473            return RandomMatrix {
474                rows: 1,
475                cols: 1,
476                data: vec![1.0],
477                layout: MatrixLayout::RowMajor,
478            };
479        }
480
481        // Generate a random Gaussian matrix
482        let a = GaussianMatrixGenerator::generate(dim, dim, 0.0, 1.0, seed);
483
484        // QR decomposition via modified Gram-Schmidt
485        let q = modified_gram_schmidt(a.data(), dim);
486
487        RandomMatrix {
488            rows: dim,
489            cols: dim,
490            data: q,
491            layout: MatrixLayout::RowMajor,
492        }
493    }
494}
495
496/// Modified Gram-Schmidt QR decomposition.
497///
498/// Given an `n x n` matrix A (row-major), returns Q (n x n row-major)
499/// such that Q is orthogonal.
500///
501/// We work column-by-column. Columns of A are extracted, orthogonalized,
502/// and normalized.
503fn modified_gram_schmidt(a: &[f64], n: usize) -> Vec<f64> {
504    // Extract columns
505    let mut cols: Vec<Vec<f64>> = (0..n)
506        .map(|j| (0..n).map(|i| a[i * n + j]).collect())
507        .collect();
508
509    for j in 0..n {
510        // Normalize column j
511        let norm = cols[j].iter().map(|x| x * x).sum::<f64>().sqrt();
512        if norm > 1e-15 {
513            for elem in &mut cols[j] {
514                *elem /= norm;
515            }
516        }
517
518        // Orthogonalize remaining columns against column j
519        for k in (j + 1)..n {
520            let dot: f64 = (0..n).map(|i| cols[j][i] * cols[k][i]).sum();
521            let col_j_copy: Vec<f64> = cols[j].clone();
522            for (elem, basis) in cols[k].iter_mut().zip(col_j_copy.iter()) {
523                *elem -= dot * basis;
524            }
525        }
526    }
527
528    // Pack columns back into row-major matrix
529    let mut q = vec![0.0; n * n];
530    for j in 0..n {
531        for i in 0..n {
532            q[i * n + j] = cols[j][i];
533        }
534    }
535    q
536}
537
538// ---------------------------------------------------------------------------
539// SymmetricPositiveDefiniteGenerator
540// ---------------------------------------------------------------------------
541
542/// Generates random symmetric positive-definite (SPD) matrices with
543/// controlled condition number.
544///
545/// The method is: generate a random orthogonal Q, a positive diagonal D
546/// with eigenvalues logarithmically spaced between 1 and `condition_number`,
547/// then return Q * D * Q^T.
548pub struct SymmetricPositiveDefiniteGenerator;
549
550impl SymmetricPositiveDefiniteGenerator {
551    /// Generates a random `dim x dim` SPD matrix.
552    ///
553    /// # Arguments
554    ///
555    /// * `dim` - Dimension of the square SPD matrix
556    /// * `condition_number` - Ratio of largest to smallest eigenvalue (must be >= 1.0)
557    /// * `seed` - Random seed
558    ///
559    /// # Errors
560    ///
561    /// Returns `RandError::InvalidSize` if `dim == 0` or `condition_number < 1.0`.
562    pub fn generate(dim: usize, condition_number: f64, seed: u64) -> RandResult<RandomMatrix> {
563        if dim == 0 {
564            return Err(RandError::InvalidSize(
565                "SPD dimension must be positive".to_string(),
566            ));
567        }
568        if condition_number < 1.0 {
569            return Err(RandError::InvalidSize(format!(
570                "condition number must be >= 1.0, got {condition_number}"
571            )));
572        }
573
574        // Generate orthogonal matrix Q
575        let q_mat = OrthogonalMatrixGenerator::generate(dim, seed);
576        let q = q_mat.data();
577
578        // Generate diagonal D with eigenvalues log-spaced from 1 to condition_number
579        let mut d = vec![0.0; dim];
580        if dim == 1 {
581            d[0] = 1.0;
582        } else {
583            let log_min = 0.0_f64; // ln(1) = 0
584            let log_max = condition_number.ln();
585            for (i, d_i) in d.iter_mut().enumerate() {
586                let t = i as f64 / (dim - 1) as f64;
587                *d_i = (log_min + t * (log_max - log_min)).exp();
588            }
589        }
590
591        // Compute Q * D * Q^T
592        // First: Q * D  (scale each column of Q by d[j])
593        let mut qd = vec![0.0; dim * dim];
594        for i in 0..dim {
595            for j in 0..dim {
596                qd[i * dim + j] = q[i * dim + j] * d[j];
597            }
598        }
599
600        // Then: (Q*D) * Q^T
601        let qt = transpose(q, dim, dim);
602        let result = matrix_multiply(&qd, &qt, dim, dim, dim);
603
604        RandomMatrix::new(dim, dim, result, MatrixLayout::RowMajor)
605    }
606}
607
608// ---------------------------------------------------------------------------
609// CorrelationMatrixGenerator
610// ---------------------------------------------------------------------------
611
612/// Generates random correlation matrices using the vine method.
613///
614/// A correlation matrix is symmetric positive semi-definite with ones on the
615/// diagonal and off-diagonal entries in [-1, 1]. The vine method generates
616/// partial correlations uniformly and constructs a valid correlation matrix.
617pub struct CorrelationMatrixGenerator;
618
619impl CorrelationMatrixGenerator {
620    /// Generates a random `dim x dim` correlation matrix.
621    ///
622    /// # Arguments
623    ///
624    /// * `dim` - Dimension of the square correlation matrix
625    /// * `seed` - Random seed
626    ///
627    /// # Errors
628    ///
629    /// Returns `RandError::InvalidSize` if `dim == 0`.
630    pub fn generate(dim: usize, seed: u64) -> RandResult<RandomMatrix> {
631        if dim == 0 {
632            return Err(RandError::InvalidSize(
633                "correlation matrix dimension must be positive".to_string(),
634            ));
635        }
636        if dim == 1 {
637            return RandomMatrix::new(1, 1, vec![1.0], MatrixLayout::RowMajor);
638        }
639
640        let mut rng = SplitMix64::new(seed);
641
642        // Vine method: build correlation matrix from partial correlations.
643        // We construct the matrix via a lower-triangular factor.
644        //
645        // For each column k and row i > k:
646        //   Generate a partial correlation p_{ik} in (-1, 1)
647        //   Build the factor L such that C = L * L^T is a correlation matrix.
648
649        let mut l = vec![0.0; dim * dim];
650
651        // First column of L
652        l[0] = 1.0; // L[0,0] = 1
653        for i in 1..dim {
654            // Generate partial correlation for (i, 0)
655            let p = 2.0 * rng.next_f64() - 1.0;
656            l[i * dim] = p; // L[i,0] = p
657        }
658
659        // Remaining columns
660        for k in 1..dim {
661            // L[k,k] = sqrt(1 - sum of L[k,j]^2 for j < k)
662            let mut sum_sq = 0.0;
663            for j in 0..k {
664                sum_sq += l[k * dim + j] * l[k * dim + j];
665            }
666            let rem = 1.0 - sum_sq;
667            l[k * dim + k] = if rem > 0.0 { rem.sqrt() } else { 0.0 };
668
669            // For rows i > k
670            for i in (k + 1)..dim {
671                // Remaining "radius" for row i
672                let mut sum_sq_i = 0.0;
673                for j in 0..k {
674                    sum_sq_i += l[i * dim + j] * l[i * dim + j];
675                }
676                let rem_i = 1.0 - sum_sq_i;
677                if rem_i <= 0.0 {
678                    l[i * dim + k] = 0.0;
679                    continue;
680                }
681
682                // Generate partial correlation
683                let p = 2.0 * rng.next_f64() - 1.0;
684                l[i * dim + k] = p * rem_i.sqrt();
685            }
686        }
687
688        // C = L * L^T
689        let lt = transpose(&l, dim, dim);
690        let c = matrix_multiply(&l, &lt, dim, dim, dim);
691
692        RandomMatrix::new(dim, dim, c, MatrixLayout::RowMajor)
693    }
694}
695
696// ===========================================================================
697// Tests
698// ===========================================================================
699
700#[cfg(test)]
701mod tests {
702    use super::*;
703
704    const TOL: f64 = 1e-10;
705
706    /// Check if a matrix is approximately symmetric.
707    fn is_symmetric(m: &RandomMatrix, tol: f64) -> bool {
708        if !m.is_square() {
709            return false;
710        }
711        let n = m.rows();
712        for i in 0..n {
713            for j in (i + 1)..n {
714                if (m.get(i, j) - m.get(j, i)).abs() > tol {
715                    return false;
716                }
717            }
718        }
719        true
720    }
721
722    /// Check if all diagonal elements of a matrix are approximately 1.
723    fn has_unit_diagonal(m: &RandomMatrix, tol: f64) -> bool {
724        let n = m.rows().min(m.cols());
725        for i in 0..n {
726            if (m.get(i, i) - 1.0).abs() > tol {
727                return false;
728            }
729        }
730        true
731    }
732
733    /// Check if a matrix is positive definite by attempting Cholesky decomposition.
734    fn is_positive_definite(m: &RandomMatrix) -> bool {
735        if !m.is_square() {
736            return false;
737        }
738        cholesky_decompose(m.data(), m.rows()).is_ok()
739    }
740
741    // -----------------------------------------------------------------------
742    // Gaussian matrix tests
743    // -----------------------------------------------------------------------
744
745    #[test]
746    fn gaussian_correct_dimensions() {
747        let m = GaussianMatrixGenerator::generate(5, 3, 0.0, 1.0, 42);
748        assert_eq!(m.rows(), 5);
749        assert_eq!(m.cols(), 3);
750        assert_eq!(m.data().len(), 15);
751    }
752
753    #[test]
754    fn gaussian_mean_and_variance() {
755        let m = GaussianMatrixGenerator::generate(1000, 1000, 2.5, 0.5, 123);
756        let n = m.data().len() as f64;
757        let mean = m.data().iter().sum::<f64>() / n;
758        let variance = m.data().iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n;
759
760        // With 1M samples, mean should be close to 2.5 and variance close to 0.25
761        assert!((mean - 2.5).abs() < 0.01, "mean = {mean}");
762        assert!((variance - 0.25).abs() < 0.01, "variance = {variance}");
763    }
764
765    #[test]
766    fn gaussian_deterministic_with_seed() {
767        let m1 = GaussianMatrixGenerator::generate(10, 10, 0.0, 1.0, 999);
768        let m2 = GaussianMatrixGenerator::generate(10, 10, 0.0, 1.0, 999);
769        assert_eq!(m1.data(), m2.data());
770    }
771
772    // -----------------------------------------------------------------------
773    // Cholesky decomposition tests
774    // -----------------------------------------------------------------------
775
776    #[test]
777    fn cholesky_identity() {
778        let identity = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
779        let l = cholesky_decompose(&identity, 3).expect("cholesky should succeed");
780        // L should be identity
781        for i in 0..3 {
782            for j in 0..3 {
783                let expected = if i == j { 1.0 } else { 0.0 };
784                assert!(
785                    (l[i * 3 + j] - expected).abs() < TOL,
786                    "L[{i},{j}] = {} expected {expected}",
787                    l[i * 3 + j]
788                );
789            }
790        }
791    }
792
793    #[test]
794    fn cholesky_reconstruction() {
795        // A = [[4, 2], [2, 3]]
796        let a = vec![4.0, 2.0, 2.0, 3.0];
797        let l = cholesky_decompose(&a, 2).expect("cholesky should succeed");
798        // Verify A = L * L^T
799        let lt = transpose(&l, 2, 2);
800        let reconstructed = matrix_multiply(&l, &lt, 2, 2, 2);
801        for i in 0..4 {
802            assert!(
803                (reconstructed[i] - a[i]).abs() < TOL,
804                "element {i}: {} vs {}",
805                reconstructed[i],
806                a[i]
807            );
808        }
809    }
810
811    #[test]
812    fn cholesky_not_positive_definite() {
813        // Not positive definite: diagonal has a negative
814        let a = vec![1.0, 2.0, 2.0, 1.0];
815        assert!(cholesky_decompose(&a, 2).is_err());
816    }
817
818    // -----------------------------------------------------------------------
819    // Matrix multiply and transpose tests
820    // -----------------------------------------------------------------------
821
822    #[test]
823    fn transpose_round_trip() {
824        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; // 2x3
825        let at = transpose(&a, 2, 3); // 3x2
826        let att = transpose(&at, 3, 2); // 2x3
827        assert_eq!(a, att);
828    }
829
830    #[test]
831    fn matrix_multiply_identity() {
832        let a = vec![1.0, 2.0, 3.0, 4.0]; // 2x2
833        let id = vec![1.0, 0.0, 0.0, 1.0]; // 2x2
834        let result = matrix_multiply(&a, &id, 2, 2, 2);
835        for i in 0..4 {
836            assert!((result[i] - a[i]).abs() < TOL);
837        }
838    }
839
840    // -----------------------------------------------------------------------
841    // Orthogonal matrix tests
842    // -----------------------------------------------------------------------
843
844    #[test]
845    fn orthogonal_qtq_is_identity() {
846        let q = OrthogonalMatrixGenerator::generate(5, 42);
847        let qt = q.transpose();
848        let qtq = matrix_multiply(qt.data(), q.data(), 5, 5, 5);
849
850        for i in 0..5 {
851            for j in 0..5 {
852                let expected = if i == j { 1.0 } else { 0.0 };
853                assert!(
854                    (qtq[i * 5 + j] - expected).abs() < 1e-12,
855                    "Q^T*Q[{i},{j}] = {} expected {expected}",
856                    qtq[i * 5 + j]
857                );
858            }
859        }
860    }
861
862    #[test]
863    fn orthogonal_determinant_abs_one() {
864        // For a 2x2 orthogonal matrix, |det| = 1
865        let q = OrthogonalMatrixGenerator::generate(2, 77);
866        let det = q.get(0, 0) * q.get(1, 1) - q.get(0, 1) * q.get(1, 0);
867        assert!(
868            (det.abs() - 1.0).abs() < 1e-12,
869            "|det| = {} expected 1.0",
870            det.abs()
871        );
872    }
873
874    // -----------------------------------------------------------------------
875    // Wishart matrix tests
876    // -----------------------------------------------------------------------
877
878    #[test]
879    fn wishart_is_symmetric_and_spd() {
880        let dim = 4;
881        let dof = 10;
882        // Use identity as scale matrix
883        let mut scale = vec![0.0; dim * dim];
884        for i in 0..dim {
885            scale[i * dim + i] = 1.0;
886        }
887        let w = WishartGenerator::generate(dim, dof, &scale, 42).expect("wishart should succeed");
888        assert_eq!(w.rows(), dim);
889        assert_eq!(w.cols(), dim);
890        assert!(
891            is_symmetric(&w, 1e-10),
892            "Wishart matrix should be symmetric"
893        );
894        assert!(is_positive_definite(&w), "Wishart matrix should be SPD");
895    }
896
897    #[test]
898    fn wishart_dof_less_than_dim_errors() {
899        let scale = vec![1.0, 0.0, 0.0, 1.0];
900        let result = WishartGenerator::generate(2, 1, &scale, 42);
901        assert!(result.is_err());
902    }
903
904    // -----------------------------------------------------------------------
905    // SPD matrix tests
906    // -----------------------------------------------------------------------
907
908    #[test]
909    fn spd_is_symmetric_and_positive_definite() {
910        let m = SymmetricPositiveDefiniteGenerator::generate(5, 100.0, 42)
911            .expect("spd gen should succeed");
912        assert!(is_symmetric(&m, 1e-10), "SPD matrix should be symmetric");
913        assert!(
914            is_positive_definite(&m),
915            "SPD matrix should be positive definite"
916        );
917    }
918
919    #[test]
920    fn spd_condition_number_bound() {
921        let dim = 4;
922        let kappa = 10.0;
923        let m = SymmetricPositiveDefiniteGenerator::generate(dim, kappa, 42)
924            .expect("spd gen should succeed");
925
926        // The eigenvalues are the diagonal of D: [1, ..., kappa]
927        // We can verify by checking the trace and Frobenius norm are consistent
928        let trace: f64 = (0..dim).map(|i| m.get(i, i)).sum();
929        assert!(trace > 0.0, "trace should be positive");
930    }
931
932    #[test]
933    fn spd_invalid_condition_number() {
934        let result = SymmetricPositiveDefiniteGenerator::generate(3, 0.5, 42);
935        assert!(result.is_err());
936    }
937
938    // -----------------------------------------------------------------------
939    // Correlation matrix tests
940    // -----------------------------------------------------------------------
941
942    #[test]
943    fn correlation_unit_diagonal() {
944        let c =
945            CorrelationMatrixGenerator::generate(5, 42).expect("correlation gen should succeed");
946        assert!(
947            has_unit_diagonal(&c, 1e-10),
948            "correlation matrix should have unit diagonal"
949        );
950    }
951
952    #[test]
953    fn correlation_is_symmetric_and_psd() {
954        let c =
955            CorrelationMatrixGenerator::generate(5, 42).expect("correlation gen should succeed");
956        assert!(
957            is_symmetric(&c, 1e-10),
958            "correlation matrix should be symmetric"
959        );
960        // Correlation matrices are PSD; the vine method should produce PD in practice
961        assert!(
962            is_positive_definite(&c),
963            "correlation matrix should be positive semi-definite"
964        );
965    }
966
967    #[test]
968    fn correlation_entries_bounded() {
969        let c =
970            CorrelationMatrixGenerator::generate(6, 123).expect("correlation gen should succeed");
971        for val in c.data() {
972            assert!(
973                *val >= -1.0 - 1e-12 && *val <= 1.0 + 1e-12,
974                "correlation entry {val} out of [-1, 1]"
975            );
976        }
977    }
978}