microgemm/
kernel.rs

1use crate::{gemm_with_kernel, MatMut, MatRef, PackSizes};
2use core::ops::Mul;
3use generic_array::{
4    typenum::{Prod, Unsigned},
5    ArrayLength,
6};
7use num_traits::{One, Zero};
8
9#[cfg(test)]
10use allocator_api2::alloc::Allocator;
11
12pub trait Kernel
13where
14    Self::Scalar: Copy + Zero + One,
15{
16    type Scalar;
17    type Mr: ArrayLength + Multiply<Self::Nr>;
18    type Nr: ArrayLength;
19
20    const MR: usize = Self::Mr::USIZE;
21    const NR: usize = Self::Nr::USIZE;
22
23    fn microkernel(
24        &self,
25        alpha: Self::Scalar,
26        lhs: MatRef<Self::Scalar>,
27        rhs: MatRef<Self::Scalar>,
28        beta: Self::Scalar,
29        dst: &mut MatMut<Self::Scalar>,
30    );
31
32    #[inline]
33    #[allow(clippy::too_many_arguments)]
34    fn gemm(
35        &self,
36        alpha: Self::Scalar,
37        a: MatRef<Self::Scalar>,
38        b: MatRef<Self::Scalar>,
39        beta: Self::Scalar,
40        c: &mut MatMut<Self::Scalar>,
41        pack_sizes: PackSizes,
42        packing_buf: &mut [Self::Scalar],
43    ) {
44        gemm_with_kernel(self, alpha, a, b, beta, c, pack_sizes, packing_buf);
45    }
46
47    #[cfg(test)]
48    #[allow(clippy::too_many_arguments)]
49    fn gemm_in(
50        &self,
51        alloc: impl Allocator,
52        alpha: Self::Scalar,
53        a: MatRef<Self::Scalar>,
54        b: MatRef<Self::Scalar>,
55        beta: Self::Scalar,
56        c: &mut MatMut<Self::Scalar>,
57        pack_sizes: PackSizes,
58    ) {
59        use allocator_api2::vec::Vec;
60
61        let size = pack_sizes.buf_len();
62        let mut v = Vec::with_capacity_in(size, alloc);
63        v.resize(size, Self::Scalar::zero());
64        self.gemm(alpha, a, b, beta, c, pack_sizes, v.as_mut_slice());
65    }
66
67    fn mr(&self) -> usize {
68        Self::MR
69    }
70    fn nr(&self) -> usize {
71        Self::NR
72    }
73}
74
75pub trait Multiply<Rhs> {
76    type Output: ArrayLength;
77}
78
79impl<Lhs, Rhs> Multiply<Rhs> for Lhs
80where
81    Lhs: Mul<Rhs>,
82    Prod<Lhs, Rhs>: ArrayLength,
83{
84    type Output = Prod<Lhs, Rhs>;
85}