use faer::sparse::SparseColMat;
use faer::sparse::Triplet;
use rand::{Rng, RngExt};
use rand_distr::{Distribution, Uniform};
use crate::error::SparseError;
pub struct RandomMatrixConfig {
pub size: usize,
pub target_nnz: usize,
pub positive_definite: bool,
}
fn add_diagonal(
triplets: &mut Vec<Triplet<usize, usize, f64>>,
row_abs_sum: &[f64],
positive_definite: bool,
rng: &mut impl Rng,
) {
let n = row_abs_sum.len();
let half = n / 2;
for (i, &abs_sum) in row_abs_sum.iter().enumerate() {
let margin = 1.0 + rng.random::<f64>();
if positive_definite || i < half {
triplets.push(Triplet::new(i, i, abs_sum + margin));
} else {
triplets.push(Triplet::new(i, i, -(abs_sum + margin)));
}
}
}
fn validate_size(size: usize, positive_definite: bool) -> Result<(), SparseError> {
if size == 0 {
return Err(SparseError::InvalidInput {
reason: "matrix size must be > 0".to_string(),
});
}
if size == 1 && !positive_definite {
return Err(SparseError::InvalidInput {
reason: "cannot generate 1x1 indefinite matrix".to_string(),
});
}
Ok(())
}
pub fn generate_random_symmetric(
config: &RandomMatrixConfig,
rng: &mut impl Rng,
) -> Result<SparseColMat<usize, f64>, SparseError> {
let n = config.size;
validate_size(n, config.positive_definite)?;
let target_offdiag_pairs = if config.target_nnz > n {
(config.target_nnz - n) / 2
} else {
0
};
let max_offdiag_pairs = n * (n - 1) / 2;
let actual_pairs = target_offdiag_pairs.min(max_offdiag_pairs);
let val_dist = Uniform::new(-1.0, 1.0).unwrap();
let idx_dist = Uniform::new(0, n).unwrap();
let mut triplets: Vec<Triplet<usize, usize, f64>> = Vec::new();
let mut placed = std::collections::HashSet::new();
for _ in 0..actual_pairs * 3 {
if placed.len() >= actual_pairs {
break;
}
let i = idx_dist.sample(rng);
let j = idx_dist.sample(rng);
if i == j {
continue;
}
let (lo, hi) = if i > j { (j, i) } else { (i, j) };
if placed.contains(&(lo, hi)) {
continue;
}
placed.insert((lo, hi));
let v = val_dist.sample(rng);
triplets.push(Triplet::new(lo, hi, v));
triplets.push(Triplet::new(hi, lo, v));
}
let mut row_abs_sum = vec![0.0f64; n];
for t in &triplets {
if t.row != t.col {
row_abs_sum[t.row] += t.val.abs();
}
}
add_diagonal(&mut triplets, &row_abs_sum, config.positive_definite, rng);
SparseColMat::try_new_from_triplets(n, n, &triplets).map_err(|e| SparseError::InvalidInput {
reason: format!("failed to create sparse matrix from triplets: {:?}", e),
})
}
pub fn generate_arrow(
size: usize,
positive_definite: bool,
rng: &mut impl Rng,
) -> Result<SparseColMat<usize, f64>, SparseError> {
validate_size(size, positive_definite)?;
let n = size;
let val_dist = Uniform::new(0.1f64, 1.0).unwrap();
let mut triplets: Vec<Triplet<usize, usize, f64>> = Vec::new();
let mut row_abs_sum = vec![0.0f64; n];
for j in 1..n {
let v: f64 = val_dist.sample(rng);
triplets.push(Triplet::new(0, j, v));
triplets.push(Triplet::new(j, 0, v));
row_abs_sum[0] += v.abs();
row_abs_sum[j] += v.abs();
}
add_diagonal(&mut triplets, &row_abs_sum, positive_definite, rng);
SparseColMat::try_new_from_triplets(n, n, &triplets).map_err(|e| SparseError::InvalidInput {
reason: format!("failed to create arrow matrix from triplets: {:?}", e),
})
}
pub fn generate_tridiagonal(
size: usize,
positive_definite: bool,
rng: &mut impl Rng,
) -> Result<SparseColMat<usize, f64>, SparseError> {
validate_size(size, positive_definite)?;
let n = size;
let val_dist = Uniform::new(0.1f64, 1.0).unwrap();
let mut triplets: Vec<Triplet<usize, usize, f64>> = Vec::new();
let mut row_abs_sum = vec![0.0f64; n];
for i in 0..n - 1 {
let v: f64 = val_dist.sample(rng);
triplets.push(Triplet::new(i, i + 1, v));
triplets.push(Triplet::new(i + 1, i, v));
row_abs_sum[i] += v.abs();
row_abs_sum[i + 1] += v.abs();
}
add_diagonal(&mut triplets, &row_abs_sum, positive_definite, rng);
SparseColMat::try_new_from_triplets(n, n, &triplets).map_err(|e| SparseError::InvalidInput {
reason: format!("failed to create tridiagonal matrix from triplets: {:?}", e),
})
}
pub fn generate_banded(
size: usize,
bandwidth: usize,
positive_definite: bool,
rng: &mut impl Rng,
) -> Result<SparseColMat<usize, f64>, SparseError> {
validate_size(size, positive_definite)?;
let n = size;
let val_dist = Uniform::new(0.1f64, 1.0).unwrap();
let mut triplets: Vec<Triplet<usize, usize, f64>> = Vec::new();
let mut row_abs_sum = vec![0.0f64; n];
for i in 0..n {
let j_start = i.saturating_sub(bandwidth);
let j_end = (i + bandwidth + 1).min(n);
for j in j_start..j_end {
if j > i {
let v: f64 = val_dist.sample(rng);
triplets.push(Triplet::new(i, j, v));
triplets.push(Triplet::new(j, i, v));
row_abs_sum[i] += v.abs();
row_abs_sum[j] += v.abs();
}
}
}
add_diagonal(&mut triplets, &row_abs_sum, positive_definite, rng);
SparseColMat::try_new_from_triplets(n, n, &triplets).map_err(|e| SparseError::InvalidInput {
reason: format!("failed to create banded matrix from triplets: {:?}", e),
})
}
#[cfg(test)]
mod tests {
use super::*;
use rand::SeedableRng;
use rand::rngs::StdRng;
fn seeded_rng() -> StdRng {
StdRng::seed_from_u64(42)
}
#[test]
fn random_pd_matrix_properties() {
let mut rng = seeded_rng();
let config = RandomMatrixConfig {
size: 100,
target_nnz: 500,
positive_definite: true,
};
let m = generate_random_symmetric(&config, &mut rng).expect("generation failed");
assert_eq!(m.nrows(), 100);
assert_eq!(m.ncols(), 100);
let dense = m.to_dense();
for i in 0..100 {
for j in 0..100 {
assert_eq!(
dense[(i, j)],
dense[(j, i)],
"not symmetric at ({}, {})",
i,
j
);
}
}
for i in 0..100 {
assert!(
dense[(i, i)] > 0.0,
"diagonal entry {} should be positive for PD",
i
);
}
}
#[test]
fn random_indefinite_matrix_has_mixed_signs() {
let mut rng = seeded_rng();
let config = RandomMatrixConfig {
size: 50,
target_nnz: 200,
positive_definite: false,
};
let m = generate_random_symmetric(&config, &mut rng).expect("generation failed");
let dense = m.to_dense();
let mut has_positive = false;
let mut has_negative = false;
for i in 0..50 {
if dense[(i, i)] > 0.0 {
has_positive = true;
}
if dense[(i, i)] < 0.0 {
has_negative = true;
}
}
assert!(has_positive, "should have positive diagonal entries");
assert!(has_negative, "should have negative diagonal entries");
}
#[test]
fn generate_arrow_pattern() {
let mut rng = seeded_rng();
let m = generate_arrow(20, true, &mut rng).expect("arrow generation failed");
assert_eq!(m.nrows(), 20);
let dense = m.to_dense();
for j in 1..20 {
assert!(
dense[(0, j)] != 0.0,
"first row entry (0, {}) should be nonzero",
j
);
assert!(
dense[(j, 0)] != 0.0,
"first col entry ({}, 0) should be nonzero",
j
);
}
for i in 1..20 {
for j in 1..20 {
if i != j {
assert_eq!(
dense[(i, j)],
0.0,
"off-diagonal ({}, {}) should be zero in arrow tail",
i,
j
);
}
}
}
}
#[test]
fn generate_tridiagonal_pattern() {
let mut rng = seeded_rng();
let m = generate_tridiagonal(30, true, &mut rng).expect("tridiagonal generation failed");
assert_eq!(m.nrows(), 30);
let dense = m.to_dense();
for i in 0..30 {
let mut nnz_count = 0;
for j in 0..30 {
if dense[(i, j)] != 0.0 {
nnz_count += 1;
assert!(
(i as isize - j as isize).unsigned_abs() <= 1,
"nonzero at ({}, {}) outside tridiagonal band",
i,
j
);
}
}
assert!(
nnz_count <= 3,
"row {} has {} nonzeros, expected <= 3",
i,
nnz_count
);
}
}
#[test]
fn generate_banded_pattern() {
let mut rng = seeded_rng();
let m = generate_banded(40, 3, true, &mut rng).expect("banded generation failed");
assert_eq!(m.nrows(), 40);
let dense = m.to_dense();
for i in 0..40 {
for j in 0..40 {
if (i as isize - j as isize).unsigned_abs() > 3 {
assert_eq!(
dense[(i, j)],
0.0,
"nonzero at ({}, {}) outside bandwidth 3",
i,
j
);
}
}
}
}
#[test]
#[ignore] fn generation_performance_under_1s() {
let mut rng = seeded_rng();
let config = RandomMatrixConfig {
size: 1000,
target_nnz: 10000,
positive_definite: true,
};
let start = std::time::Instant::now();
let _m = generate_random_symmetric(&config, &mut rng).expect("generation failed");
let elapsed = start.elapsed();
assert!(
elapsed.as_secs_f64() < 1.0,
"generation took {:.3}s, expected < 1s",
elapsed.as_secs_f64()
);
}
#[test]
fn infeasible_config_returns_error() {
let mut rng = seeded_rng();
let config = RandomMatrixConfig {
size: 1,
target_nnz: 1,
positive_definite: false,
};
let result = generate_random_symmetric(&config, &mut rng);
assert!(result.is_err(), "1x1 indefinite should return error");
}
#[test]
fn excessive_nnz_clamped() {
let mut rng = seeded_rng();
let config = RandomMatrixConfig {
size: 5,
target_nnz: 1000,
positive_definite: true,
};
let m = generate_random_symmetric(&config, &mut rng).expect("generation failed");
let dense = m.to_dense();
let mut actual_nnz = 0;
for i in 0..5 {
for j in 0..5 {
if dense[(i, j)] != 0.0 {
actual_nnz += 1;
}
}
}
assert!(
actual_nnz <= 25,
"actual nnz {} exceeds max possible 25",
actual_nnz
);
}
#[test]
fn zero_size_returns_error() {
let mut rng = seeded_rng();
assert!(generate_arrow(0, true, &mut rng).is_err());
assert!(generate_tridiagonal(0, true, &mut rng).is_err());
assert!(generate_banded(0, 1, true, &mut rng).is_err());
assert!(
generate_random_symmetric(
&RandomMatrixConfig {
size: 0,
target_nnz: 0,
positive_definite: true
},
&mut rng,
)
.is_err()
);
}
#[test]
fn size_1_pd_succeeds() {
let mut rng = seeded_rng();
let m = generate_arrow(1, true, &mut rng).expect("1x1 arrow should work");
assert_eq!(m.nrows(), 1);
let m = generate_tridiagonal(1, true, &mut rng).expect("1x1 tridiagonal should work");
assert_eq!(m.nrows(), 1);
let m = generate_banded(1, 0, true, &mut rng).expect("1x1 banded should work");
assert_eq!(m.nrows(), 1);
}
#[test]
fn size_1_indefinite_returns_error() {
let mut rng = seeded_rng();
assert!(generate_arrow(1, false, &mut rng).is_err());
assert!(generate_tridiagonal(1, false, &mut rng).is_err());
assert!(generate_banded(1, 0, false, &mut rng).is_err());
}
#[test]
fn banded_zero_bandwidth_is_diagonal() {
let mut rng = seeded_rng();
let m = generate_banded(10, 0, true, &mut rng).expect("zero bandwidth should work");
let dense = m.to_dense();
for i in 0..10 {
for j in 0..10 {
if i != j {
assert_eq!(
dense[(i, j)],
0.0,
"off-diagonal ({}, {}) should be zero",
i,
j
);
}
}
}
}
#[test]
fn banded_large_bandwidth_fills_densely() {
let mut rng = seeded_rng();
let m = generate_banded(5, 10, true, &mut rng).expect("large bandwidth should work");
let dense = m.to_dense();
for i in 0..5 {
for j in 0..5 {
assert!(
dense[(i, j)] != 0.0,
"entry ({}, {}) should be nonzero",
i,
j
);
}
}
}
}