use algebra_sparse::CsrMatrixViewMethods;
use algebra_sparse::traits::IntoView;
use approx::assert_relative_eq;
type Real = f32;
type DMatrix = nalgebra::DMatrix<Real>;
type CsrMatrix = algebra_sparse::CsrMatrix<Real>;
type DiagonalBlockMatrix = algebra_sparse::DiagonalBlockMatrix<Real>;
type DiagonalBlockMatrixView<'a> = algebra_sparse::DiagonalBlockMatrixView<'a, Real>;
#[test]
fn test_csr_mul_bd() {
fn test_case(a: CsrMatrix, b: DiagonalBlockMatrixView) {
let a1 = a.to_dense();
let b1 = b.to_dense();
let r0 = a.as_view() * b;
let r1 = a1 * b1;
assert_relative_eq!(r0.to_dense(), r1);
}
let bd_mat = {
let mut data = vec![];
let mut v = 1.0;
for _ in 0..(6 * 6) {
data.push(v);
v += 1.0;
}
let m = DMatrix::from_vec(6, 6, data);
gen_block_diag_mat(&m, &[1, 2, 3])
};
test_case(create_csr(3, 6), bd_mat.into_view());
test_case(create_csr2(), create_diag_mat2().into_view());
}
#[test]
fn test_csr_bd_csc() {
let j = create_csr2();
let h = create_diag_mat2();
let h = h.into_view();
let j = j.as_view();
println!("J = {}", j.to_dense());
println!("H = {}", h.to_dense());
println!("JH = {}", (j * h).to_dense());
let r1 = (j * h).as_view() * j.transpose();
println!("{}", r1.to_dense());
}
#[test]
fn test_csr_mul_csc() {
let csr = create_csr(3, 6);
let csc = csr.as_view().transpose();
let r0 = csr.as_view() * csc;
let r1 = csr.to_dense() * csc.to_dense();
assert_relative_eq!(r0.to_dense(), r1);
}
fn create_csr(nrows: usize, ncols: usize) -> CsrMatrix {
let mut data = vec![];
let mut v = 1.0;
for i in 0..nrows {
for j in 0..ncols {
if i != j && j % 2 == 0 {
data.push(0.0);
} else {
data.push(v);
}
v += 1.0;
}
}
let m = DMatrix::from_vec(nrows, ncols, data);
CsrMatrix::from_dense(m.as_view())
}
fn create_csr2() -> CsrMatrix {
#[allow(clippy::approx_constant)]
#[rustfmt::skip]
let ju = DMatrix::from_row_slice(6, 6, &[
0.7071068, 0.0, 0.7071067, -0.35355335, 0.0, 0.3535534,
0.0, 1.0, 0.0, 0.5, 0.0, -0.5,
-0.7071067, 0.0, 0.7071068, -0.35355344, 0.7071067, -0.35355332,
0.7071068, 0.0, 0.7071067, 0.35355335, 0.0, -0.3535534,
0.0, 1.0, 0.0, 0.5, 0.0, -0.5,
-0.7071067, 0.0, 0.7071068, 0.35355338, 0.7071067, 0.35355338,
]);
CsrMatrix::from_dense(ju.as_view())
}
fn create_diag_mat2() -> DiagonalBlockMatrix {
let diag_elements = vec![
0.0010000002,
0.0010000002,
0.0010000002,
0.0059999996,
0.0059999996,
0.0059999996,
];
DiagonalBlockMatrix::from_block_values(diag_elements, &[1, 1, 1, 1, 1, 1])
}
fn gen_block_diag_mat(dense_mat: &DMatrix, block_sizes: &[usize]) -> DiagonalBlockMatrix {
let mut values = vec![];
let mut offset = 0;
for n in block_sizes {
let range = offset..offset + *n;
let src = dense_mat.view_range(range.clone(), range.clone());
values.extend(src.iter().copied());
offset += *n;
}
DiagonalBlockMatrix::from_block_values(values, block_sizes)
}