Skip to main content

qlora_gemm/
lib.rs

1#![cfg_attr(not(feature = "std"), no_std)]
2#![warn(rust_2018_idioms)]
3
4mod gemm;
5
6#[cfg(feature = "f16")]
7pub use crate::gemm::f16;
8pub use crate::gemm::{c32, c64, gemm};
9pub use qlora_gemm_common::Parallelism;
10
11pub use qlora_gemm_common::gemm::{
12    get_lhs_packing_threshold_multi_thread, get_lhs_packing_threshold_single_thread,
13    get_rhs_packing_threshold, get_threading_threshold, set_lhs_packing_threshold_multi_thread,
14    set_lhs_packing_threshold_single_thread, set_rhs_packing_threshold, set_threading_threshold,
15    DEFAULT_LHS_PACKING_THRESHOLD_MULTI_THREAD, DEFAULT_LHS_PACKING_THRESHOLD_SINGLE_THREAD,
16    DEFAULT_RHS_PACKING_THRESHOLD, DEFAULT_THREADING_THRESHOLD,
17};
18pub use qlora_gemm_common::{get_wasm_simd128, set_wasm_simd128, DEFAULT_WASM_SIMD128};
19
20#[cfg(test)]
21mod tests {
22    use super::*;
23    extern crate alloc;
24    use alloc::{vec, vec::Vec};
25    use num_traits::Float;
26
27    #[test]
28    fn test_qlora_gemm_f16() {
29        let mut mnks = vec![];
30        mnks.push((4, 4, 4));
31        mnks.push((63, 2, 10));
32        mnks.push((16, 2, 1));
33        mnks.push((0, 0, 4));
34        mnks.push((16, 1, 1));
35        mnks.push((16, 3, 1));
36        mnks.push((16, 4, 1));
37        mnks.push((16, 1, 2));
38        mnks.push((16, 2, 2));
39        mnks.push((16, 3, 2));
40        mnks.push((16, 4, 2));
41        mnks.push((16, 16, 1));
42        mnks.push((64, 64, 0));
43        mnks.push((256, 256, 256));
44        mnks.push((4096, 4096, 4));
45        mnks.push((64, 64, 4));
46        mnks.push((0, 64, 4));
47        mnks.push((64, 0, 4));
48        mnks.push((8, 16, 1));
49        mnks.push((16, 8, 1));
50        mnks.push((1, 1, 2));
51        mnks.push((1024, 1024, 1));
52        mnks.push((1024, 1024, 4));
53        mnks.push((63, 1, 10));
54        mnks.push((63, 3, 10));
55        mnks.push((63, 4, 10));
56        mnks.push((1, 63, 10));
57        mnks.push((2, 63, 10));
58        mnks.push((3, 63, 10));
59        mnks.push((4, 63, 10));
60
61        for (m, n, k) in mnks {
62            #[cfg(feature = "std")]
63            dbg!(m, n, k);
64            for parallelism in [
65                Parallelism::None,
66                #[cfg(feature = "rayon")]
67                Parallelism::Rayon(0),
68            ] {
69                for alpha in [0.0, 1.0, 2.3] {
70                    for beta in [0.0, 1.0, 2.3] {
71                        #[cfg(feature = "std")]
72                        dbg!(alpha, beta, parallelism);
73
74                        for colmajor in [true, false] {
75                            let alpha = f16::from_f32(alpha);
76                            let beta = f16::from_f32(beta);
77                            let a_vec: Vec<f16> = (0..(m * k))
78                                .map(|_| f16::from_f32(rand::random()))
79                                .collect();
80                            let b_vec: Vec<f16> = (0..(k * n))
81                                .map(|_| f16::from_f32(rand::random()))
82                                .collect();
83                            let mut c_vec: Vec<f16> = (0..(m * n))
84                                .map(|_| f16::from_f32(rand::random()))
85                                .collect();
86                            let mut d_vec = c_vec.clone();
87
88                            unsafe {
89                                gemm::gemm(
90                                    m,
91                                    n,
92                                    k,
93                                    c_vec.as_mut_ptr(),
94                                    if colmajor { m } else { 1 } as isize,
95                                    if colmajor { 1 } else { n } as isize,
96                                    true,
97                                    a_vec.as_ptr(),
98                                    m as isize,
99                                    1,
100                                    b_vec.as_ptr(),
101                                    k as isize,
102                                    1,
103                                    alpha,
104                                    beta,
105                                    false,
106                                    false,
107                                    false,
108                                    parallelism,
109                                );
110
111                                gemm::gemm_fallback(
112                                    m,
113                                    n,
114                                    k,
115                                    d_vec.as_mut_ptr(),
116                                    if colmajor { m } else { 1 } as isize,
117                                    if colmajor { 1 } else { n } as isize,
118                                    true,
119                                    a_vec.as_ptr(),
120                                    m as isize,
121                                    1,
122                                    b_vec.as_ptr(),
123                                    k as isize,
124                                    1,
125                                    alpha,
126                                    beta,
127                                );
128                            }
129                            let eps = f16::from_f32(1e-1);
130                            for (c, d) in c_vec.iter().zip(d_vec.iter()) {
131                                let eps_rel = c.abs() * eps;
132                                let eps_abs = eps;
133                                let eps = if eps_rel > eps_abs { eps_rel } else { eps_abs };
134                                assert_approx_eq::assert_approx_eq!(c, d, eps);
135                            }
136                        }
137                    }
138                }
139            }
140        }
141    }
142
143    #[test]
144    fn test_qlora_gemm_f32() {
145        set_wasm_simd128(true);
146
147        let mut mnks = vec![];
148        mnks.push((63, 2, 10));
149        mnks.push((1, 2, 10));
150        mnks.push((1, 63, 10));
151
152        // large m to trigger parallelized rhs packing with big number of threads and small n
153        mnks.push((2048, 255, 255));
154
155        mnks.push((256, 256, 256));
156        mnks.push((4096, 4096, 4));
157        mnks.push((64, 64, 4));
158        mnks.push((0, 64, 4));
159        mnks.push((64, 0, 4));
160        mnks.push((0, 0, 4));
161        mnks.push((64, 64, 0));
162        mnks.push((16, 1, 1));
163        mnks.push((16, 2, 1));
164        mnks.push((16, 3, 1));
165        mnks.push((16, 4, 1));
166        mnks.push((16, 1, 2));
167        mnks.push((16, 2, 2));
168        mnks.push((16, 3, 2));
169        mnks.push((16, 4, 2));
170        mnks.push((16, 16, 1));
171        mnks.push((8, 16, 1));
172        mnks.push((16, 8, 1));
173        mnks.push((1, 1, 2));
174        mnks.push((4, 4, 4));
175        mnks.push((1024, 1024, 1));
176        mnks.push((1024, 1024, 4));
177        mnks.push((63, 1, 10));
178        mnks.push((63, 3, 10));
179        mnks.push((63, 4, 10));
180        mnks.push((2, 63, 10));
181        mnks.push((3, 63, 10));
182        mnks.push((4, 63, 10));
183
184        for (m, n, k) in mnks {
185            #[cfg(feature = "std")]
186            dbg!(m, n, k);
187            for parallelism in [
188                Parallelism::None,
189                #[cfg(feature = "rayon")]
190                Parallelism::Rayon(0),
191                #[cfg(feature = "rayon")]
192                Parallelism::Rayon(128),
193            ] {
194                for alpha in [0.0, 1.0, 2.3] {
195                    for beta in [0.0, 1.0, 2.3] {
196                        #[cfg(feature = "std")]
197                        dbg!(alpha, beta, parallelism);
198                        for colmajor in [true, false] {
199                            let a_vec: Vec<f32> = (0..(m * k)).map(|_| rand::random()).collect();
200                            let b_vec: Vec<f32> = (0..(k * n)).map(|_| rand::random()).collect();
201                            let mut c_vec: Vec<f32> =
202                                (0..(m * n)).map(|_| rand::random()).collect();
203                            let mut d_vec = c_vec.clone();
204
205                            unsafe {
206                                gemm::gemm(
207                                    m,
208                                    n,
209                                    k,
210                                    c_vec.as_mut_ptr(),
211                                    if colmajor { m } else { 1 } as isize,
212                                    if colmajor { 1 } else { n } as isize,
213                                    true,
214                                    a_vec.as_ptr(),
215                                    m as isize,
216                                    1,
217                                    b_vec.as_ptr(),
218                                    k as isize,
219                                    1,
220                                    alpha,
221                                    beta,
222                                    false,
223                                    false,
224                                    false,
225                                    parallelism,
226                                );
227
228                                gemm::gemm_fallback(
229                                    m,
230                                    n,
231                                    k,
232                                    d_vec.as_mut_ptr(),
233                                    if colmajor { m } else { 1 } as isize,
234                                    if colmajor { 1 } else { n } as isize,
235                                    true,
236                                    a_vec.as_ptr(),
237                                    m as isize,
238                                    1,
239                                    b_vec.as_ptr(),
240                                    k as isize,
241                                    1,
242                                    alpha,
243                                    beta,
244                                );
245                            }
246                            for (c, d) in c_vec.iter().zip(d_vec.iter()) {
247                                assert_approx_eq::assert_approx_eq!(c, d, 1e-3);
248                            }
249                        }
250                    }
251                }
252            }
253        }
254    }
255
256    #[test]
257    fn test_qlora_gemm_f64() {
258        set_wasm_simd128(true);
259
260        let mut mnks = vec![];
261        mnks.push((63, 2, 10));
262        mnks.push((1, 2, 10));
263        mnks.push((1, 63, 10));
264
265        // large m to trigger parallelized rhs packing with big number of threads and small n
266        mnks.push((2048, 255, 255));
267
268        mnks.push((256, 256, 256));
269        mnks.push((4096, 4096, 4));
270        mnks.push((64, 64, 4));
271        mnks.push((0, 64, 4));
272        mnks.push((64, 0, 4));
273        mnks.push((0, 0, 4));
274        mnks.push((64, 64, 0));
275        mnks.push((16, 1, 1));
276        mnks.push((16, 2, 1));
277        mnks.push((16, 3, 1));
278        mnks.push((16, 4, 1));
279        mnks.push((16, 1, 2));
280        mnks.push((16, 2, 2));
281        mnks.push((16, 3, 2));
282        mnks.push((16, 4, 2));
283        mnks.push((16, 16, 1));
284        mnks.push((8, 16, 1));
285        mnks.push((16, 8, 1));
286        mnks.push((1, 1, 2));
287        mnks.push((4, 4, 4));
288        mnks.push((1024, 1024, 1));
289        mnks.push((1024, 1024, 4));
290        mnks.push((63, 1, 10));
291        mnks.push((63, 3, 10));
292        mnks.push((63, 4, 10));
293        mnks.push((2, 63, 10));
294        mnks.push((3, 63, 10));
295        mnks.push((4, 63, 10));
296
297        for (m, n, k) in mnks {
298            #[cfg(feature = "std")]
299            dbg!(m, n, k);
300            for parallelism in [
301                Parallelism::None,
302                #[cfg(feature = "rayon")]
303                Parallelism::Rayon(0),
304                #[cfg(feature = "rayon")]
305                Parallelism::Rayon(128),
306            ] {
307                for alpha in [0.0, 1.0, 2.3] {
308                    for beta in [0.0, 1.0, 2.3] {
309                        #[cfg(feature = "std")]
310                        dbg!(alpha, beta, parallelism);
311                        for colmajor in [true, false] {
312                            let a_vec: Vec<f64> = (0..(m * k)).map(|_| rand::random()).collect();
313                            let b_vec: Vec<f64> = (0..(k * n)).map(|_| rand::random()).collect();
314                            let mut c_vec: Vec<f64> =
315                                (0..(m * n)).map(|_| rand::random()).collect();
316                            let mut d_vec = c_vec.clone();
317
318                            unsafe {
319                                gemm::gemm(
320                                    m,
321                                    n,
322                                    k,
323                                    c_vec.as_mut_ptr(),
324                                    if colmajor { m } else { 1 } as isize,
325                                    if colmajor { 1 } else { n } as isize,
326                                    true,
327                                    a_vec.as_ptr(),
328                                    m as isize,
329                                    1,
330                                    b_vec.as_ptr(),
331                                    k as isize,
332                                    1,
333                                    alpha,
334                                    beta,
335                                    false,
336                                    false,
337                                    false,
338                                    parallelism,
339                                );
340
341                                gemm::gemm_fallback(
342                                    m,
343                                    n,
344                                    k,
345                                    d_vec.as_mut_ptr(),
346                                    if colmajor { m } else { 1 } as isize,
347                                    if colmajor { 1 } else { n } as isize,
348                                    true,
349                                    a_vec.as_ptr(),
350                                    m as isize,
351                                    1,
352                                    b_vec.as_ptr(),
353                                    k as isize,
354                                    1,
355                                    alpha,
356                                    beta,
357                                );
358                            }
359                            for (c, d) in c_vec.iter().zip(d_vec.iter()) {
360                                assert_approx_eq::assert_approx_eq!(c, d);
361                            }
362                        }
363                    }
364                }
365            }
366        }
367    }
368
369    #[test]
370    fn test_gemm_cplx32() {
371        let mut mnks = vec![];
372        mnks.push((4, 4, 4));
373        mnks.push((0, 64, 4));
374        mnks.push((64, 0, 4));
375        mnks.push((0, 0, 4));
376        mnks.push((64, 64, 4));
377        mnks.push((64, 64, 0));
378        mnks.push((6, 3, 1));
379        mnks.push((1, 1, 2));
380        mnks.push((128, 128, 128));
381        mnks.push((16, 1, 1));
382        mnks.push((16, 2, 1));
383        mnks.push((16, 3, 1));
384        mnks.push((16, 4, 1));
385        mnks.push((16, 1, 2));
386        mnks.push((16, 2, 2));
387        mnks.push((16, 3, 2));
388        mnks.push((16, 4, 2));
389        mnks.push((16, 16, 1));
390        mnks.push((8, 16, 1));
391        mnks.push((16, 8, 1));
392        mnks.push((1024, 1024, 4));
393        mnks.push((1024, 1024, 1));
394        mnks.push((63, 1, 10));
395        mnks.push((63, 2, 10));
396        mnks.push((63, 3, 10));
397        mnks.push((63, 4, 10));
398        mnks.push((1, 63, 10));
399        mnks.push((2, 63, 10));
400        mnks.push((3, 63, 10));
401        mnks.push((4, 63, 10));
402
403        for (m, n, k) in mnks {
404            #[cfg(feature = "std")]
405            dbg!(m, n, k);
406
407            let zero = c32::new(0.0, 0.0);
408            let one = c32::new(1.0, 0.0);
409            let arbitrary = c32::new(2.3, 4.1);
410            for alpha in [zero, one, arbitrary] {
411                for beta in [zero, one, arbitrary] {
412                    #[cfg(feature = "std")]
413                    dbg!(alpha, beta);
414                    for conj_dst in [false, true] {
415                        for conj_lhs in [false, true] {
416                            for conj_rhs in [false, true] {
417                                #[cfg(feature = "std")]
418                                dbg!(conj_dst);
419                                #[cfg(feature = "std")]
420                                dbg!(conj_lhs);
421                                #[cfg(feature = "std")]
422                                dbg!(conj_rhs);
423                                for colmajor in [true, false] {
424                                    let a_vec: Vec<f32> =
425                                        (0..(2 * m * k)).map(|_| rand::random()).collect();
426                                    let b_vec: Vec<f32> =
427                                        (0..(2 * k * n)).map(|_| rand::random()).collect();
428                                    let mut c_vec: Vec<f32> =
429                                        (0..(2 * m * n)).map(|_| rand::random()).collect();
430                                    let mut d_vec = c_vec.clone();
431
432                                    unsafe {
433                                        gemm::gemm(
434                                            m,
435                                            n,
436                                            k,
437                                            c_vec.as_mut_ptr() as *mut c32,
438                                            if colmajor { m } else { 1 } as isize,
439                                            if colmajor { 1 } else { n } as isize,
440                                            true,
441                                            a_vec.as_ptr() as *const c32,
442                                            m as isize,
443                                            1,
444                                            b_vec.as_ptr() as *const c32,
445                                            k as isize,
446                                            1,
447                                            alpha,
448                                            beta,
449                                            conj_dst,
450                                            conj_lhs,
451                                            conj_rhs,
452                                            #[cfg(feature = "rayon")]
453                                            Parallelism::Rayon(0),
454                                            #[cfg(not(feature = "rayon"))]
455                                            Parallelism::None,
456                                        );
457
458                                        gemm::gemm_cplx_fallback(
459                                            m,
460                                            n,
461                                            k,
462                                            d_vec.as_mut_ptr() as *mut c32,
463                                            if colmajor { m } else { 1 } as isize,
464                                            if colmajor { 1 } else { n } as isize,
465                                            true,
466                                            a_vec.as_ptr() as *const c32,
467                                            m as isize,
468                                            1,
469                                            b_vec.as_ptr() as *const c32,
470                                            k as isize,
471                                            1,
472                                            alpha,
473                                            beta,
474                                            conj_dst,
475                                            conj_lhs,
476                                            conj_rhs,
477                                        );
478                                    }
479                                    for (c, d) in c_vec.iter().zip(d_vec.iter()) {
480                                        assert_approx_eq::assert_approx_eq!(c, d, 1e-3);
481                                    }
482                                }
483                            }
484                        }
485                    }
486                }
487            }
488        }
489    }
490
491    #[test]
492    fn test_gemm_cplx64() {
493        let mut mnks = vec![];
494        mnks.push((4, 4, 4));
495        mnks.push((0, 64, 4));
496        mnks.push((64, 0, 4));
497        mnks.push((0, 0, 4));
498        mnks.push((64, 64, 4));
499        mnks.push((64, 64, 0));
500        mnks.push((6, 3, 1));
501        mnks.push((1, 1, 2));
502        mnks.push((128, 128, 128));
503        mnks.push((16, 1, 1));
504        mnks.push((16, 2, 1));
505        mnks.push((16, 3, 1));
506        mnks.push((16, 4, 1));
507        mnks.push((16, 1, 2));
508        mnks.push((16, 2, 2));
509        mnks.push((16, 3, 2));
510        mnks.push((16, 4, 2));
511        mnks.push((16, 16, 1));
512        mnks.push((8, 16, 1));
513        mnks.push((16, 8, 1));
514        mnks.push((1024, 1024, 4));
515        mnks.push((1024, 1024, 1));
516        mnks.push((63, 1, 10));
517        mnks.push((63, 2, 10));
518        mnks.push((63, 3, 10));
519        mnks.push((63, 4, 10));
520        mnks.push((1, 63, 10));
521        mnks.push((2, 63, 10));
522        mnks.push((3, 63, 10));
523        mnks.push((4, 63, 10));
524
525        for (m, n, k) in mnks {
526            #[cfg(feature = "std")]
527            dbg!(m, n, k);
528
529            let zero = c64::new(0.0, 0.0);
530            let one = c64::new(1.0, 0.0);
531            let arbitrary = c64::new(2.3, 4.1);
532            for alpha in [zero, one, arbitrary] {
533                for beta in [zero, one, arbitrary] {
534                    #[cfg(feature = "std")]
535                    dbg!(alpha, beta);
536                    for conj_dst in [false, true] {
537                        for conj_lhs in [false, true] {
538                            for conj_rhs in [false, true] {
539                                #[cfg(feature = "std")]
540                                dbg!(conj_dst);
541                                #[cfg(feature = "std")]
542                                dbg!(conj_lhs);
543                                #[cfg(feature = "std")]
544                                dbg!(conj_rhs);
545                                for colmajor in [true, false] {
546                                    let a_vec: Vec<f64> =
547                                        (0..(2 * m * k)).map(|_| rand::random()).collect();
548                                    let b_vec: Vec<f64> =
549                                        (0..(2 * k * n)).map(|_| rand::random()).collect();
550                                    let mut c_vec: Vec<f64> =
551                                        (0..(2 * m * n)).map(|_| rand::random()).collect();
552                                    let mut d_vec = c_vec.clone();
553
554                                    unsafe {
555                                        gemm::gemm(
556                                            m,
557                                            n,
558                                            k,
559                                            c_vec.as_mut_ptr() as *mut c64,
560                                            if colmajor { m } else { 1 } as isize,
561                                            if colmajor { 1 } else { n } as isize,
562                                            true,
563                                            a_vec.as_ptr() as *const c64,
564                                            m as isize,
565                                            1,
566                                            b_vec.as_ptr() as *const c64,
567                                            k as isize,
568                                            1,
569                                            alpha,
570                                            beta,
571                                            conj_dst,
572                                            conj_lhs,
573                                            conj_rhs,
574                                            #[cfg(feature = "rayon")]
575                                            Parallelism::Rayon(0),
576                                            #[cfg(not(feature = "rayon"))]
577                                            Parallelism::None,
578                                        );
579
580                                        gemm::gemm_cplx_fallback(
581                                            m,
582                                            n,
583                                            k,
584                                            d_vec.as_mut_ptr() as *mut c64,
585                                            if colmajor { m } else { 1 } as isize,
586                                            if colmajor { 1 } else { n } as isize,
587                                            true,
588                                            a_vec.as_ptr() as *const c64,
589                                            m as isize,
590                                            1,
591                                            b_vec.as_ptr() as *const c64,
592                                            k as isize,
593                                            1,
594                                            alpha,
595                                            beta,
596                                            conj_dst,
597                                            conj_lhs,
598                                            conj_rhs,
599                                        );
600                                    }
601                                    for (c, d) in c_vec.iter().zip(d_vec.iter()) {
602                                        assert_approx_eq::assert_approx_eq!(c, d);
603                                    }
604                                }
605                            }
606                        }
607                    }
608                }
609            }
610        }
611    }
612}