microgemm 0.3.1

General matrix multiplication with custom configuration in Rust. Supports no_std and no_alloc environments.
Documentation
use crate::{
    kernels::dbg_check_microkernel_inputs,
    typenum::{U16, U2, U32, U4, U8},
    Kernel, One, Zero,
};
use core::marker::PhantomData;
use core::ops::{Add, Mul};

fn loop_micropanels<T, const DIM: usize>(lhs: &[T], rhs: &[T], cols: &mut [T])
where
    T: Copy + Add<Output = T> + Mul<Output = T>,
{
    assert_eq!(cols.len(), DIM * DIM);
    assert_eq!(lhs.len() % DIM, 0);
    assert_eq!(lhs.len(), rhs.len());

    let left = lhs.chunks_exact(DIM);
    let right = rhs.chunks_exact(DIM);

    left.zip(right).for_each(|(a, b)| {
        let cols = cols.chunks_exact_mut(DIM);

        cols.zip(b).for_each(|(col, &scalar)| {
            col.iter_mut().zip(a).for_each(|(out, &x)| {
                *out = *out + x * scalar;
            });
        });
    });
}

fn write_cols_to_colmajor<T, const DIM: usize>(dst: &mut [T], cols: &[T], alpha: T, beta: T)
where
    T: Copy + Add<Output = T> + Mul<Output = T>,
{
    assert_eq!(dst.len(), DIM * DIM);
    assert_eq!(cols.len(), dst.len());
    dst.iter_mut().zip(cols).for_each(|(to, &from)| {
        *to = alpha * from + beta * *to;
    });
}

macro_rules! impl_generic_square_kernel {
    ($struct:ident, $dim:literal, $dimty:ty) => {
        #[derive(Debug, Clone, Copy, Default)]
        pub struct $struct<T>(PhantomData<T>);

        impl<T> $struct<T> {
            pub const fn new() -> Self {
                Self(PhantomData)
            }
        }
        impl<T> Kernel for $struct<T>
        where
            T: Copy + Zero + One + Add<Output = T> + Mul<Output = T>,
        {
            type Scalar = T;
            type Mr = $dimty;
            type Nr = $dimty;

            fn microkernel(
                &self,
                alpha: Self::Scalar,
                lhs: crate::MatRef<Self::Scalar>,
                rhs: crate::MatRef<Self::Scalar>,
                beta: Self::Scalar,
                dst: &mut crate::MatMut<Self::Scalar>,
            ) {
                dbg_check_microkernel_inputs(self, lhs, rhs, dst);

                const DIM: usize = $dim;
                let mut cols = [T::zero(); DIM * DIM];
                loop_micropanels::<_, DIM>(lhs.as_slice(), rhs.as_slice(), &mut cols);
                write_cols_to_colmajor::<_, DIM>(dst.as_mut_slice(), &cols, alpha, beta);
            }
        }
    };
}

impl_generic_square_kernel!(GenericKernel2x2, 2, U2);
impl_generic_square_kernel!(GenericKernel4x4, 4, U4);
impl_generic_square_kernel!(GenericKernel8x8, 8, U8);
impl_generic_square_kernel!(GenericKernel16x16, 16, U16);
impl_generic_square_kernel!(GenericKernel32x32, 32, U32);

#[cfg(test)]
mod proptests {
    use super::*;
    use crate::{
        std_prelude::*,
        utils::{is_debug_build, proptest_kernel, ProptestKernelCfg},
    };
    use proptest::prelude::*;

    fn cfg_i32() -> ProptestKernelCfg<i32> {
        let dim = if is_debug_build() { 17 } else { 83 };
        ProptestKernelCfg::default()
            .with_max_matrix_dim(dim)
            .with_max_pack_dim(2 * dim + 1)
            .with_scalar((-11..11).boxed())
    }

    #[test]
    fn proptest_generic_kernel_2x2_i32() {
        proptest_kernel(&GenericKernel2x2::new(), cfg_i32()).unwrap();
    }
    #[test]
    fn proptest_generic_kernel_4x4_i32() {
        proptest_kernel(&GenericKernel4x4::new(), cfg_i32()).unwrap();
    }
    #[test]
    fn proptest_generic_kernel_8x8_i32() {
        proptest_kernel(&GenericKernel8x8::new(), cfg_i32()).unwrap();
    }
    #[test]
    fn proptest_generic_kernel_16x16_i32() {
        proptest_kernel(&GenericKernel16x16::new(), cfg_i32()).unwrap();
    }
    #[test]
    fn proptest_generic_kernel_32x32_i32() {
        proptest_kernel(&GenericKernel32x32::new(), cfg_i32()).unwrap();
    }
}