matrixmultiply_mt/
generic_kernel.rs

1// Original work Copyright 2016 bluss
2// Modified work Copyright 2016 J. Millard.
3//
4// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
5// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
7// option. This file may not be copied, modified, or distributed
8// except according to those terms.
9use crate::prefetch_read;
10use crate::{generic_params::*, prefetch_write};
11use num_traits::Float;
12use std::cmp::min;
13use typenum::Unsigned;
14use typenum_loops::Loop;
15
16/// Call the GEMM kernel with a "masked" output C.
17///
18/// Simply redirect the MR by NR kernel output to the passed
19/// in `mask_buf`, and copy the non masked region to the real
20/// C.
21///
22/// + rows: rows of kernel unmasked
23/// + cols: cols of kernel unmasked
24//#[inline(always)]
25#[allow(clippy::too_many_arguments, clippy::many_single_char_names)]
26pub unsafe fn masked_kernel<K: KernelConfig>(
27	k: usize,
28	alpha: K::T,
29	a: *const K::T,
30	b: *const K::T,
31	c: *mut K::T,
32	rsc: isize,
33	csc: isize,
34	rows: usize,
35	cols: usize,
36) {
37	let mr = min(K::MR::to_usize(), rows);
38	let nr = min(K::NR::to_usize(), cols);
39	prefetch_read(a as *mut i8);
40	prefetch_read(b as *mut i8);
41	write_prefetch::<K>(c, rsc, csc);
42	if K::TR::to_usize() == 0 {
43		let ab = kernel_compute::<K>(k, alpha, a, b);
44		for j in 0..nr {
45			for i in 0..mr {
46				let cptr = c.offset(rsc * i as isize + csc * j as isize);
47				*cptr += ab[i][j];
48			}
49		}
50	} else {
51		let ab = kernel_compute_trans::<K>(k, alpha, a, b);
52		for j in 0..nr {
53			for i in 0..mr {
54				let cptr = c.offset(rsc * i as isize + csc * j as isize);
55				*cptr += ab[j][i];
56			}
57		}
58	}
59}
60
61/// matrix multiplication kernel
62///
63/// This does the matrix multiplication:
64///
65/// C ← α A B + β C
66///
67/// + k: length of data in a, b
68/// + a, b are packed
69/// + c has general strides
70/// + rsc: row stride of c
71/// + csc: col stride of c
72/// + if beta is 0, then c does not need to be initialized
73#[inline(never)]
74pub unsafe fn kernel<K: KernelConfig>(
75	k: usize,
76	alpha: K::T,
77	a: *const K::T,
78	b: *const K::T,
79	c: *mut K::T,
80	rsc: isize,
81	csc: isize,
82) {
83	prefetch_read(a as *mut i8);
84	prefetch_read(b as *mut i8);
85	write_prefetch::<K>(c, rsc, csc);
86	if K::TR::to_usize() == 0 {
87		let ab = kernel_compute::<K>(k, alpha, a, b);
88		kernel_write::<K>(c, rsc, csc, &ab);
89	} else {
90		let ab = kernel_compute_trans::<K>(k, alpha, a, b);
91		kernel_write_trans::<K>(c, rsc, csc, &ab);
92	}
93}
94
95/// Split out compute for better vectorisation
96#[inline(always)]
97unsafe fn kernel_compute<K: KernelConfig>(
98	k: usize,
99	alpha: K::T,
100	a: *const K::T,
101	b: *const K::T,
102) -> GA<GA<K::T, K::NR>, K::MR> {
103	// Compute matrix multiplication into ab[i][j]
104	let mut ab = <GA<GA<K::T, K::NR>, K::MR>>::default();
105
106	K::KU::partial_unroll(k, &mut |l, _| {
107		let a = a.add(l * K::MR::to_usize());
108		let b = b.add(l * K::NR::to_usize());
109
110		K::MR::full_unroll(&mut |i| {
111			K::NR::full_unroll(&mut |j| {
112				if K::FMA::to_usize() > 0 {
113					ab[i][j] = at::<K::T>(a, i).mul_add(at::<K::T>(b, j), ab[i][j]);
114				} else {
115					ab[i][j] += at::<K::T>(a, i) * at::<K::T>(b, j);
116				}
117			});
118		});
119	});
120
121	K::MR::full_unroll(&mut |i| {
122		K::NR::full_unroll(&mut |j| {
123			ab[i][j] = ab[i][j] * alpha;
124		});
125	});
126
127	// for i in 0..K::MR::to_usize() {
128	// 	for j in 0..K::NR::to_usize() {
129	// 		ab[i][j] = ab[i][j]*alpha;
130	// 	}
131	// }
132
133	ab
134}
135
136/// Split out compute for better vectorisation
137#[inline(always)]
138unsafe fn kernel_compute_trans<K: KernelConfig>(
139	k: usize,
140	alpha: K::T,
141	a: *const K::T,
142	b: *const K::T,
143) -> GA<GA<K::T, K::MR>, K::NR> {
144	// Compute matrix multiplication into ab[i][j]
145	let mut ab = <GA<GA<K::T, K::MR>, K::NR>>::default();
146
147	K::KU::partial_unroll(k, &mut |l, _| {
148		let a = a.add(l * K::MR::to_usize());
149		let b = b.add(l * K::NR::to_usize());
150
151		K::NR::full_unroll(&mut |j| {
152			K::MR::full_unroll(&mut |i| {
153				if K::FMA::to_usize() > 0 {
154					ab[j][i] = at::<K::T>(a, i).mul_add(at::<K::T>(b, j), ab[j][i]);
155				} else {
156					ab[j][i] += at::<K::T>(a, i) * at::<K::T>(b, j);
157				}
158			});
159		});
160	});
161
162	K::NR::full_unroll(&mut |j| {
163		K::MR::full_unroll(&mut |i| {
164			ab[j][i] = ab[j][i] * alpha;
165		});
166	});
167
168	// for j in 0..K::NR::to_usize() {
169	// 	for i in 0..K::MR::to_usize() {
170	// 		ab[j][i] = ab[j][i]*alpha;
171	// 	}
172	// }
173
174	ab
175}
176
177/// prefetch locations of C which will be written too
178#[inline(always)]
179unsafe fn write_prefetch<K: KernelConfig>(c: *mut K::T, rsc: isize, csc: isize) {
180	if rsc == 1 {
181		K::NR::full_unroll(&mut |j| {
182			prefetch_write(c.offset(csc * j as isize) as *mut i8);
183		});
184	} else if csc == 1 {
185		K::MR::full_unroll(&mut |i| {
186			prefetch_write(c.offset(rsc * i as isize) as *mut i8);
187		});
188	} else {
189		for i in 0..K::MR::to_usize() {
190			for j in 0..K::NR::to_usize() {
191				prefetch_write(c.offset(rsc * i as isize + csc * j as isize) as *mut i8);
192			}
193		}
194	}
195}
196
197/// Choose writes to C in a cache/vectorisation friendly manner if possible
198#[inline(always)]
199unsafe fn kernel_write<K: KernelConfig>(c: *mut K::T, rsc: isize, csc: isize, ab: &GA<GA<K::T, K::NR>, K::MR>) {
200	if rsc == 1 {
201		for i in 0..K::MR::to_usize() {
202			for j in 0..K::NR::to_usize() {
203				let v = c.offset(1 * i as isize + csc * j as isize);
204				*v += ab[i][j];
205			}
206		}
207	} else if csc == 1 {
208		for i in 0..K::MR::to_usize() {
209			for j in 0..K::NR::to_usize() {
210				let v = c.offset(rsc * i as isize + 1 * j as isize);
211				*v += ab[i][j];
212			}
213		}
214	} else {
215		for i in 0..K::MR::to_usize() {
216			for j in 0..K::NR::to_usize() {
217				let v = c.offset(rsc * i as isize + csc * j as isize);
218				*v += ab[i][j];
219			}
220		}
221	}
222}
223
224/// Choose writes to C in a cache/vectorisation friendly manner if possible
225#[inline(always)]
226unsafe fn kernel_write_trans<K: KernelConfig>(c: *mut K::T, rsc: isize, csc: isize, ab: &GA<GA<K::T, K::MR>, K::NR>) {
227	if rsc == 1 {
228		for j in 0..K::NR::to_usize() {
229			for i in 0..K::MR::to_usize() {
230				let v = c.offset(1 * i as isize + csc * j as isize);
231				*v += ab[j][i];
232			}
233		}
234	} else if csc == 1 {
235		for j in 0..K::NR::to_usize() {
236			for i in 0..K::MR::to_usize() {
237				let v = c.offset(rsc * i as isize + 1 * j as isize);
238				*v += ab[j][i];
239			}
240		}
241	} else {
242		for j in 0..K::NR::to_usize() {
243			for i in 0..K::MR::to_usize() {
244				let v = c.offset(rsc * i as isize + csc * j as isize);
245				*v += ab[j][i];
246			}
247		}
248	}
249}
250
251#[inline(always)]
252unsafe fn at<T: Copy>(ptr: *const T, i: usize) -> T {
253	*ptr.offset(i as isize)
254}