use crate::construct;
use crate::error::SparseResult;
use crate::sym_coo::{SymCooArray, SymCooMatrix};
use crate::sym_csr::{SymCsrArray, SymCsrMatrix};
use crate::sym_sparray::SymSparseArray;
use scirs2_core::numeric::{Float, SparseElement};
use std::fmt::Debug;
use std::ops::{Add, Div, Mul, Sub};
#[allow(dead_code)]
pub fn eye_sym_array<T>(n: usize, format: &str) -> SparseResult<Box<dyn SymSparseArray<T>>>
where
T: Float
+ SparseElement
+ Div<Output = T>
+ scirs2_core::simd_ops::SimdUnifiedOps
+ Send
+ Sync
+ 'static,
{
let mut data = Vec::with_capacity(n);
let one = T::sparse_one();
for _ in 0..n {
data.push(one);
}
match format.to_lowercase().as_str() {
"csr" => {
let mut indptr = Vec::with_capacity(n + 1);
indptr.push(0);
for i in 1..=n {
indptr.push(i);
}
let mut indices = Vec::with_capacity(n);
for i in 0..n {
indices.push(i);
}
let sym_csr = SymCsrMatrix::new(data, indptr, indices, (n, n))?;
Ok(Box::new(SymCsrArray::new(sym_csr)))
}
"coo" => {
let mut rows = Vec::with_capacity(n);
let mut cols = Vec::with_capacity(n);
for i in 0..n {
rows.push(i);
cols.push(i);
}
let sym_coo = SymCooMatrix::new(data, rows, cols, (n, n))?;
Ok(Box::new(SymCooArray::new(sym_coo)))
}
_ => Err(crate::error::SparseError::ValueError(format!(
"Unknown format: {format}. Supported formats are 'csr' and 'coo'"
))),
}
}
#[allow(dead_code)]
pub fn tridiagonal_sym_array<T>(
diag: &[T],
offdiag: &[T],
format: &str,
) -> SparseResult<Box<dyn SymSparseArray<T>>>
where
T: Float
+ SparseElement
+ Div<Output = T>
+ scirs2_core::simd_ops::SimdUnifiedOps
+ Send
+ Sync
+ 'static,
{
let n = diag.len();
if offdiag.len() != n - 1 {
return Err(crate::error::SparseError::ValueError(format!(
"Off-diagonal array must have length n-1 ({}), got {}",
n - 1,
offdiag.len()
)));
}
match format.to_lowercase().as_str() {
"csr" => {
let mut data = Vec::with_capacity(n + 2 * (n - 1));
let mut indices = Vec::with_capacity(n + 2 * (n - 1));
let mut indptr = Vec::with_capacity(n + 1);
indptr.push(0);
let mut nnz = 0;
if !SparseElement::is_zero(&diag[0]) {
data.push(diag[0]);
indices.push(0);
nnz += 1;
}
indptr.push(nnz);
for i in 1..n - 1 {
if !SparseElement::is_zero(&offdiag[i - 1]) {
data.push(offdiag[i - 1]);
indices.push(i - 1);
nnz += 1;
}
if !SparseElement::is_zero(&diag[i]) {
data.push(diag[i]);
indices.push(i);
nnz += 1;
}
indptr.push(nnz);
}
if n > 1 {
if !SparseElement::is_zero(&offdiag[n - 2]) {
data.push(offdiag[n - 2]);
indices.push(n - 2);
nnz += 1;
}
if !SparseElement::is_zero(&diag[n - 1]) {
data.push(diag[n - 1]);
indices.push(n - 1);
nnz += 1;
}
indptr.push(nnz);
}
let sym_csr = SymCsrMatrix::new(data, indptr, indices, (n, n))?;
Ok(Box::new(SymCsrArray::new(sym_csr)))
}
"coo" => {
let mut data = Vec::new();
let mut rows = Vec::new();
let mut cols = Vec::new();
for (i, &diag_val) in diag.iter().enumerate().take(n) {
if !SparseElement::is_zero(&diag_val) {
data.push(diag_val);
rows.push(i);
cols.push(i);
}
}
for (i, &offdiag_val) in offdiag.iter().enumerate().take(n - 1) {
if !SparseElement::is_zero(&offdiag_val) {
data.push(offdiag_val);
rows.push(i + 1);
cols.push(i);
}
}
let sym_coo = SymCooMatrix::new(data, rows, cols, (n, n))?;
Ok(Box::new(SymCooArray::new(sym_coo)))
}
_ => Err(crate::error::SparseError::ValueError(format!(
"Unknown format: {format}. Supported formats are 'csr' and 'coo'"
))),
}
}
#[allow(dead_code)]
pub fn banded_sym_array<T>(
diagonals: &[Vec<T>],
n: usize,
format: &str,
) -> SparseResult<Box<dyn SymSparseArray<T>>>
where
T: Float
+ SparseElement
+ Div<Output = T>
+ scirs2_core::simd_ops::SimdUnifiedOps
+ Send
+ Sync
+ 'static,
{
if diagonals.is_empty() {
return Err(crate::error::SparseError::ValueError(
"At least one diagonal must be provided".to_string(),
));
}
for (i, diag) in diagonals.iter().enumerate() {
let expected_len = n - i;
if diag.len() != expected_len {
return Err(crate::error::SparseError::ValueError(format!(
"Diagonal {i} should have length {expected_len}, got {}",
diag.len()
)));
}
}
match format.to_lowercase().as_str() {
"coo" => {
let mut data = Vec::new();
let mut rows = Vec::new();
let mut cols = Vec::new();
for i in 0..n {
if !SparseElement::is_zero(&diagonals[0][i]) {
data.push(diagonals[0][i]);
rows.push(i);
cols.push(i);
}
}
for (k, diag) in diagonals.iter().enumerate().skip(1) {
for (i, &diag_val) in diag.iter().enumerate() {
if !SparseElement::is_zero(&diag_val) {
data.push(diag_val);
rows.push(i + k);
cols.push(i);
}
}
}
let sym_coo = SymCooMatrix::new(data, rows, cols, (n, n))?;
Ok(Box::new(SymCooArray::new(sym_coo)))
}
"csr" => {
let mut data = Vec::new();
let mut indices = Vec::new();
let mut indptr = vec![0];
for i in 0..n {
for j in (i.saturating_sub(diagonals.len() - 1))..i {
let k = i - j; if k < diagonals.len() {
let val = diagonals[k][j];
if !SparseElement::is_zero(&val) {
data.push(val);
indices.push(j);
}
}
}
if !SparseElement::is_zero(&diagonals[0][i]) {
data.push(diagonals[0][i]);
indices.push(i);
}
indptr.push(data.len());
}
let sym_csr = SymCsrMatrix::new(data, indptr, indices, (n, n))?;
Ok(Box::new(SymCsrArray::new(sym_csr)))
}
_ => Err(crate::error::SparseError::ValueError(format!(
"Unknown format: {format}. Supported formats are 'csr' and 'coo'"
))),
}
}
#[allow(dead_code)]
pub fn random_sym_array<T>(
n: usize,
density: f64,
format: &str,
) -> SparseResult<Box<dyn SymSparseArray<T>>>
where
T: Float
+ SparseElement
+ Div<Output = T>
+ scirs2_core::simd_ops::SimdUnifiedOps
+ Send
+ Sync
+ 'static,
{
if !(0.0..=1.0).contains(&density) {
return Err(crate::error::SparseError::ValueError(
"Density must be between 0.0 and 1.0".to_string(),
));
}
let lower_tri_size = n * (n + 1) / 2;
let _nnz_lower = (lower_tri_size as f64 * density).round() as usize;
let random_array = construct::random_array::<T>((n, n), density, None, format)?;
let coo = random_array.to_coo().map_err(|e| {
crate::error::SparseError::ValueError(format!("Failed to convert random array to COO: {e}"))
})?;
let (rows, cols, data) = coo.find();
match format.to_lowercase().as_str() {
"csr" | "coo" => {
let sym_array = SymCooArray::from_triplets(
&rows.to_vec(),
&cols.to_vec(),
&data.to_vec(),
(n, n),
true,
)?;
if format.to_lowercase() == "csr" {
Ok(Box::new(sym_array.to_sym_csr()?))
} else {
Ok(Box::new(sym_array))
}
}
_ => Err(crate::error::SparseError::ValueError(format!(
"Unknown format: {format}. Supported formats are 'csr' and 'coo'"
))),
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_eye_sym_array() {
let eye_csr = eye_sym_array::<f64>(3, "csr").expect("Operation failed");
assert_eq!(eye_csr.shape(), (3, 3));
assert_eq!(eye_csr.nnz(), 3);
assert_eq!(eye_csr.nnz_stored(), 3);
assert_eq!(eye_csr.get(0, 0), 1.0);
assert_eq!(eye_csr.get(1, 1), 1.0);
assert_eq!(eye_csr.get(2, 2), 1.0);
assert_eq!(eye_csr.get(0, 1), 0.0);
let eye_coo = eye_sym_array::<f64>(3, "coo").expect("Operation failed");
assert_eq!(eye_coo.shape(), (3, 3));
assert_eq!(eye_coo.nnz(), 3);
assert_eq!(eye_coo.get(0, 0), 1.0);
assert_eq!(eye_coo.get(1, 1), 1.0);
assert_eq!(eye_coo.get(2, 2), 1.0);
assert_eq!(eye_coo.get(0, 1), 0.0);
}
#[test]
fn test_tridiagonal_sym_array() {
let diag = vec![2.0, 2.0, 2.0, 2.0];
let offdiag = vec![1.0, 1.0, 1.0];
let tri_csr = tridiagonal_sym_array(&diag, &offdiag, "csr").expect("Operation failed");
assert_eq!(tri_csr.shape(), (4, 4));
assert_eq!(tri_csr.nnz(), 10);
assert_eq!(tri_csr.get(0, 0), 2.0); assert_eq!(tri_csr.get(1, 1), 2.0);
assert_eq!(tri_csr.get(2, 2), 2.0);
assert_eq!(tri_csr.get(3, 3), 2.0);
assert_eq!(tri_csr.get(0, 1), 1.0); assert_eq!(tri_csr.get(1, 0), 1.0); assert_eq!(tri_csr.get(1, 2), 1.0);
assert_eq!(tri_csr.get(2, 1), 1.0);
assert_eq!(tri_csr.get(2, 3), 1.0);
assert_eq!(tri_csr.get(3, 2), 1.0);
assert_eq!(tri_csr.get(0, 2), 0.0); assert_eq!(tri_csr.get(0, 3), 0.0);
assert_eq!(tri_csr.get(1, 3), 0.0);
let tri_coo = tridiagonal_sym_array(&diag, &offdiag, "coo").expect("Operation failed");
assert_eq!(tri_coo.shape(), (4, 4));
assert_eq!(tri_coo.nnz(), 10);
assert_eq!(tri_coo.get(0, 0), 2.0);
assert_eq!(tri_coo.get(0, 1), 1.0);
assert_eq!(tri_coo.get(1, 0), 1.0);
}
#[test]
fn test_banded_sym_array() {
let diagonals = vec![
vec![2.0, 2.0, 2.0, 2.0, 2.0], vec![1.0, 1.0, 1.0, 1.0], vec![0.5, 0.5, 0.5], ];
let band_csr = banded_sym_array(&diagonals, 5, "csr").expect("Operation failed");
assert_eq!(band_csr.shape(), (5, 5));
for i in 0..5 {
assert_eq!(band_csr.get(i, i), 2.0); }
for i in 0..4 {
assert_eq!(band_csr.get(i, i + 1), 1.0);
assert_eq!(band_csr.get(i + 1, i), 1.0); }
for i in 0..3 {
assert_eq!(band_csr.get(i, i + 2), 0.5);
assert_eq!(band_csr.get(i + 2, i), 0.5); }
assert_eq!(band_csr.get(0, 3), 0.0);
assert_eq!(band_csr.get(0, 4), 0.0);
assert_eq!(band_csr.get(1, 4), 0.0);
let band_coo = banded_sym_array(&diagonals, 5, "coo").expect("Operation failed");
assert_eq!(band_coo.shape(), (5, 5));
assert_eq!(band_coo.get(0, 0), 2.0);
assert_eq!(band_coo.get(0, 1), 1.0);
assert_eq!(band_coo.get(0, 2), 0.5);
}
#[test]
fn test_random_sym_array() {
let n = 5;
let density = 0.8;
let rand_csr = match random_sym_array::<f64>(n, density, "csr") {
Ok(array) => array,
Err(e) => {
println!("Warning: Random generation failed with error: {e}");
return; }
};
assert_eq!(rand_csr.shape(), (n, n));
assert!(rand_csr.is_symmetric());
for i in 0..n {
for j in 0..i {
assert_relative_eq!(rand_csr.get(i, j), rand_csr.get(j, i), epsilon = 1e-10);
}
}
let rand_coo = random_sym_array::<f64>(n, density, "coo").expect("Operation failed");
assert_eq!(rand_coo.shape(), (n, n));
assert!(rand_coo.is_symmetric());
}
}