use crate::coo_array::CooArray;
use crate::csr_array::CsrArray;
use crate::error::{SparseError, SparseResult};
use crate::sparray::SparseArray;
use scirs2_core::numeric::{Float, SparseElement};
use std::fmt::Debug;
use std::ops::{Add, AddAssign, Div, Mul, Sub};
#[allow(dead_code)]
pub fn hstack<'a, T>(
arrays: &[&'a dyn SparseArray<T>],
format: &str,
) -> SparseResult<Box<dyn SparseArray<T>>>
where
T: 'a
+ Float
+ SparseElement
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T>
+ Div<Output = T>
+ Debug
+ Copy
+ 'static,
{
if arrays.is_empty() {
return Err(SparseError::ValueError(
"Cannot stack empty list of arrays".to_string(),
));
}
let firstshape = arrays[0].shape();
let m = firstshape.0;
for (_i, &array) in arrays.iter().enumerate().skip(1) {
let shape = array.shape();
if shape.0 != m {
return Err(SparseError::DimensionMismatch {
expected: m,
found: shape.0,
});
}
}
let mut n = 0;
for &array in arrays.iter() {
n += array.shape().1;
}
let mut rows = Vec::new();
let mut cols = Vec::new();
let mut data = Vec::new();
let mut col_offset = 0;
for &array in arrays.iter() {
let shape = array.shape();
let (array_rows, array_cols, array_data) = array.find();
for i in 0..array_data.len() {
rows.push(array_rows[i]);
cols.push(array_cols[i] + col_offset);
data.push(array_data[i]);
}
col_offset += shape.1;
}
match format.to_lowercase().as_str() {
"csr" => CsrArray::from_triplets(&rows, &cols, &data, (m, n), false)
.map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
"coo" => CooArray::from_triplets(&rows, &cols, &data, (m, n), false)
.map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
_ => Err(SparseError::ValueError(format!(
"Unknown sparse format: {format}. Supported formats are 'csr' and 'coo'"
))),
}
}
#[allow(dead_code)]
pub fn vstack<'a, T>(
arrays: &[&'a dyn SparseArray<T>],
format: &str,
) -> SparseResult<Box<dyn SparseArray<T>>>
where
T: 'a
+ Float
+ SparseElement
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T>
+ Div<Output = T>
+ Debug
+ Copy
+ 'static,
{
if arrays.is_empty() {
return Err(SparseError::ValueError(
"Cannot stack empty list of arrays".to_string(),
));
}
let firstshape = arrays[0].shape();
let n = firstshape.1;
for (_i, &array) in arrays.iter().enumerate().skip(1) {
let shape = array.shape();
if shape.1 != n {
return Err(SparseError::DimensionMismatch {
expected: n,
found: shape.1,
});
}
}
let mut m = 0;
for &array in arrays.iter() {
m += array.shape().0;
}
let mut rows = Vec::new();
let mut cols = Vec::new();
let mut data = Vec::new();
let mut row_offset = 0;
for &array in arrays.iter() {
let shape = array.shape();
let (array_rows, array_cols, array_data) = array.find();
for i in 0..array_data.len() {
rows.push(array_rows[i] + row_offset);
cols.push(array_cols[i]);
data.push(array_data[i]);
}
row_offset += shape.0;
}
match format.to_lowercase().as_str() {
"csr" => CsrArray::from_triplets(&rows, &cols, &data, (m, n), false)
.map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
"coo" => CooArray::from_triplets(&rows, &cols, &data, (m, n), false)
.map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
_ => Err(SparseError::ValueError(format!(
"Unknown sparse format: {format}. Supported formats are 'csr' and 'coo'"
))),
}
}
#[allow(dead_code)]
pub fn block_diag<'a, T>(
arrays: &[&'a dyn SparseArray<T>],
format: &str,
) -> SparseResult<Box<dyn SparseArray<T>>>
where
T: 'a
+ Float
+ SparseElement
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T>
+ Div<Output = T>
+ Debug
+ Copy
+ 'static,
{
if arrays.is_empty() {
return Err(SparseError::ValueError(
"Cannot create block diagonal with empty list of arrays".to_string(),
));
}
let mut total_rows = 0;
let mut total_cols = 0;
for &array in arrays.iter() {
let shape = array.shape();
total_rows += shape.0;
total_cols += shape.1;
}
let mut rows = Vec::new();
let mut cols = Vec::new();
let mut data = Vec::new();
let mut row_offset = 0;
let mut col_offset = 0;
for &array in arrays.iter() {
let shape = array.shape();
let (array_rows, array_cols, array_data) = array.find();
for i in 0..array_data.len() {
rows.push(array_rows[i] + row_offset);
cols.push(array_cols[i] + col_offset);
data.push(array_data[i]);
}
row_offset += shape.0;
col_offset += shape.1;
}
match format.to_lowercase().as_str() {
"csr" => CsrArray::from_triplets(&rows, &cols, &data, (total_rows, total_cols), false)
.map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
"coo" => CooArray::from_triplets(&rows, &cols, &data, (total_rows, total_cols), false)
.map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
_ => Err(SparseError::ValueError(format!(
"Unknown sparse format: {format}. Supported formats are 'csr' and 'coo'"
))),
}
}
#[allow(dead_code)]
pub fn tril<T>(
array: &dyn SparseArray<T>,
k: isize,
format: &str,
) -> SparseResult<Box<dyn SparseArray<T>>>
where
T: Float
+ SparseElement
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T>
+ Div<Output = T>
+ Debug
+ Copy
+ 'static,
{
let shape = array.shape();
let (rows, cols, data) = array.find();
let mut tril_rows = Vec::new();
let mut tril_cols = Vec::new();
let mut tril_data = Vec::new();
for i in 0..data.len() {
let row = rows[i];
let col = cols[i];
if (row as isize) >= (col as isize) - k {
tril_rows.push(row);
tril_cols.push(col);
tril_data.push(data[i]);
}
}
match format.to_lowercase().as_str() {
"csr" => CsrArray::from_triplets(&tril_rows, &tril_cols, &tril_data, shape, false)
.map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
"coo" => CooArray::from_triplets(&tril_rows, &tril_cols, &tril_data, shape, false)
.map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
_ => Err(SparseError::ValueError(format!(
"Unknown sparse format: {format}. Supported formats are 'csr' and 'coo'"
))),
}
}
#[allow(dead_code)]
pub fn triu<T>(
array: &dyn SparseArray<T>,
k: isize,
format: &str,
) -> SparseResult<Box<dyn SparseArray<T>>>
where
T: Float
+ SparseElement
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T>
+ Div<Output = T>
+ Debug
+ Copy
+ 'static,
{
let shape = array.shape();
let (rows, cols, data) = array.find();
let mut triu_rows = Vec::new();
let mut triu_cols = Vec::new();
let mut triu_data = Vec::new();
for i in 0..data.len() {
let row = rows[i];
let col = cols[i];
if (row as isize) <= (col as isize) - k {
triu_rows.push(row);
triu_cols.push(col);
triu_data.push(data[i]);
}
}
match format.to_lowercase().as_str() {
"csr" => CsrArray::from_triplets(&triu_rows, &triu_cols, &triu_data, shape, false)
.map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
"coo" => CooArray::from_triplets(&triu_rows, &triu_cols, &triu_data, shape, false)
.map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
_ => Err(SparseError::ValueError(format!(
"Unknown sparse format: {format}. Supported formats are 'csr' and 'coo'"
))),
}
}
#[allow(dead_code)]
pub fn kron<'a, T>(
a: &'a dyn SparseArray<T>,
b: &'a dyn SparseArray<T>,
format: &str,
) -> SparseResult<Box<dyn SparseArray<T>>>
where
T: 'a
+ Float
+ SparseElement
+ Add<Output = T>
+ AddAssign
+ Sub<Output = T>
+ Mul<Output = T>
+ Div<Output = T>
+ Debug
+ Copy
+ 'static,
{
let ashape = a.shape();
let bshape = b.shape();
let outputshape = (ashape.0 * bshape.0, ashape.1 * bshape.1);
if a.nnz() == 0 || b.nnz() == 0 {
let empty_rows: Vec<usize> = Vec::new();
let empty_cols: Vec<usize> = Vec::new();
let empty_data: Vec<T> = Vec::new();
return match format.to_lowercase().as_str() {
"csr" => {
CsrArray::from_triplets(&empty_rows, &empty_cols, &empty_data, outputshape, false)
.map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
}
"coo" => {
CooArray::from_triplets(&empty_rows, &empty_cols, &empty_data, outputshape, false)
.map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
}
_ => Err(SparseError::ValueError(format!(
"Unknown sparse format: {format}. Supported formats are 'csr' and 'coo'"
))),
};
}
let b_coo = b.to_coo().expect("Failed to convert to COO");
let (b_rows, b_cols, b_data) = b_coo.find();
let a_coo = a.to_coo().expect("Failed to convert to COO");
let (a_rows, a_cols, a_data) = a_coo.find();
let nnz_a = a_data.len();
let nnz_b = b_data.len();
let nnz_output = nnz_a * nnz_b;
let mut out_rows = Vec::with_capacity(nnz_output);
let mut out_cols = Vec::with_capacity(nnz_output);
let mut out_data = Vec::with_capacity(nnz_output);
for i in 0..nnz_a {
for j in 0..nnz_b {
let row = a_rows[i] * bshape.0 + b_rows[j];
let col = a_cols[i] * bshape.1 + b_cols[j];
let val = a_data[i] * b_data[j];
out_rows.push(row);
out_cols.push(col);
out_data.push(val);
}
}
match format.to_lowercase().as_str() {
"csr" => CsrArray::from_triplets(&out_rows, &out_cols, &out_data, outputshape, false)
.map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
"coo" => CooArray::from_triplets(&out_rows, &out_cols, &out_data, outputshape, false)
.map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
_ => Err(SparseError::ValueError(format!(
"Unknown sparse format: {format}. Supported formats are 'csr' and 'coo'"
))),
}
}
#[allow(dead_code)]
pub fn kronsum<'a, T>(
a: &'a dyn SparseArray<T>,
b: &'a dyn SparseArray<T>,
format: &str,
) -> SparseResult<Box<dyn SparseArray<T>>>
where
T: 'a
+ Float
+ SparseElement
+ Add<Output = T>
+ AddAssign
+ Sub<Output = T>
+ Mul<Output = T>
+ Div<Output = T>
+ Debug
+ Copy
+ 'static,
{
let ashape = a.shape();
let bshape = b.shape();
if ashape.0 != ashape.1 {
return Err(SparseError::ValueError(
"First matrix must be square".to_string(),
));
}
if bshape.0 != bshape.1 {
return Err(SparseError::ValueError(
"Second matrix must be square".to_string(),
));
}
let m = ashape.0;
let n = bshape.0;
if is_identity_matrix(a) && is_identity_matrix(b) {
let outputshape = (m * n, m * n);
let mut rows = Vec::new();
let mut cols = Vec::new();
let mut data = Vec::new();
for i in 0..m * n {
rows.push(i);
cols.push(i);
data.push(T::sparse_one() + T::sparse_one()); }
for i in 0..n {
for j in 0..n {
if i != j && (b.get(i, j) > T::sparse_zero() || b.get(j, i) > T::sparse_zero()) {
for k in 0..m {
rows.push(i * m + k);
cols.push(j * m + k);
data.push(T::sparse_one());
}
}
}
}
for i in 0..n - 1 {
for j in 0..m {
rows.push(i * m + j);
cols.push((i + 1) * m + j);
data.push(T::sparse_one());
rows.push((i + 1) * m + j);
cols.push(i * m + j);
data.push(T::sparse_one());
}
}
return match format.to_lowercase().as_str() {
"csr" => CsrArray::from_triplets(&rows, &cols, &data, outputshape, true)
.map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
"coo" => CooArray::from_triplets(&rows, &cols, &data, outputshape, true)
.map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
_ => Err(SparseError::ValueError(format!(
"Unknown sparse format: {format}. Supported formats are 'csr' and 'coo'"
))),
};
}
let outputshape = (m * n, m * n);
let mut rows = Vec::new();
let mut cols = Vec::new();
let mut data = Vec::new();
let (a_rows, a_cols, a_data) = a.find();
for i in 0..n {
for k in 0..a_data.len() {
let row_idx = i * m + a_rows[k];
let col_idx = i * m + a_cols[k];
rows.push(row_idx);
cols.push(col_idx);
data.push(a_data[k]);
}
}
let (b_rows, b_cols, b_data) = b.find();
for k in 0..b_data.len() {
let b_row = b_rows[k];
let b_col = b_cols[k];
for i in 0..m {
let row_idx = b_row * m + i;
let col_idx = b_col * m + i;
rows.push(row_idx);
cols.push(col_idx);
data.push(b_data[k]);
}
}
match format.to_lowercase().as_str() {
"csr" => CsrArray::from_triplets(&rows, &cols, &data, outputshape, true)
.map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
"coo" => CooArray::from_triplets(&rows, &cols, &data, outputshape, true)
.map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
_ => Err(SparseError::ValueError(format!(
"Unknown sparse format: {format}. Supported formats are 'csr' and 'coo'"
))),
}
}
#[allow(dead_code)]
pub fn bmat<'a, T>(
blocks: &[Vec<Option<&'a dyn SparseArray<T>>>],
format: &str,
) -> SparseResult<Box<dyn SparseArray<T>>>
where
T: 'a
+ Float
+ SparseElement
+ Add<Output = T>
+ AddAssign
+ Sub<Output = T>
+ Mul<Output = T>
+ Div<Output = T>
+ Debug
+ Copy
+ 'static,
{
if blocks.is_empty() {
return Err(SparseError::ValueError(
"Empty blocks array provided".to_string(),
));
}
let m = blocks.len(); let n = blocks[0].len();
for (i, row) in blocks.iter().enumerate() {
if row.len() != n {
return Err(SparseError::ValueError(format!(
"Block row {i} has length {}, expected {n}",
row.len()
)));
}
}
let mut row_sizes = vec![0; m];
let mut col_sizes = vec![0; n];
let mut block_mask = vec![vec![false; n]; m];
for (i, row_size) in row_sizes.iter_mut().enumerate().take(m) {
for (j, col_size) in col_sizes.iter_mut().enumerate().take(n) {
if let Some(block) = blocks[i][j] {
let shape = block.shape();
if *row_size == 0 {
*row_size = shape.0;
} else if *row_size != shape.0 {
return Err(SparseError::ValueError(format!(
"Inconsistent row dimensions in block row {i}. Expected {}, got {}",
row_sizes[i], shape.0
)));
}
if *col_size == 0 {
*col_size = shape.1;
} else if *col_size != shape.1 {
return Err(SparseError::ValueError(format!(
"Inconsistent column dimensions in block column {j}. Expected {}, got {}",
*col_size, shape.1
)));
}
block_mask[i][j] = true;
}
}
}
for (i, &row_size) in row_sizes.iter().enumerate().take(m) {
if row_size == 0 {
return Err(SparseError::ValueError(format!(
"Block row {i} has no arrays, cannot determine dimensions"
)));
}
}
for (j, &col_size) in col_sizes.iter().enumerate().take(n) {
if col_size == 0 {
return Err(SparseError::ValueError(format!(
"Block column {j} has no arrays, cannot determine dimensions"
)));
}
}
let mut row_offsets = vec![0; m + 1];
let mut col_offsets = vec![0; n + 1];
for i in 0..m {
row_offsets[i + 1] = row_offsets[i] + row_sizes[i];
}
for j in 0..n {
col_offsets[j + 1] = col_offsets[j] + col_sizes[j];
}
let totalshape = (row_offsets[m], col_offsets[n]);
let mut has_blocks = false;
for mask_row in block_mask.iter().take(m) {
for &mask_elem in mask_row.iter().take(n) {
if mask_elem {
has_blocks = true;
break;
}
}
if has_blocks {
break;
}
}
if !has_blocks {
let empty_rows: Vec<usize> = Vec::new();
let empty_cols: Vec<usize> = Vec::new();
let empty_data: Vec<T> = Vec::new();
return match format.to_lowercase().as_str() {
"csr" => {
CsrArray::from_triplets(&empty_rows, &empty_cols, &empty_data, totalshape, false)
.map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
}
"coo" => {
CooArray::from_triplets(&empty_rows, &empty_cols, &empty_data, totalshape, false)
.map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
}
_ => Err(SparseError::ValueError(format!(
"Unknown sparse format: {format}. Supported formats are 'csr' and 'coo'"
))),
};
}
let mut rows = Vec::new();
let mut cols = Vec::new();
let mut data = Vec::new();
for (i, row_offset) in row_offsets.iter().take(m).enumerate() {
for (j, col_offset) in col_offsets.iter().take(n).enumerate() {
if let Some(block) = blocks[i][j] {
let (block_rows, block_cols, block_data) = block.find();
for (((row, col), val), _) in block_rows
.iter()
.zip(block_cols.iter())
.zip(block_data.iter())
.zip(0..block_data.len())
{
rows.push(*row + *row_offset);
cols.push(*col + *col_offset);
data.push(*val);
}
}
}
}
match format.to_lowercase().as_str() {
"csr" => CsrArray::from_triplets(&rows, &cols, &data, totalshape, false)
.map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
"coo" => CooArray::from_triplets(&rows, &cols, &data, totalshape, false)
.map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
_ => Err(SparseError::ValueError(format!(
"Unknown sparse format: {format}. Supported formats are 'csr' and 'coo'"
))),
}
}
#[allow(dead_code)]
fn is_identity_matrix<T>(array: &dyn SparseArray<T>) -> bool
where
T: Float + SparseElement + Debug + Copy + 'static,
{
let shape = array.shape();
if shape.0 != shape.1 {
return false;
}
let n = shape.0;
if array.nnz() != n {
return false;
}
let (rows, cols, data) = array.find();
if rows.len() != n {
return false;
}
for i in 0..rows.len() {
if rows[i] != cols[i] {
return false;
}
if (data[i] - T::sparse_one()).abs() > T::epsilon() {
return false;
}
}
true
}
#[cfg(test)]
mod tests {
use super::*;
use crate::construct::eye_array;
#[test]
fn test_hstack() {
let a = eye_array::<f64>(2, "csr").expect("Test operation failed");
let b = eye_array::<f64>(2, "csr").expect("Test operation failed");
let c = hstack(&[&*a, &*b], "csr").expect("Test: hstack failed");
assert_eq!(c.shape(), (2, 4));
assert_eq!(c.get(0, 0), 1.0);
assert_eq!(c.get(1, 1), 1.0);
assert_eq!(c.get(0, 2), 1.0);
assert_eq!(c.get(1, 3), 1.0);
assert_eq!(c.get(0, 1), 0.0);
assert_eq!(c.get(0, 3), 0.0);
}
#[test]
fn test_vstack() {
let a = eye_array::<f64>(2, "csr").expect("Test operation failed");
let b = eye_array::<f64>(2, "csr").expect("Test operation failed");
let c = vstack(&[&*a, &*b], "csr").expect("Test: vstack failed");
assert_eq!(c.shape(), (4, 2));
assert_eq!(c.get(0, 0), 1.0);
assert_eq!(c.get(1, 1), 1.0);
assert_eq!(c.get(2, 0), 1.0);
assert_eq!(c.get(3, 1), 1.0);
assert_eq!(c.get(0, 1), 0.0);
assert_eq!(c.get(1, 0), 0.0);
}
#[test]
fn test_block_diag() {
let a = eye_array::<f64>(2, "csr").expect("Test operation failed");
let b = eye_array::<f64>(3, "csr").expect("Test operation failed");
let c = block_diag(&[&*a, &*b], "csr").expect("Test: block_diag failed");
assert_eq!(c.shape(), (5, 5));
assert_eq!(c.get(0, 0), 1.0);
assert_eq!(c.get(1, 1), 1.0);
assert_eq!(c.get(2, 2), 1.0);
assert_eq!(c.get(3, 3), 1.0);
assert_eq!(c.get(4, 4), 1.0);
assert_eq!(c.get(0, 2), 0.0);
assert_eq!(c.get(2, 0), 0.0);
}
#[test]
fn test_kron() {
let a = eye_array::<f64>(2, "csr").expect("Test operation failed");
let b = eye_array::<f64>(2, "csr").expect("Test operation failed");
let c = kron(&*a, &*b, "csr").expect("Test: kron failed");
assert_eq!(c.shape(), (4, 4));
assert_eq!(c.get(0, 0), 1.0);
assert_eq!(c.get(1, 1), 1.0);
assert_eq!(c.get(2, 2), 1.0);
assert_eq!(c.get(3, 3), 1.0);
assert_eq!(c.get(0, 1), 0.0);
assert_eq!(c.get(0, 2), 0.0);
assert_eq!(c.get(1, 0), 0.0);
let rowsa = vec![0, 0, 1];
let cols_a = vec![0, 1, 0];
let data_a = vec![1.0, 2.0, 3.0];
let a = CooArray::from_triplets(&rowsa, &cols_a, &data_a, (2, 2), false)
.expect("Test operation failed");
let rowsb = vec![0, 1];
let cols_b = vec![0, 1];
let data_b = vec![4.0, 5.0];
let b = CooArray::from_triplets(&rowsb, &cols_b, &data_b, (2, 2), false)
.expect("Test operation failed");
let c = kron(&a, &b, "csr").expect("Test: kron failed");
assert_eq!(c.shape(), (4, 4));
assert_eq!(c.get(0, 0), 4.0);
assert_eq!(c.get(0, 2), 8.0);
assert_eq!(c.get(1, 1), 5.0);
assert_eq!(c.get(1, 3), 10.0);
assert_eq!(c.get(2, 0), 12.0);
assert_eq!(c.get(3, 1), 15.0);
assert_eq!(c.get(0, 1), 0.0);
assert_eq!(c.get(0, 3), 0.0);
assert_eq!(c.get(2, 1), 0.0);
assert_eq!(c.get(2, 2), 0.0);
assert_eq!(c.get(2, 3), 0.0);
assert_eq!(c.get(3, 0), 0.0);
assert_eq!(c.get(3, 2), 0.0);
assert_eq!(c.get(3, 3), 0.0);
}
#[test]
fn test_kronsum() {
let a = eye_array::<f64>(2, "csr").expect("Test operation failed");
let b = eye_array::<f64>(2, "csr").expect("Test operation failed");
let c = kronsum(&*a, &*b, "csr").expect("Test: kronsum failed");
assert_eq!(c.shape(), (4, 4));
let (rows, _cols, data) = c.find();
assert!(!rows.is_empty());
assert!(!data.is_empty());
let c_coo = kronsum(&*a, &*b, "coo").expect("Test: kronsum failed");
assert_eq!(c_coo.shape(), (4, 4));
let (coo_rows, _coo_cols, coo_data) = c_coo.find();
assert!(!coo_rows.is_empty());
assert!(!coo_data.is_empty());
}
#[test]
fn test_tril() {
let rows = vec![0, 0, 0, 1, 1, 1, 2, 2, 2];
let cols = vec![0, 1, 2, 0, 1, 2, 0, 1, 2];
let data = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
let a = CooArray::from_triplets(&rows, &cols, &data, (3, 3), false)
.expect("Test operation failed");
let b = tril(&a, 0, "csr").expect("Test: tril failed");
assert_eq!(b.shape(), (3, 3));
assert_eq!(b.get(0, 0), 1.0);
assert_eq!(b.get(1, 0), 1.0);
assert_eq!(b.get(1, 1), 1.0);
assert_eq!(b.get(2, 0), 1.0);
assert_eq!(b.get(2, 1), 1.0);
assert_eq!(b.get(2, 2), 1.0);
assert_eq!(b.get(0, 1), 0.0);
assert_eq!(b.get(0, 2), 0.0);
assert_eq!(b.get(1, 2), 0.0);
let c = tril(&a, 1, "csr").expect("Test: tril failed");
assert_eq!(c.get(0, 0), 1.0);
assert_eq!(c.get(0, 1), 1.0); assert_eq!(c.get(0, 2), 0.0); assert_eq!(c.get(1, 1), 1.0);
assert_eq!(c.get(1, 2), 1.0); }
#[test]
fn test_triu() {
let rows = vec![0, 0, 0, 1, 1, 1, 2, 2, 2];
let cols = vec![0, 1, 2, 0, 1, 2, 0, 1, 2];
let data = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
let a = CooArray::from_triplets(&rows, &cols, &data, (3, 3), false)
.expect("Test operation failed");
let b = triu(&a, 0, "csr").expect("Test: triu failed");
assert_eq!(b.shape(), (3, 3));
assert_eq!(b.get(0, 0), 1.0);
assert_eq!(b.get(0, 1), 1.0);
assert_eq!(b.get(0, 2), 1.0);
assert_eq!(b.get(1, 1), 1.0);
assert_eq!(b.get(1, 2), 1.0);
assert_eq!(b.get(2, 2), 1.0);
assert_eq!(b.get(1, 0), 0.0);
assert_eq!(b.get(2, 0), 0.0);
assert_eq!(b.get(2, 1), 0.0);
let c = triu(&a, -1, "csr").expect("Test: triu failed");
assert_eq!(c.get(0, 0), 1.0);
assert_eq!(c.get(1, 0), 1.0); assert_eq!(c.get(2, 0), 0.0); assert_eq!(c.get(1, 1), 1.0);
assert_eq!(c.get(2, 1), 1.0); }
#[test]
fn test_bmat() {
let a = eye_array::<f64>(2, "csr").expect("Test operation failed");
let b = eye_array::<f64>(2, "csr").expect("Test operation failed");
let blocks1 = vec![vec![Some(&*a), Some(&*b)], vec![Some(&*b), Some(&*a)]];
let c1 = bmat(&blocks1, "csr").expect("Test: bmat failed");
assert_eq!(c1.shape(), (4, 4));
assert_eq!(c1.get(0, 0), 1.0);
assert_eq!(c1.get(1, 1), 1.0);
assert_eq!(c1.get(2, 2), 1.0);
assert_eq!(c1.get(3, 3), 1.0);
assert_eq!(c1.get(0, 2), 1.0);
assert_eq!(c1.get(1, 3), 1.0);
assert_eq!(c1.get(2, 0), 1.0);
assert_eq!(c1.get(3, 1), 1.0);
assert_eq!(c1.get(0, 1), 0.0);
assert_eq!(c1.get(0, 3), 0.0);
assert_eq!(c1.get(2, 1), 0.0);
assert_eq!(c1.get(2, 3), 0.0);
let blocks2 = vec![vec![Some(&*a), Some(&*b)], vec![None, Some(&*a)]];
let c2 = bmat(&blocks2, "csr").expect("Test: bmat failed");
assert_eq!(c2.shape(), (4, 4));
assert_eq!(c2.get(0, 0), 1.0);
assert_eq!(c2.get(1, 1), 1.0);
assert_eq!(c2.get(2, 0), 0.0); assert_eq!(c2.get(2, 1), 0.0); assert_eq!(c2.get(2, 2), 1.0);
assert_eq!(c2.get(3, 3), 1.0);
let b1 = eye_array::<f64>(2, "csr").expect("Test operation failed");
let b2 = eye_array::<f64>(2, "csr").expect("Test operation failed");
let blocks3 = vec![vec![Some(&*b1), Some(&*b2)], vec![Some(&*b2), Some(&*b1)]];
let c3 = bmat(&blocks3, "csr").expect("Test: bmat failed");
assert_eq!(c3.shape(), (4, 4));
assert_eq!(c3.get(0, 0), 1.0);
assert_eq!(c3.get(1, 1), 1.0);
assert_eq!(c3.get(2, 2), 1.0);
assert_eq!(c3.get(3, 3), 1.0);
}
}