use num_traits::Zero;
#[derive(Clone, Default, Debug, PartialEq)]
struct Mat {
data: [[i32; 2]; 2],
}
impl Mat {
fn new(data: [[i32; 2]; 2]) -> Self {
Self { data }
}
}
impl Zero for Mat {
fn zero() -> Self {
Self {
data: [[0_i32; 2]; 2],
}
}
fn is_zero(&self) -> bool {
self == &Self::zero()
}
}
impl std::ops::Add for Mat {
type Output = Self;
fn add(self, other: Self) -> Self {
let mut out = Self::Output::zero();
for i in 0..2 {
for j in 0..2 {
out.data[i][j] = self.data[i][j] + other.data[i][j]
}
}
out
}
}
impl std::ops::Mul for &Mat {
type Output = Mat;
fn mul(self, other: Self) -> Self::Output {
let mut out = Self::Output::zero();
for i in 0..2 {
for j in 0..2 {
for k in 0..2 {
out.data[i][j] += self.data[i][k] * other.data[k][j];
}
}
}
out
}
}
impl sprs::MulAcc for Mat {
fn mul_acc(&mut self, a: &Self, b: &Self) {
for i in 0..2 {
for j in 0..2 {
for k in 0..2 {
self.data[i][j] += a.data[i][k] * b.data[k][j];
}
}
}
}
}
#[test]
fn block_matrix_multiply() {
let mat1 = Mat::new([[1, 2], [3, 4]]);
let mat2 = Mat::new([[0, -3], [-2, -7]]);
assert_eq!(&mat1 * &mat1, Mat::new([[7, 10], [15, 22]]));
let smat1 = sprs::CsMat::new(
(2, 2),
vec![0, 1, 3],
vec![1, 0, 1],
vec![mat1.clone(), mat1, mat2],
);
let mat1 = Mat::new([[2, 0], [7, -4]]);
let mat2 = Mat::new([[0, -99], [9, -7]]);
let smat2 =
sprs::CsMat::new((2, 2), vec![0, 2, 2], vec![0, 1], vec![mat1, mat2]);
let smat3 = &smat1 * &smat2;
assert_eq!(smat3.indptr().raw_storage(), &[0, 0, 2]);
assert_eq!(smat3.indices(), &[0, 1]);
let data = smat3.data();
assert_eq!(data.len(), 2);
assert_eq!(data[0], Mat::new([[16, -8], [34, -16]]));
assert_eq!(data[1], Mat::new([[18, -113], [36, -325]]));
}