1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
use burn_tensor::Element;

use crate::{
    element::WgpuElement,
    kernel::{DynamicKernelSource, SourceTemplate, StaticKernelSource},
    tensor::WgpuTensor,
};
use std::marker::PhantomData;

use crate::kernel_wgsl;

use super::base::{matmul_tiling_2d_launch, B_K, B_M, B_N, WORKGROUP_SIZE};

kernel_wgsl!(
    MatmulTiling2DVec4LhsRaw,
    "../../../template/matmul/blocktiling_2d/vec4_lhs.wgsl"
);

#[derive(new, Debug)]
struct MatmulTiling2DVec4Lhs<E: WgpuElement> {
    _elem: PhantomData<E>,
}

impl<E: WgpuElement> DynamicKernelSource for MatmulTiling2DVec4Lhs<E> {
    fn source(&self) -> SourceTemplate {
        MatmulTiling2DVec4LhsRaw::source()
            .register("b_m", B_M.to_string())
            .register("b_n", B_N.to_string())
            .register("b_k", B_K.to_string())
            .register("bm_x_bk_4", (B_M * B_K / 4).to_string())
            .register("bk_x_bn", (B_K * B_N).to_string())
            .register("workgroup_size_x", WORKGROUP_SIZE.to_string())
            .register("workgroup_size_y", WORKGROUP_SIZE.to_string())
            .register("workgroup_size_z", "1".to_string())
            .register("elem", E::type_name())
            .register("int", "i32")
    }

    fn id(&self) -> String {
        std::format!("{:?}", self)
    }
}

/// Matrix multiplication using tiling 2d algorithm with
/// vec4 primitive on lhs only
pub fn matmul_tiling_2d_vec4_lhs<E: WgpuElement + Element, const D: usize>(
    lhs: WgpuTensor<E, D>,
    rhs: WgpuTensor<E, D>,
    out: WgpuTensor<E, D>,
) -> WgpuTensor<E, D> {
    let kernel = MatmulTiling2DVec4Lhs::<E>::new();
    matmul_tiling_2d_launch(lhs, rhs, out, kernel)
}

#[cfg(test)]
mod tests {
    use crate::kernel::matmul::utils::tests::{same_as_reference, same_as_reference_swapped_dims};

    use super::matmul_tiling_2d_vec4_lhs;

    #[test]
    pub fn test_matmul_vec4_primitive_straightforward() {
        test_with_params(1, 2, 1, 1, 1);
    }

    #[test]
    pub fn test_matmul_vec4_primitive_shapes_smaller_than_blocks() {
        test_with_params(8, 8, 8, 1, 1);
    }

    #[test]
    pub fn test_matmul_vec4_primitive_n_smaller_than_m() {
        test_with_params(8, 8, 3, 1, 1);
    }

    #[test]
    pub fn test_matmul_vec4_primitive_m_smaller_than_n() {
        test_with_params(3, 8, 8, 1, 1);
    }

    #[test]
    pub fn test_matmul_vec4_primitive_k_smaller_than_m_n() {
        test_with_params(8, 3, 8, 1, 1);
    }

    #[test]
    pub fn test_matmul_vec4_primitive_k_larger_than_m_n() {
        test_with_params(8, 48, 8, 1, 1);
    }

    #[test]
    pub fn test_matmul_vec4_primitive_multibatch_1_dim() {
        test_with_params(8, 8, 8, 3, 1);
    }

    #[test]
    pub fn test_matmul_vec4_primitive_multibatch_2_dims() {
        test_with_params(8, 8, 8, 3, 4);
    }

    #[test]
    pub fn test_matmul_vec4_primitive_blocks_divide_shapes_unevenly() {
        test_with_params(7, 7, 7, 1, 1);
    }

    #[test]
    pub fn test_matmul_vec4_primitive_medium() {
        test_with_params(17, 16, 16, 1, 1);
    }

    #[test]
    pub fn test_matmul_vec4_primitive_large() {
        test_with_params(134, 242, 250, 1, 1);
    }

    fn test_with_params(m: usize, k: usize, n: usize, batch_1: usize, batch_2: usize) {
        let func = matmul_tiling_2d_vec4_lhs;
        let shape_lhs = [batch_1, batch_2, m, k];
        let shape_rhs = [batch_1, batch_2, k, n];
        same_as_reference(func, shape_lhs, shape_rhs);
    }

    #[test]
    fn test_matmul_tiling_2d_vec4_primitive_swapped_batches_no_padding() {
        let matmul_func = matmul_tiling_2d_vec4_lhs;
        let swap = [0, 1];
        let shape_lhs = [3, 2, 4, 4];
        let shape_rhs = [3, 2, 4, 4];
        same_as_reference_swapped_dims(matmul_func, swap, swap, shape_lhs, shape_rhs);
    }

    #[test]
    fn test_matmul_tiling_2d_vec4_primitive_swapped_row_col_no_padding() {
        let matmul_func = matmul_tiling_2d_vec4_lhs;
        let swap_lhs = [0, 0];
        let swap_rhs = [2, 3];
        let shape_lhs = [3, 2, 4, 4];
        let shape_rhs = [3, 2, 4, 4];
        same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs);
    }

    #[test]
    fn test_matmul_tiling_2d_vec4_primitive_swapped_row_with_batch_no_padding() {
        let matmul_func = matmul_tiling_2d_vec4_lhs;
        let swap_lhs = [0, 3];
        let swap_rhs = [0, 2];
        let shape_lhs = [4, 4, 4, 4];
        let shape_rhs = [4, 4, 4, 4];
        same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs);
    }
}