1#![cfg_attr(feature = "nightly", feature(stdsimd), feature(avx512_target_feature))]
2#![cfg_attr(not(feature = "std"), no_std)]
3#![warn(rust_2018_idioms)]
4
5mod gemm;
6
7pub use crate::gemm::*;
8pub use gemm_common::Parallelism;
9
10pub use gemm_f16::f16;
11
12#[cfg(test)]
13mod tests {
14 use super::*;
15 extern crate alloc;
16 use alloc::{vec, vec::Vec};
17 use num_traits::Float;
18
19 #[test]
20 fn test_gemm_f16() {
21 let mut mnks = vec![];
22 mnks.push((16, 2, 1));
23 mnks.push((0, 0, 4));
24 mnks.push((16, 1, 1));
25 mnks.push((16, 3, 1));
26 mnks.push((16, 4, 1));
27 mnks.push((16, 1, 2));
28 mnks.push((16, 2, 2));
29 mnks.push((16, 3, 2));
30 mnks.push((16, 4, 2));
31 mnks.push((16, 16, 1));
32 mnks.push((64, 64, 0));
33 mnks.push((256, 256, 256));
34 mnks.push((4096, 4096, 4));
35 mnks.push((64, 64, 4));
36 mnks.push((0, 64, 4));
37 mnks.push((64, 0, 4));
38 mnks.push((8, 16, 1));
39 mnks.push((16, 8, 1));
40 mnks.push((1, 1, 2));
41 mnks.push((4, 4, 4));
42 mnks.push((1024, 1024, 1));
43 mnks.push((1024, 1024, 4));
44 mnks.push((63, 1, 10));
45 mnks.push((63, 2, 10));
46 mnks.push((63, 3, 10));
47 mnks.push((63, 4, 10));
48 mnks.push((1, 63, 10));
49 mnks.push((2, 63, 10));
50 mnks.push((3, 63, 10));
51 mnks.push((4, 63, 10));
52
53 for (m, n, k) in mnks {
54 dbg!(m, n, k);
55 for parallelism in [Parallelism::None, Parallelism::Rayon(0)] {
56 for alpha in [0.0, 1.0, 2.3] {
57 for beta in [0.0, 1.0, 2.3] {
58 dbg!(alpha, beta, parallelism);
59 let alpha = f16::from_f32(alpha);
60 let beta = f16::from_f32(beta);
61 let a_vec: Vec<f16> = (0..(m * k))
62 .map(|_| f16::from_f32(rand::random()))
63 .collect();
64 let b_vec: Vec<f16> = (0..(k * n))
65 .map(|_| f16::from_f32(rand::random()))
66 .collect();
67 let mut c_vec: Vec<f16> = (0..(m * n))
68 .map(|_| f16::from_f32(rand::random()))
69 .collect();
70 let mut d_vec = c_vec.clone();
71
72 unsafe {
73 gemm::gemm(
74 m,
75 n,
76 k,
77 c_vec.as_mut_ptr(),
78 m as isize,
79 1,
80 true,
81 a_vec.as_ptr(),
82 m as isize,
83 1,
84 b_vec.as_ptr(),
85 k as isize,
86 1,
87 alpha,
88 beta,
89 false,
90 false,
91 false,
92 parallelism,
93 );
94
95 gemm::gemm_fallback(
96 m,
97 n,
98 k,
99 d_vec.as_mut_ptr(),
100 m as isize,
101 1,
102 true,
103 a_vec.as_ptr(),
104 m as isize,
105 1,
106 b_vec.as_ptr(),
107 k as isize,
108 1,
109 alpha,
110 beta,
111 );
112 }
113 let eps = f16::from_f32(1e-1);
114 for (c, d) in c_vec.iter().zip(d_vec.iter()) {
115 let eps_rel = c.abs() * eps;
116 let eps_abs = eps;
117 let eps = if eps_rel > eps_abs { eps_rel } else { eps_abs };
118 assert_approx_eq::assert_approx_eq!(c, d, eps);
119 }
120 }
121 }
122 }
123 }
124 }
125
126 #[test]
127 fn test_gemm_real() {
128 let mut mnks = vec![];
129 mnks.push((2048, 255, 255));
131
132 mnks.push((256, 256, 256));
133 mnks.push((4096, 4096, 4));
134 mnks.push((64, 64, 4));
135 mnks.push((0, 64, 4));
136 mnks.push((64, 0, 4));
137 mnks.push((0, 0, 4));
138 mnks.push((64, 64, 0));
139 mnks.push((16, 1, 1));
140 mnks.push((16, 2, 1));
141 mnks.push((16, 3, 1));
142 mnks.push((16, 4, 1));
143 mnks.push((16, 1, 2));
144 mnks.push((16, 2, 2));
145 mnks.push((16, 3, 2));
146 mnks.push((16, 4, 2));
147 mnks.push((16, 16, 1));
148 mnks.push((8, 16, 1));
149 mnks.push((16, 8, 1));
150 mnks.push((1, 1, 2));
151 mnks.push((4, 4, 4));
152 mnks.push((1024, 1024, 1));
153 mnks.push((1024, 1024, 4));
154 mnks.push((63, 1, 10));
155 mnks.push((63, 2, 10));
156 mnks.push((63, 3, 10));
157 mnks.push((63, 4, 10));
158 mnks.push((1, 63, 10));
159 mnks.push((2, 63, 10));
160 mnks.push((3, 63, 10));
161 mnks.push((4, 63, 10));
162
163 for (m, n, k) in mnks {
164 dbg!(m, n, k);
165 for parallelism in [
166 Parallelism::None,
167 Parallelism::Rayon(0),
168 Parallelism::Rayon(128),
169 ] {
170 for alpha in [0.0, 1.0, 2.3] {
171 for beta in [0.0, 1.0, 2.3] {
172 dbg!(alpha, beta, parallelism);
173 let a_vec: Vec<f64> = (0..(m * k)).map(|_| rand::random()).collect();
174 let b_vec: Vec<f64> = (0..(k * n)).map(|_| rand::random()).collect();
175 let mut c_vec: Vec<f64> = (0..(m * n)).map(|_| rand::random()).collect();
176 let mut d_vec = c_vec.clone();
177
178 unsafe {
179 gemm::gemm(
180 m,
181 n,
182 k,
183 c_vec.as_mut_ptr(),
184 m as isize,
185 1,
186 true,
187 a_vec.as_ptr(),
188 m as isize,
189 1,
190 b_vec.as_ptr(),
191 k as isize,
192 1,
193 alpha,
194 beta,
195 false,
196 false,
197 false,
198 parallelism,
199 );
200
201 gemm::gemm_fallback(
202 m,
203 n,
204 k,
205 d_vec.as_mut_ptr(),
206 m as isize,
207 1,
208 true,
209 a_vec.as_ptr(),
210 m as isize,
211 1,
212 b_vec.as_ptr(),
213 k as isize,
214 1,
215 alpha,
216 beta,
217 );
218 }
219 for (c, d) in c_vec.iter().zip(d_vec.iter()) {
220 assert_approx_eq::assert_approx_eq!(c, d);
221 }
222 }
223 }
224 }
225 }
226 }
227
228 #[test]
229 fn test_gemm_cplx() {
230 let mut mnks = vec![];
231 mnks.push((0, 64, 4));
232 mnks.push((64, 0, 4));
233 mnks.push((0, 0, 4));
234 mnks.push((64, 64, 4));
235 mnks.push((64, 64, 0));
236 mnks.push((6, 3, 1));
237 mnks.push((1, 1, 2));
238 mnks.push((128, 128, 128));
239 mnks.push((16, 1, 1));
240 mnks.push((16, 2, 1));
241 mnks.push((16, 3, 1));
242 mnks.push((16, 4, 1));
243 mnks.push((16, 1, 2));
244 mnks.push((16, 2, 2));
245 mnks.push((16, 3, 2));
246 mnks.push((16, 4, 2));
247 mnks.push((16, 16, 1));
248 mnks.push((8, 16, 1));
249 mnks.push((16, 8, 1));
250 mnks.push((4, 4, 4));
251 mnks.push((1024, 1024, 4));
252 mnks.push((1024, 1024, 1));
253 mnks.push((63, 1, 10));
254 mnks.push((63, 2, 10));
255 mnks.push((63, 3, 10));
256 mnks.push((63, 4, 10));
257 mnks.push((1, 63, 10));
258 mnks.push((2, 63, 10));
259 mnks.push((3, 63, 10));
260 mnks.push((4, 63, 10));
261
262 for (m, n, k) in mnks {
263 dbg!(m, n, k);
264
265 let zero = c64::new(0.0, 0.0);
266 let one = c64::new(1.0, 0.0);
267 let arbitrary = c64::new(2.3, 4.1);
268 for alpha in [zero, one, arbitrary] {
269 for beta in [zero, one, arbitrary] {
270 dbg!(alpha, beta);
271 for conj_dst in [false, true] {
272 for conj_lhs in [false, true] {
273 for conj_rhs in [false, true] {
274 dbg!(conj_dst);
275 dbg!(conj_lhs);
276 dbg!(conj_rhs);
277 let a_vec: Vec<f64> =
278 (0..(2 * m * k)).map(|_| rand::random()).collect();
279 let b_vec: Vec<f64> =
280 (0..(2 * k * n)).map(|_| rand::random()).collect();
281 let mut c_vec: Vec<f64> =
282 (0..(2 * m * n)).map(|_| rand::random()).collect();
283 let mut d_vec = c_vec.clone();
284
285 unsafe {
286 gemm::gemm(
287 m,
288 n,
289 k,
290 c_vec.as_mut_ptr() as *mut c64,
291 m as isize,
292 1,
293 true,
294 a_vec.as_ptr() as *const c64,
295 m as isize,
296 1,
297 b_vec.as_ptr() as *const c64,
298 k as isize,
299 1,
300 alpha,
301 beta,
302 conj_dst,
303 conj_lhs,
304 conj_rhs,
305 Parallelism::Rayon(0),
306 );
307
308 gemm::gemm_cplx_fallback(
309 m,
310 n,
311 k,
312 d_vec.as_mut_ptr() as *mut c64,
313 m as isize,
314 1,
315 true,
316 a_vec.as_ptr() as *const c64,
317 m as isize,
318 1,
319 b_vec.as_ptr() as *const c64,
320 k as isize,
321 1,
322 alpha,
323 beta,
324 conj_dst,
325 conj_lhs,
326 conj_rhs,
327 );
328 }
329 for (c, d) in c_vec.iter().zip(d_vec.iter()) {
330 assert_approx_eq::assert_approx_eq!(c, d);
331 }
332 }
333 }
334 }
335 }
336 }
337 }
338 }
339}