gemm_common/
horizontal_microkernel.rs

1#[macro_export]
2macro_rules! horizontal_kernel {
3    ($([$target: tt])?, $name: ident, $m: tt, $n: tt $(,)?) => {
4        $(#[target_feature(enable = $target)])?
5        pub unsafe fn $name(
6            k: usize,
7            dst: *mut T,
8            lhs: *const T,
9            rhs: *const T,
10            dst_cs: isize,
11            dst_rs: isize,
12            lhs_rs: isize,
13            rhs_cs: isize,
14            alpha: T,
15            beta: T,
16            alpha_status: u8,
17            _conj_dst: bool,
18            _conj_lhs: bool,
19            _conj_rhs: bool,
20        ) {
21            let mut accum = [[splat(::core::mem::zeroed()); $m]; $n];
22            seq_macro::seq!(M_ITER in 0..$m {
23               let lhs~M_ITER = lhs.wrapping_offset(M_ITER * lhs_rs);
24            });
25            seq_macro::seq!(N_ITER in 0..$n {
26               let rhs~N_ITER = rhs.wrapping_offset(N_ITER * rhs_cs);
27            });
28
29            let mut depth = 0;
30            while depth < k / N * N {
31                seq_macro::seq!(M_ITER in 0..$m {
32                   let lhs~M_ITER = *(lhs~M_ITER.add(depth) as *const [T; N]);
33                });
34                seq_macro::seq!(N_ITER in 0..$n {
35                   let rhs~N_ITER = *(rhs~N_ITER.add(depth) as *const [T; N]);
36                });
37
38                seq_macro::seq!(M_ITER in 0..$m {
39                    seq_macro::seq!(N_ITER in 0..$n {
40                        accum[N_ITER][M_ITER] = mul_add(lhs~M_ITER, rhs~N_ITER, accum[N_ITER][M_ITER]);
41                    });
42                });
43
44                depth += N;
45            }
46
47            if depth < k {
48                seq_macro::seq!(M_ITER in 0..$m {
49                   let lhs~M_ITER = partial_load(lhs~M_ITER.add(depth), k - depth);
50                });
51                seq_macro::seq!(N_ITER in 0..$n {
52                   let rhs~N_ITER = partial_load(rhs~N_ITER.add(depth), k - depth);
53                });
54
55                seq_macro::seq!(M_ITER in 0..$m {
56                    seq_macro::seq!(N_ITER in 0..$n {
57                        accum[N_ITER][M_ITER] = mul_add(lhs~M_ITER, rhs~N_ITER, accum[N_ITER][M_ITER]);
58                    });
59                });
60            }
61
62            let mut accum_reduced: [[T; $m]; $n] = core::mem::zeroed();
63            seq_macro::seq!(M_ITER in 0..$m {
64                seq_macro::seq!(N_ITER in 0..$n {
65                    accum_reduced[N_ITER][M_ITER] = reduce_sum(accum[N_ITER][M_ITER]);
66                });
67            });
68
69            if alpha_status == 0 {
70                seq_macro::seq!(M_ITER in 0..$m {{
71                    seq_macro::seq!(N_ITER in 0..$n {{
72                        let dst = dst.offset(dst_cs * N_ITER + dst_rs * M_ITER);
73                        *dst = scalar_mul(beta, accum_reduced[N_ITER][M_ITER]);
74                    }});
75                }});
76            } else if alpha_status == 1 {
77                seq_macro::seq!(M_ITER in 0..$m {{
78                    seq_macro::seq!(N_ITER in 0..$n {{
79                        let dst = dst.offset(dst_cs * N_ITER + dst_rs * M_ITER);
80                        *dst = scalar_mul_add(
81                            beta,
82                            accum_reduced[N_ITER][M_ITER],
83                            *dst,
84                        );
85                    }});
86                }});
87            } else {
88                seq_macro::seq!(M_ITER in 0..$m {{
89                    seq_macro::seq!(N_ITER in 0..$n {{
90                        let dst = dst.offset(dst_cs * N_ITER + dst_rs * M_ITER);
91                        *dst = scalar_add(
92                            scalar_mul(beta, accum_reduced[N_ITER][M_ITER]),
93                            scalar_mul(alpha, *dst),
94                        );
95                    }});
96                }});
97            }
98        }
99    };
100}
101
102#[macro_export]
103macro_rules! horizontal_cplx_kernel {
104    ($([$target: tt])?, $name: ident, $m: tt, $n: tt $(,)?) => {
105        $(#[target_feature(enable = $target)])?
106        pub unsafe fn $name(
107            k: usize,
108            dst: *mut num_complex::Complex<T>,
109            lhs: *const num_complex::Complex<T>,
110            rhs: *const num_complex::Complex<T>,
111            dst_cs: isize,
112            dst_rs: isize,
113            lhs_rs: isize,
114            rhs_cs: isize,
115            alpha: num_complex::Complex<T>,
116            beta: num_complex::Complex<T>,
117            alpha_status: u8,
118            _conj_dst: bool,
119            conj_lhs: bool,
120            conj_rhs: bool,
121        ) {
122            let mut accum = [[splat(::core::mem::zeroed()); $m]; $n];
123            seq_macro::seq!(M_ITER in 0..$m {
124               let lhs~M_ITER = lhs.wrapping_offset(M_ITER * lhs_rs);
125            });
126            seq_macro::seq!(N_ITER in 0..$n {
127               let rhs~N_ITER = rhs.wrapping_offset(N_ITER * rhs_cs);
128            });
129
130            let (conj_lhs, conj_all) = match (conj_lhs, conj_rhs) {
131                (true, true) => (false, true),
132                (false, true) => (true, true),
133                (true, false) => (true, false),
134                (false, false) => (false, false),
135            };
136
137            if conj_lhs {
138                let mut depth = 0;
139                while depth < k / CPLX_N * CPLX_N {
140                    seq_macro::seq!(M_ITER in 0..$m {
141                       let lhs~M_ITER = *(lhs~M_ITER.add(depth) as *const Pack);
142                    });
143                    seq_macro::seq!(N_ITER in 0..$n {
144                       let rhs~N_ITER = *(rhs~N_ITER.add(depth) as *const Pack);
145                    });
146
147                    seq_macro::seq!(M_ITER in 0..$m {
148                        seq_macro::seq!(N_ITER in 0..$n {
149                            accum[N_ITER][M_ITER] = conj_mul_add(lhs~M_ITER, rhs~N_ITER, accum[N_ITER][M_ITER]);
150                        });
151                    });
152
153                    depth += CPLX_N;
154                }
155
156                if depth < k {
157                    seq_macro::seq!(M_ITER in 0..$m {
158                       let lhs~M_ITER = partial_load(lhs~M_ITER.add(depth), k - depth);
159                    });
160                    seq_macro::seq!(N_ITER in 0..$n {
161                       let rhs~N_ITER = partial_load(rhs~N_ITER.add(depth), k - depth);
162                    });
163
164                    seq_macro::seq!(M_ITER in 0..$m {
165                        seq_macro::seq!(N_ITER in 0..$n {
166                            accum[N_ITER][M_ITER] = conj_mul_add(lhs~M_ITER, rhs~N_ITER, accum[N_ITER][M_ITER]);
167                        });
168                    });
169                }
170            } else {
171                let mut depth = 0;
172                while depth < k / CPLX_N * CPLX_N {
173                    seq_macro::seq!(M_ITER in 0..$m {
174                       let lhs~M_ITER = *(lhs~M_ITER.add(depth) as *const Pack);
175                    });
176                    seq_macro::seq!(N_ITER in 0..$n {
177                       let rhs~N_ITER = *(rhs~N_ITER.add(depth) as *const Pack);
178                    });
179
180                    seq_macro::seq!(M_ITER in 0..$m {
181                        seq_macro::seq!(N_ITER in 0..$n {
182                            accum[N_ITER][M_ITER] = mul_add(lhs~M_ITER, rhs~N_ITER, accum[N_ITER][M_ITER]);
183                        });
184                    });
185
186                    depth += CPLX_N;
187                }
188
189                if depth < k {
190                    seq_macro::seq!(M_ITER in 0..$m {
191                       let lhs~M_ITER = partial_load(lhs~M_ITER.add(depth), k - depth);
192                    });
193                    seq_macro::seq!(N_ITER in 0..$n {
194                       let rhs~N_ITER = partial_load(rhs~N_ITER.add(depth), k - depth);
195                    });
196
197                    seq_macro::seq!(M_ITER in 0..$m {
198                        seq_macro::seq!(N_ITER in 0..$n {
199                            accum[N_ITER][M_ITER] = mul_add(lhs~M_ITER, rhs~N_ITER, accum[N_ITER][M_ITER]);
200                        });
201                    });
202                }
203            }
204
205            if conj_all {
206                seq_macro::seq!(M_ITER in 0..$m {{
207                    seq_macro::seq!(N_ITER in 0..$n {{
208                        accum[N_ITER][M_ITER] = conj(accum[N_ITER][M_ITER]);
209                    }});
210                }});
211            }
212
213            let mut accum_reduced: [[num_complex::Complex<T>; $m]; $n] = core::mem::zeroed();
214            seq_macro::seq!(M_ITER in 0..$m {
215                seq_macro::seq!(N_ITER in 0..$n {
216                    accum_reduced[N_ITER][M_ITER] = reduce_sum(accum[N_ITER][M_ITER]);
217                });
218            });
219
220            if alpha_status == 0 {
221                seq_macro::seq!(M_ITER in 0..$m {{
222                    seq_macro::seq!(N_ITER in 0..$n {{
223                        let dst = dst.offset(dst_cs * N_ITER + dst_rs * M_ITER);
224                        *dst = beta * accum_reduced[N_ITER][M_ITER];
225                    }});
226                }});
227            } else if alpha_status == 1 {
228                seq_macro::seq!(M_ITER in 0..$m {{
229                    seq_macro::seq!(N_ITER in 0..$n {{
230                        let dst = dst.offset(dst_cs * N_ITER + dst_rs * M_ITER);
231                        *dst = (beta * accum_reduced[N_ITER][M_ITER]) + *dst;
232                    }});
233                }});
234            } else {
235                seq_macro::seq!(M_ITER in 0..$m {{
236                    seq_macro::seq!(N_ITER in 0..$n {{
237                        let dst = dst.offset(dst_cs * N_ITER + dst_rs * M_ITER);
238                        *dst = (beta * accum_reduced[N_ITER][M_ITER]) + (alpha * *dst);
239                    }});
240                }});
241            }
242        }
243    };
244}