1#![cfg_attr(not(feature = "std"), no_std)]
2#![warn(rust_2018_idioms)]
3
4mod gemm;
5
6#[cfg(feature = "f16")]
7pub use crate::gemm::f16;
8pub use crate::gemm::{c32, c64, gemm};
9pub use qlora_gemm_common::Parallelism;
10
11pub use qlora_gemm_common::gemm::{
12 get_lhs_packing_threshold_multi_thread, get_lhs_packing_threshold_single_thread,
13 get_rhs_packing_threshold, get_threading_threshold, set_lhs_packing_threshold_multi_thread,
14 set_lhs_packing_threshold_single_thread, set_rhs_packing_threshold, set_threading_threshold,
15 DEFAULT_LHS_PACKING_THRESHOLD_MULTI_THREAD, DEFAULT_LHS_PACKING_THRESHOLD_SINGLE_THREAD,
16 DEFAULT_RHS_PACKING_THRESHOLD, DEFAULT_THREADING_THRESHOLD,
17};
18pub use qlora_gemm_common::{get_wasm_simd128, set_wasm_simd128, DEFAULT_WASM_SIMD128};
19
20#[cfg(test)]
21mod tests {
22 use super::*;
23 extern crate alloc;
24 use alloc::{vec, vec::Vec};
25 use num_traits::Float;
26
27 #[test]
28 fn test_qlora_gemm_f16() {
29 let mut mnks = vec![];
30 mnks.push((4, 4, 4));
31 mnks.push((63, 2, 10));
32 mnks.push((16, 2, 1));
33 mnks.push((0, 0, 4));
34 mnks.push((16, 1, 1));
35 mnks.push((16, 3, 1));
36 mnks.push((16, 4, 1));
37 mnks.push((16, 1, 2));
38 mnks.push((16, 2, 2));
39 mnks.push((16, 3, 2));
40 mnks.push((16, 4, 2));
41 mnks.push((16, 16, 1));
42 mnks.push((64, 64, 0));
43 mnks.push((256, 256, 256));
44 mnks.push((4096, 4096, 4));
45 mnks.push((64, 64, 4));
46 mnks.push((0, 64, 4));
47 mnks.push((64, 0, 4));
48 mnks.push((8, 16, 1));
49 mnks.push((16, 8, 1));
50 mnks.push((1, 1, 2));
51 mnks.push((1024, 1024, 1));
52 mnks.push((1024, 1024, 4));
53 mnks.push((63, 1, 10));
54 mnks.push((63, 3, 10));
55 mnks.push((63, 4, 10));
56 mnks.push((1, 63, 10));
57 mnks.push((2, 63, 10));
58 mnks.push((3, 63, 10));
59 mnks.push((4, 63, 10));
60
61 for (m, n, k) in mnks {
62 #[cfg(feature = "std")]
63 dbg!(m, n, k);
64 for parallelism in [
65 Parallelism::None,
66 #[cfg(feature = "rayon")]
67 Parallelism::Rayon(0),
68 ] {
69 for alpha in [0.0, 1.0, 2.3] {
70 for beta in [0.0, 1.0, 2.3] {
71 #[cfg(feature = "std")]
72 dbg!(alpha, beta, parallelism);
73
74 for colmajor in [true, false] {
75 let alpha = f16::from_f32(alpha);
76 let beta = f16::from_f32(beta);
77 let a_vec: Vec<f16> = (0..(m * k))
78 .map(|_| f16::from_f32(rand::random()))
79 .collect();
80 let b_vec: Vec<f16> = (0..(k * n))
81 .map(|_| f16::from_f32(rand::random()))
82 .collect();
83 let mut c_vec: Vec<f16> = (0..(m * n))
84 .map(|_| f16::from_f32(rand::random()))
85 .collect();
86 let mut d_vec = c_vec.clone();
87
88 unsafe {
89 gemm::gemm(
90 m,
91 n,
92 k,
93 c_vec.as_mut_ptr(),
94 if colmajor { m } else { 1 } as isize,
95 if colmajor { 1 } else { n } as isize,
96 true,
97 a_vec.as_ptr(),
98 m as isize,
99 1,
100 b_vec.as_ptr(),
101 k as isize,
102 1,
103 alpha,
104 beta,
105 false,
106 false,
107 false,
108 parallelism,
109 );
110
111 gemm::gemm_fallback(
112 m,
113 n,
114 k,
115 d_vec.as_mut_ptr(),
116 if colmajor { m } else { 1 } as isize,
117 if colmajor { 1 } else { n } as isize,
118 true,
119 a_vec.as_ptr(),
120 m as isize,
121 1,
122 b_vec.as_ptr(),
123 k as isize,
124 1,
125 alpha,
126 beta,
127 );
128 }
129 let eps = f16::from_f32(1e-1);
130 for (c, d) in c_vec.iter().zip(d_vec.iter()) {
131 let eps_rel = c.abs() * eps;
132 let eps_abs = eps;
133 let eps = if eps_rel > eps_abs { eps_rel } else { eps_abs };
134 assert_approx_eq::assert_approx_eq!(c, d, eps);
135 }
136 }
137 }
138 }
139 }
140 }
141 }
142
143 #[test]
144 fn test_qlora_gemm_f32() {
145 set_wasm_simd128(true);
146
147 let mut mnks = vec![];
148 mnks.push((63, 2, 10));
149 mnks.push((1, 2, 10));
150 mnks.push((1, 63, 10));
151
152 mnks.push((2048, 255, 255));
154
155 mnks.push((256, 256, 256));
156 mnks.push((4096, 4096, 4));
157 mnks.push((64, 64, 4));
158 mnks.push((0, 64, 4));
159 mnks.push((64, 0, 4));
160 mnks.push((0, 0, 4));
161 mnks.push((64, 64, 0));
162 mnks.push((16, 1, 1));
163 mnks.push((16, 2, 1));
164 mnks.push((16, 3, 1));
165 mnks.push((16, 4, 1));
166 mnks.push((16, 1, 2));
167 mnks.push((16, 2, 2));
168 mnks.push((16, 3, 2));
169 mnks.push((16, 4, 2));
170 mnks.push((16, 16, 1));
171 mnks.push((8, 16, 1));
172 mnks.push((16, 8, 1));
173 mnks.push((1, 1, 2));
174 mnks.push((4, 4, 4));
175 mnks.push((1024, 1024, 1));
176 mnks.push((1024, 1024, 4));
177 mnks.push((63, 1, 10));
178 mnks.push((63, 3, 10));
179 mnks.push((63, 4, 10));
180 mnks.push((2, 63, 10));
181 mnks.push((3, 63, 10));
182 mnks.push((4, 63, 10));
183
184 for (m, n, k) in mnks {
185 #[cfg(feature = "std")]
186 dbg!(m, n, k);
187 for parallelism in [
188 Parallelism::None,
189 #[cfg(feature = "rayon")]
190 Parallelism::Rayon(0),
191 #[cfg(feature = "rayon")]
192 Parallelism::Rayon(128),
193 ] {
194 for alpha in [0.0, 1.0, 2.3] {
195 for beta in [0.0, 1.0, 2.3] {
196 #[cfg(feature = "std")]
197 dbg!(alpha, beta, parallelism);
198 for colmajor in [true, false] {
199 let a_vec: Vec<f32> = (0..(m * k)).map(|_| rand::random()).collect();
200 let b_vec: Vec<f32> = (0..(k * n)).map(|_| rand::random()).collect();
201 let mut c_vec: Vec<f32> =
202 (0..(m * n)).map(|_| rand::random()).collect();
203 let mut d_vec = c_vec.clone();
204
205 unsafe {
206 gemm::gemm(
207 m,
208 n,
209 k,
210 c_vec.as_mut_ptr(),
211 if colmajor { m } else { 1 } as isize,
212 if colmajor { 1 } else { n } as isize,
213 true,
214 a_vec.as_ptr(),
215 m as isize,
216 1,
217 b_vec.as_ptr(),
218 k as isize,
219 1,
220 alpha,
221 beta,
222 false,
223 false,
224 false,
225 parallelism,
226 );
227
228 gemm::gemm_fallback(
229 m,
230 n,
231 k,
232 d_vec.as_mut_ptr(),
233 if colmajor { m } else { 1 } as isize,
234 if colmajor { 1 } else { n } as isize,
235 true,
236 a_vec.as_ptr(),
237 m as isize,
238 1,
239 b_vec.as_ptr(),
240 k as isize,
241 1,
242 alpha,
243 beta,
244 );
245 }
246 for (c, d) in c_vec.iter().zip(d_vec.iter()) {
247 assert_approx_eq::assert_approx_eq!(c, d, 1e-3);
248 }
249 }
250 }
251 }
252 }
253 }
254 }
255
256 #[test]
257 fn test_qlora_gemm_f64() {
258 set_wasm_simd128(true);
259
260 let mut mnks = vec![];
261 mnks.push((63, 2, 10));
262 mnks.push((1, 2, 10));
263 mnks.push((1, 63, 10));
264
265 mnks.push((2048, 255, 255));
267
268 mnks.push((256, 256, 256));
269 mnks.push((4096, 4096, 4));
270 mnks.push((64, 64, 4));
271 mnks.push((0, 64, 4));
272 mnks.push((64, 0, 4));
273 mnks.push((0, 0, 4));
274 mnks.push((64, 64, 0));
275 mnks.push((16, 1, 1));
276 mnks.push((16, 2, 1));
277 mnks.push((16, 3, 1));
278 mnks.push((16, 4, 1));
279 mnks.push((16, 1, 2));
280 mnks.push((16, 2, 2));
281 mnks.push((16, 3, 2));
282 mnks.push((16, 4, 2));
283 mnks.push((16, 16, 1));
284 mnks.push((8, 16, 1));
285 mnks.push((16, 8, 1));
286 mnks.push((1, 1, 2));
287 mnks.push((4, 4, 4));
288 mnks.push((1024, 1024, 1));
289 mnks.push((1024, 1024, 4));
290 mnks.push((63, 1, 10));
291 mnks.push((63, 3, 10));
292 mnks.push((63, 4, 10));
293 mnks.push((2, 63, 10));
294 mnks.push((3, 63, 10));
295 mnks.push((4, 63, 10));
296
297 for (m, n, k) in mnks {
298 #[cfg(feature = "std")]
299 dbg!(m, n, k);
300 for parallelism in [
301 Parallelism::None,
302 #[cfg(feature = "rayon")]
303 Parallelism::Rayon(0),
304 #[cfg(feature = "rayon")]
305 Parallelism::Rayon(128),
306 ] {
307 for alpha in [0.0, 1.0, 2.3] {
308 for beta in [0.0, 1.0, 2.3] {
309 #[cfg(feature = "std")]
310 dbg!(alpha, beta, parallelism);
311 for colmajor in [true, false] {
312 let a_vec: Vec<f64> = (0..(m * k)).map(|_| rand::random()).collect();
313 let b_vec: Vec<f64> = (0..(k * n)).map(|_| rand::random()).collect();
314 let mut c_vec: Vec<f64> =
315 (0..(m * n)).map(|_| rand::random()).collect();
316 let mut d_vec = c_vec.clone();
317
318 unsafe {
319 gemm::gemm(
320 m,
321 n,
322 k,
323 c_vec.as_mut_ptr(),
324 if colmajor { m } else { 1 } as isize,
325 if colmajor { 1 } else { n } as isize,
326 true,
327 a_vec.as_ptr(),
328 m as isize,
329 1,
330 b_vec.as_ptr(),
331 k as isize,
332 1,
333 alpha,
334 beta,
335 false,
336 false,
337 false,
338 parallelism,
339 );
340
341 gemm::gemm_fallback(
342 m,
343 n,
344 k,
345 d_vec.as_mut_ptr(),
346 if colmajor { m } else { 1 } as isize,
347 if colmajor { 1 } else { n } as isize,
348 true,
349 a_vec.as_ptr(),
350 m as isize,
351 1,
352 b_vec.as_ptr(),
353 k as isize,
354 1,
355 alpha,
356 beta,
357 );
358 }
359 for (c, d) in c_vec.iter().zip(d_vec.iter()) {
360 assert_approx_eq::assert_approx_eq!(c, d);
361 }
362 }
363 }
364 }
365 }
366 }
367 }
368
369 #[test]
370 fn test_gemm_cplx32() {
371 let mut mnks = vec![];
372 mnks.push((4, 4, 4));
373 mnks.push((0, 64, 4));
374 mnks.push((64, 0, 4));
375 mnks.push((0, 0, 4));
376 mnks.push((64, 64, 4));
377 mnks.push((64, 64, 0));
378 mnks.push((6, 3, 1));
379 mnks.push((1, 1, 2));
380 mnks.push((128, 128, 128));
381 mnks.push((16, 1, 1));
382 mnks.push((16, 2, 1));
383 mnks.push((16, 3, 1));
384 mnks.push((16, 4, 1));
385 mnks.push((16, 1, 2));
386 mnks.push((16, 2, 2));
387 mnks.push((16, 3, 2));
388 mnks.push((16, 4, 2));
389 mnks.push((16, 16, 1));
390 mnks.push((8, 16, 1));
391 mnks.push((16, 8, 1));
392 mnks.push((1024, 1024, 4));
393 mnks.push((1024, 1024, 1));
394 mnks.push((63, 1, 10));
395 mnks.push((63, 2, 10));
396 mnks.push((63, 3, 10));
397 mnks.push((63, 4, 10));
398 mnks.push((1, 63, 10));
399 mnks.push((2, 63, 10));
400 mnks.push((3, 63, 10));
401 mnks.push((4, 63, 10));
402
403 for (m, n, k) in mnks {
404 #[cfg(feature = "std")]
405 dbg!(m, n, k);
406
407 let zero = c32::new(0.0, 0.0);
408 let one = c32::new(1.0, 0.0);
409 let arbitrary = c32::new(2.3, 4.1);
410 for alpha in [zero, one, arbitrary] {
411 for beta in [zero, one, arbitrary] {
412 #[cfg(feature = "std")]
413 dbg!(alpha, beta);
414 for conj_dst in [false, true] {
415 for conj_lhs in [false, true] {
416 for conj_rhs in [false, true] {
417 #[cfg(feature = "std")]
418 dbg!(conj_dst);
419 #[cfg(feature = "std")]
420 dbg!(conj_lhs);
421 #[cfg(feature = "std")]
422 dbg!(conj_rhs);
423 for colmajor in [true, false] {
424 let a_vec: Vec<f32> =
425 (0..(2 * m * k)).map(|_| rand::random()).collect();
426 let b_vec: Vec<f32> =
427 (0..(2 * k * n)).map(|_| rand::random()).collect();
428 let mut c_vec: Vec<f32> =
429 (0..(2 * m * n)).map(|_| rand::random()).collect();
430 let mut d_vec = c_vec.clone();
431
432 unsafe {
433 gemm::gemm(
434 m,
435 n,
436 k,
437 c_vec.as_mut_ptr() as *mut c32,
438 if colmajor { m } else { 1 } as isize,
439 if colmajor { 1 } else { n } as isize,
440 true,
441 a_vec.as_ptr() as *const c32,
442 m as isize,
443 1,
444 b_vec.as_ptr() as *const c32,
445 k as isize,
446 1,
447 alpha,
448 beta,
449 conj_dst,
450 conj_lhs,
451 conj_rhs,
452 #[cfg(feature = "rayon")]
453 Parallelism::Rayon(0),
454 #[cfg(not(feature = "rayon"))]
455 Parallelism::None,
456 );
457
458 gemm::gemm_cplx_fallback(
459 m,
460 n,
461 k,
462 d_vec.as_mut_ptr() as *mut c32,
463 if colmajor { m } else { 1 } as isize,
464 if colmajor { 1 } else { n } as isize,
465 true,
466 a_vec.as_ptr() as *const c32,
467 m as isize,
468 1,
469 b_vec.as_ptr() as *const c32,
470 k as isize,
471 1,
472 alpha,
473 beta,
474 conj_dst,
475 conj_lhs,
476 conj_rhs,
477 );
478 }
479 for (c, d) in c_vec.iter().zip(d_vec.iter()) {
480 assert_approx_eq::assert_approx_eq!(c, d, 1e-3);
481 }
482 }
483 }
484 }
485 }
486 }
487 }
488 }
489 }
490
491 #[test]
492 fn test_gemm_cplx64() {
493 let mut mnks = vec![];
494 mnks.push((4, 4, 4));
495 mnks.push((0, 64, 4));
496 mnks.push((64, 0, 4));
497 mnks.push((0, 0, 4));
498 mnks.push((64, 64, 4));
499 mnks.push((64, 64, 0));
500 mnks.push((6, 3, 1));
501 mnks.push((1, 1, 2));
502 mnks.push((128, 128, 128));
503 mnks.push((16, 1, 1));
504 mnks.push((16, 2, 1));
505 mnks.push((16, 3, 1));
506 mnks.push((16, 4, 1));
507 mnks.push((16, 1, 2));
508 mnks.push((16, 2, 2));
509 mnks.push((16, 3, 2));
510 mnks.push((16, 4, 2));
511 mnks.push((16, 16, 1));
512 mnks.push((8, 16, 1));
513 mnks.push((16, 8, 1));
514 mnks.push((1024, 1024, 4));
515 mnks.push((1024, 1024, 1));
516 mnks.push((63, 1, 10));
517 mnks.push((63, 2, 10));
518 mnks.push((63, 3, 10));
519 mnks.push((63, 4, 10));
520 mnks.push((1, 63, 10));
521 mnks.push((2, 63, 10));
522 mnks.push((3, 63, 10));
523 mnks.push((4, 63, 10));
524
525 for (m, n, k) in mnks {
526 #[cfg(feature = "std")]
527 dbg!(m, n, k);
528
529 let zero = c64::new(0.0, 0.0);
530 let one = c64::new(1.0, 0.0);
531 let arbitrary = c64::new(2.3, 4.1);
532 for alpha in [zero, one, arbitrary] {
533 for beta in [zero, one, arbitrary] {
534 #[cfg(feature = "std")]
535 dbg!(alpha, beta);
536 for conj_dst in [false, true] {
537 for conj_lhs in [false, true] {
538 for conj_rhs in [false, true] {
539 #[cfg(feature = "std")]
540 dbg!(conj_dst);
541 #[cfg(feature = "std")]
542 dbg!(conj_lhs);
543 #[cfg(feature = "std")]
544 dbg!(conj_rhs);
545 for colmajor in [true, false] {
546 let a_vec: Vec<f64> =
547 (0..(2 * m * k)).map(|_| rand::random()).collect();
548 let b_vec: Vec<f64> =
549 (0..(2 * k * n)).map(|_| rand::random()).collect();
550 let mut c_vec: Vec<f64> =
551 (0..(2 * m * n)).map(|_| rand::random()).collect();
552 let mut d_vec = c_vec.clone();
553
554 unsafe {
555 gemm::gemm(
556 m,
557 n,
558 k,
559 c_vec.as_mut_ptr() as *mut c64,
560 if colmajor { m } else { 1 } as isize,
561 if colmajor { 1 } else { n } as isize,
562 true,
563 a_vec.as_ptr() as *const c64,
564 m as isize,
565 1,
566 b_vec.as_ptr() as *const c64,
567 k as isize,
568 1,
569 alpha,
570 beta,
571 conj_dst,
572 conj_lhs,
573 conj_rhs,
574 #[cfg(feature = "rayon")]
575 Parallelism::Rayon(0),
576 #[cfg(not(feature = "rayon"))]
577 Parallelism::None,
578 );
579
580 gemm::gemm_cplx_fallback(
581 m,
582 n,
583 k,
584 d_vec.as_mut_ptr() as *mut c64,
585 if colmajor { m } else { 1 } as isize,
586 if colmajor { 1 } else { n } as isize,
587 true,
588 a_vec.as_ptr() as *const c64,
589 m as isize,
590 1,
591 b_vec.as_ptr() as *const c64,
592 k as isize,
593 1,
594 alpha,
595 beta,
596 conj_dst,
597 conj_lhs,
598 conj_rhs,
599 );
600 }
601 for (c, d) in c_vec.iter().zip(d_vec.iter()) {
602 assert_approx_eq::assert_approx_eq!(c, d);
603 }
604 }
605 }
606 }
607 }
608 }
609 }
610 }
611 }
612}