microgemm 0.2.0-alpha

General matrix multiplication with custom configuration in Rust. Supports no_std and no_alloc environments.
Documentation
use super::NeonKernel4x4;
use crate::{kernels::dbg_check_microkernel_inputs, typenum::U4, Kernel, MatMut, MatRef};
use core::arch::aarch64::{
    vaddq_f32, vfmaq_laneq_f32, vld1q_f32, vmovq_n_f32, vmulq_n_f32, vst1q_f32,
};

impl Kernel for NeonKernel4x4<f32> {
    type Scalar = f32;
    type Mr = U4;
    type Nr = U4;

    fn microkernel(
        &self,
        alpha: f32,
        lhs: &MatRef<f32>,
        rhs: &MatRef<f32>,
        beta: f32,
        dst: &mut MatMut<f32>,
    ) {
        dbg_check_microkernel_inputs(self, lhs, rhs, dst);
        let kc = lhs.ncols();
        neon_4x4_microkernel_f32(
            kc,
            alpha,
            lhs.as_slice(),
            rhs.as_slice(),
            beta,
            dst.as_mut_slice(),
        );
    }
}

fn neon_4x4_microkernel_f32(
    kc: usize,
    alpha: f32,
    lhs: &[f32],
    rhs: &[f32],
    beta: f32,
    dst_colmajor: &mut [f32],
) {
    assert_eq!(lhs.len(), rhs.len());
    assert_eq!(lhs.len(), 4 * kc);

    unsafe { inner(kc, alpha, lhs.as_ptr(), rhs.as_ptr(), beta, dst_colmajor) };

    unsafe fn inner(
        kc: usize,
        alpha: f32,
        mut left: *const f32,
        mut right: *const f32,
        beta: f32,
        dst: &mut [f32],
    ) {
        let mut cols0 = [vmovq_n_f32(0f32); 4];
        let mut cols1 = [vmovq_n_f32(0f32); 4];
        let mut cols2 = [vmovq_n_f32(0f32); 4];
        let mut cols3 = [vmovq_n_f32(0f32); 4];

        macro_rules! accum_lane {
            ($cols:ident, $a:ident, $b:ident) => {
                $cols[0] = vfmaq_laneq_f32::<0>($cols[0], $a, $b);
                $cols[1] = vfmaq_laneq_f32::<1>($cols[1], $a, $b);
                $cols[2] = vfmaq_laneq_f32::<2>($cols[2], $a, $b);
                $cols[3] = vfmaq_laneq_f32::<3>($cols[3], $a, $b);
            };
        }

        for _ in 0..kc / 4 {
            let a = vld1q_f32(left);
            let b = vld1q_f32(right);
            accum_lane!(cols0, a, b);

            let a = vld1q_f32(left.add(4));
            let b = vld1q_f32(right.add(4));
            accum_lane!(cols1, a, b);

            let a = vld1q_f32(left.add(8));
            let b = vld1q_f32(right.add(8));
            accum_lane!(cols2, a, b);

            let a = vld1q_f32(left.add(12));
            let b = vld1q_f32(right.add(12));
            accum_lane!(cols3, a, b);

            left = left.add(16);
            right = right.add(16);
        }

        for _ in 0..kc % 4 {
            let a = vld1q_f32(left);
            let b = vld1q_f32(right);
            accum_lane!(cols0, a, b);

            left = left.add(4);
            right = right.add(4);
        }

        for row in 0..4 {
            let sum = vaddq_f32(
                vaddq_f32(cols0[row], cols1[row]),
                vaddq_f32(cols2[row], cols3[row]),
            );
            cols0[row] = vmulq_n_f32(sum, alpha);
        }

        let it = dst.chunks_exact_mut(4).zip(cols0);
        if beta == 0f32 {
            it.for_each(|(to, from)| {
                vst1q_f32(to.as_mut_ptr(), from);
            });
        } else {
            it.for_each(|(to, from)| {
                let mut tmp = [0f32; 4];
                vst1q_f32(tmp.as_mut_ptr(), from);
                to.iter_mut().zip(tmp).for_each(|(y, x)| {
                    *y = x + beta * *y;
                });
            });
        }
    }
}

#[cfg(test)]
mod tests {
    use std::arch::is_aarch64_feature_detected;

    use super::*;
    use crate::kernels::GenericKernel4x4;
    use crate::utils::*;

    use rand::{thread_rng, Rng};

    fn neon_kernel<T>() -> NeonKernel4x4<T> {
        if is_aarch64_feature_detected!("neon") {
            unsafe { NeonKernel4x4::new() }
        } else {
            panic!("neon feature is not supported");
        }
    }

    #[test]
    fn test_neon_naive_f32() {
        let kernel = &neon_kernel::<f32>();
        let mut rng = thread_rng();
        let cmp = |expect: &[f32], got: &[f32]| {
            let eps = 80.0 * f32::EPSILON;
            assert_approx_eq(expect, got, eps);
        };
        for _ in 0..40 {
            let scalar = || rng.gen_range(-1.0..1.0);
            random_kernel_test(kernel, scalar, cmp);
        }
    }

    #[test]
    fn test_neon_f32() {
        const DIM: usize = 4;

        let generic_kernel = &GenericKernel4x4::<f32>::new();
        let neon_kernel = &neon_kernel::<f32>();

        let mut rng = thread_rng();
        let cmp = |expect: &[f32], got: &[f32]| {
            let eps = 75.0 * f32::EPSILON;
            assert_approx_eq(expect, got, eps);
        };

        for _ in 0..60 {
            let mc = rng.gen_range(1..20) * DIM;
            let nc = rng.gen_range(1..20) * DIM;
            let scalar = || rng.gen_range(-1.0..1.0);

            cmp_kernels_with_random_data(generic_kernel, neon_kernel, scalar, cmp, mc, nc);
        }
    }
}