1use crate::{gemm_with_kernel, MatMut, MatRef, PackSizes};
2use core::ops::Mul;
3use generic_array::{
4 typenum::{Prod, Unsigned},
5 ArrayLength,
6};
7use num_traits::{One, Zero};
8
9#[cfg(test)]
10use allocator_api2::alloc::Allocator;
11
12pub trait Kernel
13where
14 Self::Scalar: Copy + Zero + One,
15{
16 type Scalar;
17 type Mr: ArrayLength + Multiply<Self::Nr>;
18 type Nr: ArrayLength;
19
20 const MR: usize = Self::Mr::USIZE;
21 const NR: usize = Self::Nr::USIZE;
22
23 fn microkernel(
24 &self,
25 alpha: Self::Scalar,
26 lhs: MatRef<Self::Scalar>,
27 rhs: MatRef<Self::Scalar>,
28 beta: Self::Scalar,
29 dst: &mut MatMut<Self::Scalar>,
30 );
31
32 #[inline]
33 #[allow(clippy::too_many_arguments)]
34 fn gemm(
35 &self,
36 alpha: Self::Scalar,
37 a: MatRef<Self::Scalar>,
38 b: MatRef<Self::Scalar>,
39 beta: Self::Scalar,
40 c: &mut MatMut<Self::Scalar>,
41 pack_sizes: PackSizes,
42 packing_buf: &mut [Self::Scalar],
43 ) {
44 gemm_with_kernel(self, alpha, a, b, beta, c, pack_sizes, packing_buf);
45 }
46
47 #[cfg(test)]
48 #[allow(clippy::too_many_arguments)]
49 fn gemm_in(
50 &self,
51 alloc: impl Allocator,
52 alpha: Self::Scalar,
53 a: MatRef<Self::Scalar>,
54 b: MatRef<Self::Scalar>,
55 beta: Self::Scalar,
56 c: &mut MatMut<Self::Scalar>,
57 pack_sizes: PackSizes,
58 ) {
59 use allocator_api2::vec::Vec;
60
61 let size = pack_sizes.buf_len();
62 let mut v = Vec::with_capacity_in(size, alloc);
63 v.resize(size, Self::Scalar::zero());
64 self.gemm(alpha, a, b, beta, c, pack_sizes, v.as_mut_slice());
65 }
66
67 fn mr(&self) -> usize {
68 Self::MR
69 }
70 fn nr(&self) -> usize {
71 Self::NR
72 }
73}
74
75pub trait Multiply<Rhs> {
76 type Output: ArrayLength;
77}
78
79impl<Lhs, Rhs> Multiply<Rhs> for Lhs
80where
81 Lhs: Mul<Rhs>,
82 Prod<Lhs, Rhs>: ArrayLength,
83{
84 type Output = Prod<Lhs, Rhs>;
85}