gemm_common/
horizontal_microkernel.rs1#[macro_export]
2macro_rules! horizontal_kernel {
3 ($([$target: tt])?, $name: ident, $m: tt, $n: tt $(,)?) => {
4 $(#[target_feature(enable = $target)])?
5 pub unsafe fn $name(
6 k: usize,
7 dst: *mut T,
8 lhs: *const T,
9 rhs: *const T,
10 dst_cs: isize,
11 dst_rs: isize,
12 lhs_rs: isize,
13 rhs_cs: isize,
14 alpha: T,
15 beta: T,
16 alpha_status: u8,
17 _conj_dst: bool,
18 _conj_lhs: bool,
19 _conj_rhs: bool,
20 ) {
21 let mut accum = [[splat(::core::mem::zeroed()); $m]; $n];
22 seq_macro::seq!(M_ITER in 0..$m {
23 let lhs~M_ITER = lhs.wrapping_offset(M_ITER * lhs_rs);
24 });
25 seq_macro::seq!(N_ITER in 0..$n {
26 let rhs~N_ITER = rhs.wrapping_offset(N_ITER * rhs_cs);
27 });
28
29 let mut depth = 0;
30 while depth < k / N * N {
31 seq_macro::seq!(M_ITER in 0..$m {
32 let lhs~M_ITER = *(lhs~M_ITER.add(depth) as *const [T; N]);
33 });
34 seq_macro::seq!(N_ITER in 0..$n {
35 let rhs~N_ITER = *(rhs~N_ITER.add(depth) as *const [T; N]);
36 });
37
38 seq_macro::seq!(M_ITER in 0..$m {
39 seq_macro::seq!(N_ITER in 0..$n {
40 accum[N_ITER][M_ITER] = mul_add(lhs~M_ITER, rhs~N_ITER, accum[N_ITER][M_ITER]);
41 });
42 });
43
44 depth += N;
45 }
46
47 if depth < k {
48 seq_macro::seq!(M_ITER in 0..$m {
49 let lhs~M_ITER = partial_load(lhs~M_ITER.add(depth), k - depth);
50 });
51 seq_macro::seq!(N_ITER in 0..$n {
52 let rhs~N_ITER = partial_load(rhs~N_ITER.add(depth), k - depth);
53 });
54
55 seq_macro::seq!(M_ITER in 0..$m {
56 seq_macro::seq!(N_ITER in 0..$n {
57 accum[N_ITER][M_ITER] = mul_add(lhs~M_ITER, rhs~N_ITER, accum[N_ITER][M_ITER]);
58 });
59 });
60 }
61
62 let mut accum_reduced: [[T; $m]; $n] = core::mem::zeroed();
63 seq_macro::seq!(M_ITER in 0..$m {
64 seq_macro::seq!(N_ITER in 0..$n {
65 accum_reduced[N_ITER][M_ITER] = reduce_sum(accum[N_ITER][M_ITER]);
66 });
67 });
68
69 if alpha_status == 0 {
70 seq_macro::seq!(M_ITER in 0..$m {{
71 seq_macro::seq!(N_ITER in 0..$n {{
72 let dst = dst.offset(dst_cs * N_ITER + dst_rs * M_ITER);
73 *dst = scalar_mul(beta, accum_reduced[N_ITER][M_ITER]);
74 }});
75 }});
76 } else if alpha_status == 1 {
77 seq_macro::seq!(M_ITER in 0..$m {{
78 seq_macro::seq!(N_ITER in 0..$n {{
79 let dst = dst.offset(dst_cs * N_ITER + dst_rs * M_ITER);
80 *dst = scalar_mul_add(
81 beta,
82 accum_reduced[N_ITER][M_ITER],
83 *dst,
84 );
85 }});
86 }});
87 } else {
88 seq_macro::seq!(M_ITER in 0..$m {{
89 seq_macro::seq!(N_ITER in 0..$n {{
90 let dst = dst.offset(dst_cs * N_ITER + dst_rs * M_ITER);
91 *dst = scalar_add(
92 scalar_mul(beta, accum_reduced[N_ITER][M_ITER]),
93 scalar_mul(alpha, *dst),
94 );
95 }});
96 }});
97 }
98 }
99 };
100}
101
102#[macro_export]
103macro_rules! horizontal_cplx_kernel {
104 ($([$target: tt])?, $name: ident, $m: tt, $n: tt $(,)?) => {
105 $(#[target_feature(enable = $target)])?
106 pub unsafe fn $name(
107 k: usize,
108 dst: *mut num_complex::Complex<T>,
109 lhs: *const num_complex::Complex<T>,
110 rhs: *const num_complex::Complex<T>,
111 dst_cs: isize,
112 dst_rs: isize,
113 lhs_rs: isize,
114 rhs_cs: isize,
115 alpha: num_complex::Complex<T>,
116 beta: num_complex::Complex<T>,
117 alpha_status: u8,
118 _conj_dst: bool,
119 conj_lhs: bool,
120 conj_rhs: bool,
121 ) {
122 let mut accum = [[splat(::core::mem::zeroed()); $m]; $n];
123 seq_macro::seq!(M_ITER in 0..$m {
124 let lhs~M_ITER = lhs.wrapping_offset(M_ITER * lhs_rs);
125 });
126 seq_macro::seq!(N_ITER in 0..$n {
127 let rhs~N_ITER = rhs.wrapping_offset(N_ITER * rhs_cs);
128 });
129
130 let (conj_lhs, conj_all) = match (conj_lhs, conj_rhs) {
131 (true, true) => (false, true),
132 (false, true) => (true, true),
133 (true, false) => (true, false),
134 (false, false) => (false, false),
135 };
136
137 if conj_lhs {
138 let mut depth = 0;
139 while depth < k / CPLX_N * CPLX_N {
140 seq_macro::seq!(M_ITER in 0..$m {
141 let lhs~M_ITER = *(lhs~M_ITER.add(depth) as *const Pack);
142 });
143 seq_macro::seq!(N_ITER in 0..$n {
144 let rhs~N_ITER = *(rhs~N_ITER.add(depth) as *const Pack);
145 });
146
147 seq_macro::seq!(M_ITER in 0..$m {
148 seq_macro::seq!(N_ITER in 0..$n {
149 accum[N_ITER][M_ITER] = conj_mul_add(lhs~M_ITER, rhs~N_ITER, accum[N_ITER][M_ITER]);
150 });
151 });
152
153 depth += CPLX_N;
154 }
155
156 if depth < k {
157 seq_macro::seq!(M_ITER in 0..$m {
158 let lhs~M_ITER = partial_load(lhs~M_ITER.add(depth), k - depth);
159 });
160 seq_macro::seq!(N_ITER in 0..$n {
161 let rhs~N_ITER = partial_load(rhs~N_ITER.add(depth), k - depth);
162 });
163
164 seq_macro::seq!(M_ITER in 0..$m {
165 seq_macro::seq!(N_ITER in 0..$n {
166 accum[N_ITER][M_ITER] = conj_mul_add(lhs~M_ITER, rhs~N_ITER, accum[N_ITER][M_ITER]);
167 });
168 });
169 }
170 } else {
171 let mut depth = 0;
172 while depth < k / CPLX_N * CPLX_N {
173 seq_macro::seq!(M_ITER in 0..$m {
174 let lhs~M_ITER = *(lhs~M_ITER.add(depth) as *const Pack);
175 });
176 seq_macro::seq!(N_ITER in 0..$n {
177 let rhs~N_ITER = *(rhs~N_ITER.add(depth) as *const Pack);
178 });
179
180 seq_macro::seq!(M_ITER in 0..$m {
181 seq_macro::seq!(N_ITER in 0..$n {
182 accum[N_ITER][M_ITER] = mul_add(lhs~M_ITER, rhs~N_ITER, accum[N_ITER][M_ITER]);
183 });
184 });
185
186 depth += CPLX_N;
187 }
188
189 if depth < k {
190 seq_macro::seq!(M_ITER in 0..$m {
191 let lhs~M_ITER = partial_load(lhs~M_ITER.add(depth), k - depth);
192 });
193 seq_macro::seq!(N_ITER in 0..$n {
194 let rhs~N_ITER = partial_load(rhs~N_ITER.add(depth), k - depth);
195 });
196
197 seq_macro::seq!(M_ITER in 0..$m {
198 seq_macro::seq!(N_ITER in 0..$n {
199 accum[N_ITER][M_ITER] = mul_add(lhs~M_ITER, rhs~N_ITER, accum[N_ITER][M_ITER]);
200 });
201 });
202 }
203 }
204
205 if conj_all {
206 seq_macro::seq!(M_ITER in 0..$m {{
207 seq_macro::seq!(N_ITER in 0..$n {{
208 accum[N_ITER][M_ITER] = conj(accum[N_ITER][M_ITER]);
209 }});
210 }});
211 }
212
213 let mut accum_reduced: [[num_complex::Complex<T>; $m]; $n] = core::mem::zeroed();
214 seq_macro::seq!(M_ITER in 0..$m {
215 seq_macro::seq!(N_ITER in 0..$n {
216 accum_reduced[N_ITER][M_ITER] = reduce_sum(accum[N_ITER][M_ITER]);
217 });
218 });
219
220 if alpha_status == 0 {
221 seq_macro::seq!(M_ITER in 0..$m {{
222 seq_macro::seq!(N_ITER in 0..$n {{
223 let dst = dst.offset(dst_cs * N_ITER + dst_rs * M_ITER);
224 *dst = beta * accum_reduced[N_ITER][M_ITER];
225 }});
226 }});
227 } else if alpha_status == 1 {
228 seq_macro::seq!(M_ITER in 0..$m {{
229 seq_macro::seq!(N_ITER in 0..$n {{
230 let dst = dst.offset(dst_cs * N_ITER + dst_rs * M_ITER);
231 *dst = (beta * accum_reduced[N_ITER][M_ITER]) + *dst;
232 }});
233 }});
234 } else {
235 seq_macro::seq!(M_ITER in 0..$m {{
236 seq_macro::seq!(N_ITER in 0..$n {{
237 let dst = dst.offset(dst_cs * N_ITER + dst_rs * M_ITER);
238 *dst = (beta * accum_reduced[N_ITER][M_ITER]) + (alpha * *dst);
239 }});
240 }});
241 }
242 }
243 };
244}