matrixmultiply_mt/
generic_kernel.rs1use crate::prefetch_read;
10use crate::{generic_params::*, prefetch_write};
11use num_traits::Float;
12use std::cmp::min;
13use typenum::Unsigned;
14use typenum_loops::Loop;
15
16#[allow(clippy::too_many_arguments, clippy::many_single_char_names)]
26pub unsafe fn masked_kernel<K: KernelConfig>(
27 k: usize,
28 alpha: K::T,
29 a: *const K::T,
30 b: *const K::T,
31 c: *mut K::T,
32 rsc: isize,
33 csc: isize,
34 rows: usize,
35 cols: usize,
36) {
37 let mr = min(K::MR::to_usize(), rows);
38 let nr = min(K::NR::to_usize(), cols);
39 prefetch_read(a as *mut i8);
40 prefetch_read(b as *mut i8);
41 write_prefetch::<K>(c, rsc, csc);
42 if K::TR::to_usize() == 0 {
43 let ab = kernel_compute::<K>(k, alpha, a, b);
44 for j in 0..nr {
45 for i in 0..mr {
46 let cptr = c.offset(rsc * i as isize + csc * j as isize);
47 *cptr += ab[i][j];
48 }
49 }
50 } else {
51 let ab = kernel_compute_trans::<K>(k, alpha, a, b);
52 for j in 0..nr {
53 for i in 0..mr {
54 let cptr = c.offset(rsc * i as isize + csc * j as isize);
55 *cptr += ab[j][i];
56 }
57 }
58 }
59}
60
61#[inline(never)]
74pub unsafe fn kernel<K: KernelConfig>(
75 k: usize,
76 alpha: K::T,
77 a: *const K::T,
78 b: *const K::T,
79 c: *mut K::T,
80 rsc: isize,
81 csc: isize,
82) {
83 prefetch_read(a as *mut i8);
84 prefetch_read(b as *mut i8);
85 write_prefetch::<K>(c, rsc, csc);
86 if K::TR::to_usize() == 0 {
87 let ab = kernel_compute::<K>(k, alpha, a, b);
88 kernel_write::<K>(c, rsc, csc, &ab);
89 } else {
90 let ab = kernel_compute_trans::<K>(k, alpha, a, b);
91 kernel_write_trans::<K>(c, rsc, csc, &ab);
92 }
93}
94
95#[inline(always)]
97unsafe fn kernel_compute<K: KernelConfig>(
98 k: usize,
99 alpha: K::T,
100 a: *const K::T,
101 b: *const K::T,
102) -> GA<GA<K::T, K::NR>, K::MR> {
103 let mut ab = <GA<GA<K::T, K::NR>, K::MR>>::default();
105
106 K::KU::partial_unroll(k, &mut |l, _| {
107 let a = a.add(l * K::MR::to_usize());
108 let b = b.add(l * K::NR::to_usize());
109
110 K::MR::full_unroll(&mut |i| {
111 K::NR::full_unroll(&mut |j| {
112 if K::FMA::to_usize() > 0 {
113 ab[i][j] = at::<K::T>(a, i).mul_add(at::<K::T>(b, j), ab[i][j]);
114 } else {
115 ab[i][j] += at::<K::T>(a, i) * at::<K::T>(b, j);
116 }
117 });
118 });
119 });
120
121 K::MR::full_unroll(&mut |i| {
122 K::NR::full_unroll(&mut |j| {
123 ab[i][j] = ab[i][j] * alpha;
124 });
125 });
126
127 ab
134}
135
136#[inline(always)]
138unsafe fn kernel_compute_trans<K: KernelConfig>(
139 k: usize,
140 alpha: K::T,
141 a: *const K::T,
142 b: *const K::T,
143) -> GA<GA<K::T, K::MR>, K::NR> {
144 let mut ab = <GA<GA<K::T, K::MR>, K::NR>>::default();
146
147 K::KU::partial_unroll(k, &mut |l, _| {
148 let a = a.add(l * K::MR::to_usize());
149 let b = b.add(l * K::NR::to_usize());
150
151 K::NR::full_unroll(&mut |j| {
152 K::MR::full_unroll(&mut |i| {
153 if K::FMA::to_usize() > 0 {
154 ab[j][i] = at::<K::T>(a, i).mul_add(at::<K::T>(b, j), ab[j][i]);
155 } else {
156 ab[j][i] += at::<K::T>(a, i) * at::<K::T>(b, j);
157 }
158 });
159 });
160 });
161
162 K::NR::full_unroll(&mut |j| {
163 K::MR::full_unroll(&mut |i| {
164 ab[j][i] = ab[j][i] * alpha;
165 });
166 });
167
168 ab
175}
176
177#[inline(always)]
179unsafe fn write_prefetch<K: KernelConfig>(c: *mut K::T, rsc: isize, csc: isize) {
180 if rsc == 1 {
181 K::NR::full_unroll(&mut |j| {
182 prefetch_write(c.offset(csc * j as isize) as *mut i8);
183 });
184 } else if csc == 1 {
185 K::MR::full_unroll(&mut |i| {
186 prefetch_write(c.offset(rsc * i as isize) as *mut i8);
187 });
188 } else {
189 for i in 0..K::MR::to_usize() {
190 for j in 0..K::NR::to_usize() {
191 prefetch_write(c.offset(rsc * i as isize + csc * j as isize) as *mut i8);
192 }
193 }
194 }
195}
196
197#[inline(always)]
199unsafe fn kernel_write<K: KernelConfig>(c: *mut K::T, rsc: isize, csc: isize, ab: &GA<GA<K::T, K::NR>, K::MR>) {
200 if rsc == 1 {
201 for i in 0..K::MR::to_usize() {
202 for j in 0..K::NR::to_usize() {
203 let v = c.offset(1 * i as isize + csc * j as isize);
204 *v += ab[i][j];
205 }
206 }
207 } else if csc == 1 {
208 for i in 0..K::MR::to_usize() {
209 for j in 0..K::NR::to_usize() {
210 let v = c.offset(rsc * i as isize + 1 * j as isize);
211 *v += ab[i][j];
212 }
213 }
214 } else {
215 for i in 0..K::MR::to_usize() {
216 for j in 0..K::NR::to_usize() {
217 let v = c.offset(rsc * i as isize + csc * j as isize);
218 *v += ab[i][j];
219 }
220 }
221 }
222}
223
224#[inline(always)]
226unsafe fn kernel_write_trans<K: KernelConfig>(c: *mut K::T, rsc: isize, csc: isize, ab: &GA<GA<K::T, K::MR>, K::NR>) {
227 if rsc == 1 {
228 for j in 0..K::NR::to_usize() {
229 for i in 0..K::MR::to_usize() {
230 let v = c.offset(1 * i as isize + csc * j as isize);
231 *v += ab[j][i];
232 }
233 }
234 } else if csc == 1 {
235 for j in 0..K::NR::to_usize() {
236 for i in 0..K::MR::to_usize() {
237 let v = c.offset(rsc * i as isize + 1 * j as isize);
238 *v += ab[j][i];
239 }
240 }
241 } else {
242 for j in 0..K::NR::to_usize() {
243 for i in 0..K::MR::to_usize() {
244 let v = c.offset(rsc * i as isize + csc * j as isize);
245 *v += ab[j][i];
246 }
247 }
248 }
249}
250
251#[inline(always)]
252unsafe fn at<T: Copy>(ptr: *const T, i: usize) -> T {
253 *ptr.offset(i as isize)
254}