microgemm 0.3.1

General matrix multiplication with custom configuration in Rust. Supports no_std and no_alloc environments.
Documentation
use crate::std_prelude::*;
use crate::{MatMut, MatRef, Zero};
use core::ops::{Add, Mul};

pub fn naive_gemm<T>(alpha: T, a: MatRef<T>, b: MatRef<T>, beta: T, c: &mut MatMut<T>)
where
    T: Copy + Add<Output = T> + Mul<Output = T> + Zero,
{
    assert_eq!(a.nrows(), c.nrows());
    assert_eq!(b.ncols(), c.ncols());
    assert_eq!(a.ncols(), b.nrows());

    let k = a.ncols();

    for i in 0..a.nrows() {
        for j in 0..b.ncols() {
            let dot = (0..k)
                .map(|h| a.get(i, h) * b.get(h, j))
                .reduce(|accum, x| accum + x)
                .unwrap_or(T::zero());
            let z = c.get_mut(i, j);
            *z = alpha * dot + beta * *z;
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[rustfmt::skip]
    #[test]
    fn fixed_1() {
        let a = [
            1, 2, 3,
            4, 5, 6,
        ];
        let b = [
            10, 11,
            20, 21,
            30, 31,
        ];
        let a = MatRef::row_major(2, 3, &a);
        let b = MatRef::row_major(3, 2, &b);

        let mut c = [-1; 4];
        let mut c = MatMut::row_major(2, 2, c.as_mut());
        let expect = [
            140, 146,
            320, 335,
        ];
        naive_gemm(1, a, b, 0, c.as_mut());
        assert_eq!(c.as_slice(), expect);

        let mut c = [-1; 4];
        let mut c = MatMut::col_major(2, 2, c.as_mut());
        let expect = [
            140, 320,
            146, 335,
        ];
        naive_gemm(1, a, b, 0, c.as_mut());
        assert_eq!(c.as_slice(), expect);
    }

    #[rustfmt::skip]
    #[test]
    fn fixed_2() {
        let alpha = 3;
        let beta = -4;

        let a = [
            1, 2, 3,
            4, 5, 6,
        ];
        let b = [
            2, 3, 4,
            5, 6, 7,
        ];
        let mut c = [
            -4, 1,
            -5, -6,
        ];
        let expect = [
            beta * c[0] + alpha * (2 + 2 * 3 + 3 * 4),
            beta * c[1] + alpha * (5 + 2 * 6 + 3 * 7),
            beta * c[2] + alpha * (4 * 2 + 5 * 3 + 6 * 4),
            beta * c[3] + alpha * (4 * 5 + 5 * 6 + 6 * 7),
        ];

        let a = MatRef::row_major(2, 3, a.as_ref());
        let b = MatRef::col_major(3, 2, b.as_ref());
        let mut c = MatMut::row_major(2, 2, c.as_mut());

        naive_gemm(alpha, a, b, beta, c.as_mut());
        assert_eq!(c.as_slice(), expect);
    }

    #[test]
    #[rustfmt::skip]
    fn fixed_3() {
        let a = [
            1, 0, 2,
            0, -1, 3,
        ];
        let a = MatRef::row_major(2, 3, a.as_ref());
        let b = [
            2, -1,
            0, 5,
            1, 1,
        ];
        let b = MatRef::row_major(3, 2, b.as_ref());

        let mut c = [-9; 2 * 2];
        let c = &mut MatMut::row_major(2, 2, c.as_mut());
        let expect = [
            4, 1,
            3, -2,
        ];
        naive_gemm(1, a, b, 0, c);
        assert_eq!(c.as_slice(), expect);
    }
}