pire_dev/
lib.rs

1#![allow(non_camel_case_types)]
2#![allow(dead_code)]
3#![allow(unused)]
4
5// Will be using libloading instead of native dynmic linking, this is more convenient and acceptible
6// since this crate is for testing
7
8use libc::{c_double, c_float, c_int, c_schar, c_short, c_ushort, c_void};
9
10use num_complex::{c32, c64, Complex32, Complex64};
11
12use half::f16;
13use once_cell::sync::Lazy;
14
15#[repr(C)]
16#[derive(Clone, Copy, Debug, PartialEq)]
17#[allow(clippy::enum_variant_names)]
18pub enum CBLAS_LAYOUT {
19    CblasRowMajor = 101,
20    CblasColMajor = 102,
21}
22pub use self::CBLAS_LAYOUT::*;
23
24#[repr(C)]
25pub struct cntx_t(i32);
26
27#[repr(C)]
28#[derive(Clone, Copy, Debug, PartialEq)]
29#[allow(clippy::enum_variant_names)]
30pub enum CBLAS_TRANSPOSE {
31    CblasNoTrans = 111,
32    CblasTrans = 112,
33    CblasConjTrans = 113,
34}
35pub use self::CBLAS_TRANSPOSE::*;
36
37#[repr(C)]
38#[derive(Clone, Copy, Debug)]
39#[allow(clippy::enum_variant_names)]
40pub enum CBLAS_OFFSET {
41    CblasRowOffset = 171,
42    CblasColOffset = 172,
43    CblasFixOffset = 173,
44}
45pub use self::CBLAS_OFFSET::*;
46
47type SGEMM_FN_TYPE = unsafe extern "C" fn(
48    CBLAS_LAYOUT,
49    CBLAS_TRANSPOSE,
50    CBLAS_TRANSPOSE,
51    c_int,
52    c_int,
53    c_int,
54    c_float,
55    *const c_float,
56    c_int,
57    *const c_float,
58    c_int,
59    c_float,
60    *mut c_float,
61    c_int,
62);
63
64type SGEMM_B_FN_TYPE = unsafe extern "C" fn(
65    CBLAS_LAYOUT,
66    *const CBLAS_TRANSPOSE,
67    *const CBLAS_TRANSPOSE,
68    *const c_int,
69    *const c_int,
70    *const c_int,
71    *const c_float,
72    *const *const c_float,
73    *const c_int,
74    *const *const c_float,
75    *const c_int,
76    *const c_float,
77    *const *mut c_float,
78    *const c_int,
79    c_int,
80    *const c_int,
81);
82
83type DGEMM_FN_TYPE = unsafe extern "C" fn(
84    CBLAS_LAYOUT,
85    CBLAS_TRANSPOSE,
86    CBLAS_TRANSPOSE,
87    c_int,
88    c_int,
89    c_int,
90    c_double,
91    *const c_double,
92    c_int,
93    *const c_double,
94    c_int,
95    c_double,
96    *mut c_double,
97    c_int,
98);
99
100type CGEMM_FN_TYPE = unsafe extern "C" fn(
101    CBLAS_LAYOUT,
102    CBLAS_TRANSPOSE,
103    CBLAS_TRANSPOSE,
104    c_int,
105    c_int,
106    c_int,
107    *const c_void,
108    *const c_void,
109    c_int,
110    *const c_void,
111    c_int,
112    *const c_void,
113    *mut c_void,
114    c_int,
115);
116
117type ZGEMM_FN_TYPE = unsafe extern "C" fn(
118    CBLAS_LAYOUT,
119    CBLAS_TRANSPOSE,
120    CBLAS_TRANSPOSE,
121    c_int,
122    c_int,
123    c_int,
124    *const c_void,
125    *const c_void,
126    c_int,
127    *const c_void,
128    c_int,
129    *const c_void,
130    *mut c_void,
131    c_int,
132);
133
134type HGEMM_FN_TYPE = unsafe extern "C" fn(
135    CBLAS_LAYOUT,
136    CBLAS_TRANSPOSE,
137    CBLAS_TRANSPOSE,
138    c_int,
139    c_int,
140    c_int,
141    c_ushort,
142    *const c_ushort,
143    c_int,
144    *const c_ushort,
145    c_int,
146    c_ushort,
147    *mut c_ushort,
148    c_int,
149);
150
151type GEMM_I8_FN_TYPE = unsafe extern "C" fn(
152    CBLAS_LAYOUT,
153    CBLAS_TRANSPOSE,
154    CBLAS_TRANSPOSE,
155    CBLAS_OFFSET,
156    c_int,
157    c_int,
158    c_int,
159    c_float,
160    *const c_void,
161    c_int,
162    c_schar,
163    *const c_void,
164    c_int,
165    c_schar,
166    c_float,
167    *mut c_int,
168    c_int,
169    *const c_int,
170);
171
172type GEMM_I16_FN_TYPE = unsafe extern "C" fn(
173    CBLAS_LAYOUT,
174    CBLAS_TRANSPOSE,
175    CBLAS_TRANSPOSE,
176    CBLAS_OFFSET,
177    c_int,
178    c_int,
179    c_int,
180    c_float,
181    *const c_short,
182    c_int,
183    c_short,
184    *const c_short,
185    c_int,
186    c_short,
187    c_float,
188    *mut c_int,
189    c_int,
190    *const c_int,
191);
192
193const PROJECT_DIR: &str = core::env!("CARGO_MANIFEST_DIR");
194
195// TODO: Add more reasonalble deafult paths for different os,s windows/unix
196pub static CBLAS_LIBRARY_MKL: Lazy<libloading::Library> = Lazy::new(|| unsafe {
197    let default_mkl_path = format!("{PROJECT_DIR}/../../.env/Library/bin/mkl_rt.2.dll");
198    let mkl_path = std::env::var("PIRE_MKL_PATH").unwrap_or(default_mkl_path);
199    libloading::Library::new(mkl_path).unwrap()
200});
201
202pub static CBLAS_LIBRARY_OPENBLAS: Lazy<libloading::Library> = Lazy::new(|| unsafe {
203    let default_openblas_path = format!("{PROJECT_DIR}/../../openblas/openblas.dll");
204    let openblas_path = std::env::var("PIRE_OPENBLAS_PATH").unwrap_or(default_openblas_path);
205    libloading::Library::new(openblas_path).unwrap()
206});
207
208pub static CBLAS_LIBRARY_BLIS: Lazy<libloading::Library> = Lazy::new(|| unsafe {
209    let default_blis_path = format!("{PROJECT_DIR}/../../blis/blis.dll");
210    let blis_path = std::env::var("PIRE_BLIS_PATH").unwrap_or(default_blis_path);
211    libloading::Library::new(blis_path).unwrap()
212});
213
214pub static CBLAS_SGEMM_MKL: Lazy<libloading::Symbol<'static, SGEMM_FN_TYPE>> = Lazy::new(|| unsafe {
215    let cblas_gemm = CBLAS_LIBRARY_MKL.get(b"cblas_sgemm").unwrap();
216    cblas_gemm
217});
218
219pub static CBLAS_SGEMM_OPENBLAS: Lazy<libloading::Symbol<'static, SGEMM_FN_TYPE>> = Lazy::new(|| unsafe {
220    let cblas_gemm = CBLAS_LIBRARY_OPENBLAS.get(b"cblas_sgemm").unwrap();
221    cblas_gemm
222});
223
224pub static CBLAS_SGEMM_BLIS: Lazy<libloading::Symbol<'static, SGEMM_FN_TYPE>> = Lazy::new(|| unsafe {
225    let cblas_gemm = CBLAS_LIBRARY_BLIS.get(b"cblas_sgemm").unwrap();
226    cblas_gemm
227});
228
229pub static CBLAS_SGEMM_B_MKL: Lazy<libloading::Symbol<'static, SGEMM_B_FN_TYPE>> = Lazy::new(|| unsafe {
230    let cblas_gemm = CBLAS_LIBRARY_MKL.get(b"cblas_sgemm_batch").unwrap();
231    cblas_gemm
232});
233
234pub static CBLAS_DGEMM_MKL: Lazy<libloading::Symbol<'static, DGEMM_FN_TYPE>> = Lazy::new(|| unsafe {
235    let cblas_gemm = CBLAS_LIBRARY_MKL.get(b"cblas_dgemm").unwrap();
236    cblas_gemm
237});
238
239pub static CBLAS_DGEMM_OPENBLAS: Lazy<libloading::Symbol<'static, DGEMM_FN_TYPE>> = Lazy::new(|| unsafe {
240    let cblas_gemm = CBLAS_LIBRARY_OPENBLAS.get(b"cblas_dgemm").unwrap();
241    cblas_gemm
242});
243
244pub static CBLAS_DGEMM_BLIS: Lazy<libloading::Symbol<'static, DGEMM_FN_TYPE>> = Lazy::new(|| unsafe {
245    let cblas_gemm = CBLAS_LIBRARY_BLIS.get(b"cblas_dgemm").unwrap();
246    cblas_gemm
247});
248
249pub static CBLAS_CGEMM_MKL: Lazy<libloading::Symbol<'static, CGEMM_FN_TYPE>> = Lazy::new(|| unsafe {
250    let cblas_gemm = CBLAS_LIBRARY_MKL.get(b"cblas_cgemm").unwrap();
251    cblas_gemm
252});
253
254pub static CBLAS_CGEMM_OPENBLAS: Lazy<libloading::Symbol<'static, CGEMM_FN_TYPE>> = Lazy::new(|| unsafe {
255    let cblas_gemm = CBLAS_LIBRARY_OPENBLAS.get(b"cblas_cgemm").unwrap();
256    cblas_gemm
257});
258
259pub static CBLAS_CGEMM_BLIS: Lazy<libloading::Symbol<'static, CGEMM_FN_TYPE>> = Lazy::new(|| unsafe {
260    let cblas_gemm = CBLAS_LIBRARY_BLIS.get(b"cblas_cgemm").unwrap();
261    cblas_gemm
262});
263
264pub static CBLAS_ZGEMM_MKL: Lazy<libloading::Symbol<'static, ZGEMM_FN_TYPE>> = Lazy::new(|| unsafe {
265    let cblas_gemm = CBLAS_LIBRARY_MKL.get(b"cblas_zgemm").unwrap();
266    cblas_gemm
267});
268
269pub static CBLAS_ZGEMM_OPENBLAS: Lazy<libloading::Symbol<'static, ZGEMM_FN_TYPE>> = Lazy::new(|| unsafe {
270    let cblas_gemm = CBLAS_LIBRARY_OPENBLAS.get(b"cblas_zgemm").unwrap();
271    cblas_gemm
272});
273
274pub static CBLAS_ZGEMM_BLIS: Lazy<libloading::Symbol<'static, ZGEMM_FN_TYPE>> = Lazy::new(|| unsafe {
275    let cblas_gemm = CBLAS_LIBRARY_BLIS.get(b"cblas_zgemm").unwrap();
276    cblas_gemm
277});
278
279pub static CBLAS_HGEMM_MKL: Lazy<libloading::Symbol<'static, HGEMM_FN_TYPE>> = Lazy::new(|| unsafe {
280    let cblas_gemm = CBLAS_LIBRARY_MKL.get(b"cblas_hgemm").unwrap();
281    cblas_gemm
282});
283
284pub static CBLAS_GEMM_I8: Lazy<libloading::Symbol<'static, GEMM_I8_FN_TYPE>> = Lazy::new(|| unsafe {
285    let cblas_gemm = CBLAS_LIBRARY_MKL.get(b"cblas_gemm_s8u8s32").unwrap();
286    cblas_gemm
287});
288
289pub static CBLAS_GEMM_I16: Lazy<libloading::Symbol<'static, GEMM_I16_FN_TYPE>> = Lazy::new(|| unsafe {
290    let cblas_gemm = CBLAS_LIBRARY_MKL.get(b"cblas_gemm_s16s16s32").unwrap();
291    cblas_gemm
292});
293
294pub enum CBlasBackend {
295    Mkl,
296    Blis,
297    OpenBlas,
298}
299
300pub unsafe fn cblas_sgemm(
301    layout: CBLAS_LAYOUT,
302    transa: CBLAS_TRANSPOSE,
303    transb: CBLAS_TRANSPOSE,
304    m: c_int,
305    n: c_int,
306    k: c_int,
307    alpha: c_float,
308    a: *const c_float,
309    lda: c_int,
310    b: *const c_float,
311    ldb: c_int,
312    beta: c_float,
313    c: *mut c_float,
314    ldc: c_int,
315    backend: CBlasBackend,
316) {
317    match backend {
318        CBlasBackend::Mkl => {
319            CBLAS_SGEMM_MKL(layout, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
320        }
321        CBlasBackend::Blis => {
322            CBLAS_SGEMM_BLIS(layout, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
323        }
324        CBlasBackend::OpenBlas => {
325            CBLAS_SGEMM_OPENBLAS(layout, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
326        }
327    }
328}
329
330pub unsafe fn cblas_dgemm(
331    layout: CBLAS_LAYOUT,
332    transa: CBLAS_TRANSPOSE,
333    transb: CBLAS_TRANSPOSE,
334    m: c_int,
335    n: c_int,
336    k: c_int,
337    alpha: c_double,
338    a: *const c_double,
339    lda: c_int,
340    b: *const c_double,
341    ldb: c_int,
342    beta: c_double,
343    c: *mut c_double,
344    ldc: c_int,
345    backend: CBlasBackend,
346) {
347    match backend {
348        CBlasBackend::Mkl => {
349            CBLAS_DGEMM_MKL(layout, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
350        }
351        CBlasBackend::Blis => {
352            CBLAS_DGEMM_BLIS(layout, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
353        }
354        CBlasBackend::OpenBlas => {
355            CBLAS_DGEMM_OPENBLAS(layout, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
356        }
357    }
358}
359
360pub unsafe fn cblas_cgemm(
361    layout: CBLAS_LAYOUT,
362    transa: CBLAS_TRANSPOSE,
363    transb: CBLAS_TRANSPOSE,
364    m: c_int,
365    n: c_int,
366    k: c_int,
367    alpha: *const c_void,
368    a: *const c_void,
369    lda: c_int,
370    b: *const c_void,
371    ldb: c_int,
372    beta: *const c_void,
373    c: *mut c_void,
374    ldc: c_int,
375    backend: CBlasBackend,
376) {
377    match backend {
378        CBlasBackend::Mkl => {
379            CBLAS_CGEMM_MKL(layout, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
380        }
381        CBlasBackend::Blis => {
382            CBLAS_CGEMM_BLIS(layout, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
383        }
384        CBlasBackend::OpenBlas => {
385            CBLAS_CGEMM_OPENBLAS(layout, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
386        }
387    }
388}
389
390pub unsafe fn cblas_zgemm(
391    layout: CBLAS_LAYOUT,
392    transa: CBLAS_TRANSPOSE,
393    transb: CBLAS_TRANSPOSE,
394    m: c_int,
395    n: c_int,
396    k: c_int,
397    alpha: *const c_void,
398    a: *const c_void,
399    lda: c_int,
400    b: *const c_void,
401    ldb: c_int,
402    beta: *const c_void,
403    c: *mut c_void,
404    ldc: c_int,
405    backend: CBlasBackend,
406) {
407    match backend {
408        CBlasBackend::Mkl => {
409            CBLAS_ZGEMM_MKL(layout, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
410        }
411        CBlasBackend::Blis => {
412            CBLAS_ZGEMM_BLIS(layout, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
413        }
414        CBlasBackend::OpenBlas => {
415            CBLAS_ZGEMM_OPENBLAS(layout, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
416        }
417    }
418}
419pub unsafe fn cblas_hgemm(
420    layout: CBLAS_LAYOUT,
421    transa: CBLAS_TRANSPOSE,
422    transb: CBLAS_TRANSPOSE,
423    m: c_int,
424    n: c_int,
425    k: c_int,
426    alpha: c_ushort,
427    a: *const c_ushort,
428    lda: c_int,
429    b: *const c_ushort,
430    ldb: c_int,
431    beta: c_ushort,
432    c: *mut c_ushort,
433    ldc: c_int,
434    backend: CBlasBackend,
435) {
436    match backend {
437        CBlasBackend::Mkl => {
438            CBLAS_HGEMM_MKL(layout, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
439        }
440        CBlasBackend::Blis => {
441            unimplemented!()
442        }
443        CBlasBackend::OpenBlas => {
444            unimplemented!()
445        }
446    }
447}
448
449pub unsafe fn cblas_gemm_s8u8s32(
450    layout: CBLAS_LAYOUT,
451    transa: CBLAS_TRANSPOSE,
452    transb: CBLAS_TRANSPOSE,
453    offsetc: CBLAS_OFFSET,
454    m: c_int,
455    n: c_int,
456    k: c_int,
457    alpha: c_float,
458    a: *const c_void,
459    lda: c_int,
460    oa: c_schar,
461    b: *const c_void,
462    ldb: c_int,
463    ob: c_schar,
464    beta: c_float,
465    c: *mut c_int,
466    ldc: c_int,
467    oc: *const c_int,
468    backend: CBlasBackend,
469) {
470    match backend {
471        CBlasBackend::Mkl => {
472            CBLAS_GEMM_I8(layout, transa, transb, offsetc, m, n, k, alpha, a, lda, oa, b, ldb, ob, beta, c, ldc, oc);
473        }
474        CBlasBackend::Blis => {
475            unimplemented!()
476        }
477        CBlasBackend::OpenBlas => {
478            unimplemented!()
479        }
480    }
481}
482
483#[allow(clippy::too_many_arguments)]
484pub unsafe fn cblas_gemm_s16s16s32(
485    layout: CBLAS_LAYOUT,
486    transa: CBLAS_TRANSPOSE,
487    transb: CBLAS_TRANSPOSE,
488    offsetc: CBLAS_OFFSET,
489    m: c_int,
490    n: c_int,
491    k: c_int,
492    alpha: c_float,
493    a: *const c_short,
494    lda: c_int,
495    oa: c_short,
496    b: *const c_short,
497    ldb: c_int,
498    ob: c_short,
499    beta: c_float,
500    c: *mut c_int,
501    ldc: c_int,
502    oc: *const c_int,
503    backend: CBlasBackend,
504) {
505    match backend {
506        CBlasBackend::Mkl => {
507            CBLAS_GEMM_I16(layout, transa, transb, offsetc, m, n, k, alpha, a, lda, oa, b, ldb, ob, beta, c, ldc, oc);
508        }
509        CBlasBackend::Blis => {
510            unimplemented!()
511        }
512        CBlasBackend::OpenBlas => {
513            unimplemented!()
514        }
515    }
516}
517
518pub unsafe fn cblas_sgemm_batch(
519    layout: CBLAS_LAYOUT,
520    transa: *const CBLAS_TRANSPOSE,
521    transb: *const CBLAS_TRANSPOSE,
522    m: *const c_int,
523    n: *const c_int,
524    k: *const c_int,
525    alpha: *const c_float,
526    a: *const *const c_float,
527    lda: *const c_int,
528    b: *const *const c_float,
529    ldb: *const c_int,
530    beta: *const c_float,
531    c: *const *mut c_float,
532    ldc: *const c_int,
533    group_count: c_int,
534    group_size: *const c_int,
535    backend: CBlasBackend,
536) {
537    let lib = libloading::Library::new("C:/Users/I011745/Desktop/corenum/pire/.env/Library/bin/mkl_rt.2.dll").unwrap();
538    let cblas_sgemm_batch: libloading::Symbol<
539        unsafe extern "C" fn(
540            CBLAS_LAYOUT,
541            *const CBLAS_TRANSPOSE,
542            *const CBLAS_TRANSPOSE,
543            *const c_int,
544            *const c_int,
545            *const c_int,
546            *const c_float,
547            *const *const c_float,
548            *const c_int,
549            *const *const c_float,
550            *const c_int,
551            *const c_float,
552            *const *mut c_float,
553            *const c_int,
554            c_int,
555            *const c_int,
556        ),
557    > = lib.get(b"cblas_sgemm_batch").unwrap();
558    cblas_sgemm_batch(layout, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, group_count, group_size);
559
560    match backend {
561        CBlasBackend::Mkl => {
562            CBLAS_SGEMM_B_MKL(
563                layout,
564                transa,
565                transb,
566                m,
567                n,
568                k,
569                alpha,
570                a,
571                lda,
572                b,
573                ldb,
574                beta,
575                c,
576                ldc,
577                group_count,
578                group_size,
579            );
580        }
581        CBlasBackend::Blis => {
582            unimplemented!()
583        }
584        CBlasBackend::OpenBlas => {
585            unimplemented!()
586        }
587    }
588}
589
590pub enum ABLayout {
591    NN,
592    NT,
593    TN,
594    TT,
595}
596
597pub fn layout_to_strides(
598    layout: &ABLayout,
599    m: usize,
600    n: usize,
601    k: usize,
602) -> (usize, usize, usize, usize, usize, usize) {
603    match layout {
604        ABLayout::NN => (1, m, 1, k, 1, m),
605        ABLayout::NT => (1, m, n, 1, 1, m),
606        ABLayout::TN => (k, 1, 1, k, 1, m),
607        ABLayout::TT => (k, 1, n, 1, 1, m),
608    }
609}
610
611use rand::distributions::{Distribution, Uniform};
612use rand::rngs::StdRng;
613use rand::{Rng, SeedableRng};
614
615pub trait Bound {
616    type X: rand::distributions::uniform::SampleUniform;
617    fn min_value() -> Self::X;
618    fn max_value() -> Self::X;
619    fn my_sample(dist: &Uniform<Self::X>, rng: &mut StdRng) -> Self;
620}
621
622impl Bound for f32 {
623    type X = f32;
624    fn min_value() -> Self {
625        -2.0
626    }
627    fn max_value() -> Self {
628        2.0
629    }
630    fn my_sample(dist: &Uniform<Self>, rng: &mut StdRng) -> Self {
631        dist.sample(rng)
632    }
633}
634
635impl Bound for f64 {
636    type X = f64;
637    fn min_value() -> Self {
638        -10.0
639    }
640    fn max_value() -> Self {
641        10.0
642    }
643    fn my_sample(dist: &Uniform<Self>, rng: &mut StdRng) -> Self {
644        dist.sample(rng)
645    }
646}
647
648impl Bound for i16 {
649    type X = i16;
650    fn min_value() -> Self {
651        -10
652    }
653    fn max_value() -> Self {
654        10
655    }
656
657    fn my_sample(dist: &Uniform<Self>, rng: &mut StdRng) -> Self {
658        dist.sample(rng)
659    }
660}
661
662impl Bound for i8 {
663    type X = i8;
664    fn min_value() -> Self {
665        -10
666    }
667    fn max_value() -> Self {
668        10
669    }
670
671    fn my_sample(dist: &Uniform<Self>, rng: &mut StdRng) -> Self {
672        dist.sample(rng)
673    }
674}
675
676impl Bound for u8 {
677    type X = u8;
678    fn min_value() -> Self {
679        10
680    }
681    fn max_value() -> Self {
682        20
683    }
684
685    fn my_sample(dist: &Uniform<Self>, rng: &mut StdRng) -> Self {
686        dist.sample(rng)
687    }
688}
689
690impl Bound for i32 {
691    type X = i32;
692    fn min_value() -> Self {
693        -10
694    }
695    fn max_value() -> Self {
696        10
697    }
698    fn my_sample(dist: &Uniform<Self>, rng: &mut StdRng) -> Self {
699        dist.sample(rng)
700    }
701}
702
703impl Bound for f16 {
704    type X = f16;
705    fn min_value() -> Self {
706        f16::from_f32(-1.0)
707    }
708    fn max_value() -> Self {
709        f16::from_f32(1.0)
710    }
711    fn my_sample(dist: &Uniform<Self>, rng: &mut StdRng) -> Self {
712        dist.sample(rng)
713    }
714}
715
716impl Bound for Complex<f32> {
717    type X = f32;
718    fn min_value() -> f32 {
719        -1.0
720    }
721    fn max_value() -> f32 {
722        1.0
723    }
724    fn my_sample(dist: &Uniform<f32>, rng: &mut StdRng) -> Self {
725        // dist.sample(rng)
726        let x = dist.sample(rng);
727        let y = dist.sample(rng);
728        Complex::new(x, y)
729    }
730}
731
732impl Bound for Complex<f64> {
733    type X = f64;
734    fn min_value() -> f64 {
735        -1.0
736    }
737    fn max_value() -> f64 {
738        1.0
739    }
740    fn my_sample(dist: &Uniform<f64>, rng: &mut StdRng) -> Self {
741        // dist.sample(rng)
742        let x = dist.sample(rng);
743        let y = dist.sample(rng);
744        Complex::new(x, y)
745    }
746}
747
748pub fn random_matrix_std<T>(arr: &mut [T])
749where
750    rand::distributions::Standard: rand::prelude::Distribution<T>,
751{
752    let mut x = StdRng::seed_from_u64(43);
753    arr.iter_mut().for_each(|p| *p = x.gen::<T>());
754}
755
756pub fn random_matrix_uniform<T>(arr: &mut [T])
757where
758    T: Bound,
759    T::X: rand::distributions::uniform::SampleUniform,
760{
761    let t0 = T::min_value();
762    let t1 = T::max_value();
763    let mut x = StdRng::seed_from_u64(43);
764    let un_dist = Uniform::new(t0, t1);
765    arr.iter_mut().for_each(|p| *p = T::my_sample(&un_dist, &mut x));
766}
767
768pub trait Diff {
769    fn diff(&self, other: &Self) -> f64;
770}
771
772impl Diff for f32 {
773    fn diff(&self, other: &Self) -> f64 {
774        let diff_abs = (self - other).abs();
775        let diff_rel = diff_abs / self.abs();
776        diff_abs.min(diff_rel) as f64
777    }
778}
779
780impl Diff for f64 {
781    fn diff(&self, other: &Self) -> f64 {
782        let diff_abs = (self - other).abs();
783        let diff_rel = diff_abs / self.abs();
784        diff_abs.min(diff_rel) as f64
785    }
786}
787
788impl Diff for i16 {
789    fn diff(&self, other: &Self) -> f64 {
790        let diff_abs = (*self - *other).abs() as f64;
791        diff_abs
792    }
793}
794
795impl Diff for i8 {
796    fn diff(&self, other: &Self) -> f64 {
797        let diff_abs = (*self as i16 - *other as i16).abs() as f64;
798        diff_abs
799    }
800}
801
802impl Diff for u8 {
803    fn diff(&self, other: &Self) -> f64 {
804        let diff_abs = (*self as i16 - *other as i16).abs() as f64;
805        diff_abs
806    }
807}
808
809impl Diff for i32 {
810    fn diff(&self, other: &Self) -> f64 {
811        let diff_abs = (*self - *other).abs() as f64;
812        diff_abs
813    }
814}
815
816impl Diff for f16 {
817    fn diff(&self, other: &Self) -> f64 {
818        let x = self.to_f32();
819        let y = other.to_f32();
820        let diff_abs = (x - y).abs();
821        let diff_rel = diff_abs / x.abs();
822        diff_abs.min(diff_rel) as f64
823    }
824}
825
826use num_complex::Complex;
827
828impl Diff for Complex<f32> {
829    fn diff(&self, other: &Self) -> f64 {
830        let diff_re = self.re.diff(&other.re);
831        let diff_im = self.im.diff(&other.im);
832        diff_re.max(diff_im)
833    }
834}
835
836impl Diff for Complex<f64> {
837    fn diff(&self, other: &Self) -> f64 {
838        let diff_re = self.re.diff(&other.re);
839        let diff_im = self.im.diff(&other.im);
840        diff_re.max(diff_im)
841    }
842}
843
844pub fn max_abs_diff<T: Copy + std::fmt::Debug>(ap: &[T], bp: &[T], eps: f64) -> f64
845where
846    T: Diff,
847{
848    let mut diff = 0_f64;
849    let len = ap.len();
850    // println!("------------------------------");
851    let mut diff_idx = 0;
852    for i in 0..len {
853        let a = ap[i];
854        let b = bp[i];
855        let cur_diff: f64 = a.diff(&b);
856        if cur_diff > diff {
857            diff_idx = i;
858            diff = cur_diff;
859        }
860    }
861    diff
862}
863
864pub unsafe fn gemm_fallback_f64(
865    m: usize,
866    n: usize,
867    k: usize,
868    alpha: f64,
869    a: *const f64,
870    a_rs: usize,
871    a_cs: usize,
872    b: *const f64,
873    b_rs: usize,
874    b_cs: usize,
875    beta: f64,
876    c: *mut f64,
877    c_rs: usize,
878    c_cs: usize,
879) {
880    for i in 0..m {
881        for j in 0..n {
882            let mut dx = 0.0;
883            for p in 0..k {
884                dx += *a.add(a_rs * i + a_cs * p) * *b.add(b_rs * p + b_cs * j);
885            }
886            *c.add(c_rs * i + c_cs * j) = alpha * dx + beta * *c.add(c_rs * i + c_cs * j);
887        }
888    }
889}
890
891pub unsafe fn gemm_fallback_f32(
892    m: usize,
893    n: usize,
894    k: usize,
895    alpha: f32,
896    a: *const f32,
897    a_rs: usize,
898    a_cs: usize,
899    b: *const f32,
900    b_rs: usize,
901    b_cs: usize,
902    beta: f32,
903    c: *mut f32,
904    c_rs: usize,
905    c_cs: usize,
906) {
907    for i in 0..m {
908        for j in 0..n {
909            let mut dx = 0.0;
910            for p in 0..k {
911                dx += *a.add(a_rs * i + a_cs * p) * *b.add(b_rs * p + b_cs * j);
912            }
913            *c.add(c_rs * i + c_cs * j) = alpha * dx + beta * *c.add(c_rs * i + c_cs * j);
914        }
915    }
916}
917
918pub unsafe fn gemm_fallback_s16s16s32(
919    m: usize,
920    n: usize,
921    k: usize,
922    alpha: f32,
923    a: *const i16,
924    a_rs: usize,
925    a_cs: usize,
926    b: *const i16,
927    b_rs: usize,
928    b_cs: usize,
929    beta: f32,
930    c: *mut i32,
931    c_rs: usize,
932    c_cs: usize,
933) {
934    for i in 0..m {
935        for j in 0..n {
936            let mut dx = 0i32;
937            for p in 0..k {
938                dx += *a.add(a_rs * i + a_cs * p) as i32 * *b.add(b_rs * p + b_cs * j) as i32;
939            }
940            *c.add(c_rs * i + c_cs * j) = (alpha * dx as f32 + beta * *c.add(c_rs * i + c_cs * j) as f32) as i32;
941        }
942    }
943}
944
945pub unsafe fn gemm_fallback_s8u8s32(
946    m: usize,
947    n: usize,
948    k: usize,
949    alpha: f32,
950    a: *const i8,
951    a_rs: usize,
952    a_cs: usize,
953    b: *const u8,
954    b_rs: usize,
955    b_cs: usize,
956    beta: f32,
957    c: *mut i32,
958    c_rs: usize,
959    c_cs: usize,
960) {
961    for i in 0..m {
962        for j in 0..n {
963            let mut dx = 0i32;
964            for p in 0..k {
965                dx += *a.add(a_rs * i + a_cs * p) as i32 * *b.add(b_rs * p + b_cs * j) as i32;
966            }
967            *c.add(c_rs * i + c_cs * j) = (alpha * dx as f32 + beta * *c.add(c_rs * i + c_cs * j) as f32) as i32;
968        }
969    }
970}
971
972pub unsafe fn gemm_fallback_c32(
973    m: usize,
974    n: usize,
975    k: usize,
976    alpha: Complex32,
977    a: *const Complex32,
978    a_rs: usize,
979    a_cs: usize,
980    b: *const Complex32,
981    b_rs: usize,
982    b_cs: usize,
983    beta: Complex32,
984    c: *mut Complex32,
985    c_rs: usize,
986    c_cs: usize,
987) {
988    for i in 0..m {
989        for j in 0..n {
990            let mut dx = Complex32::ZERO;
991            for p in 0..k {
992                dx += *a.add(a_rs * i + a_cs * p) * *b.add(b_rs * p + b_cs * j);
993            }
994            *c.add(c_rs * i + c_cs * j) = alpha * dx + beta * *c.add(c_rs * i + c_cs * j);
995        }
996    }
997}
998
999pub unsafe fn gemm_fallback_c64(
1000    m: usize,
1001    n: usize,
1002    k: usize,
1003    alpha: Complex64,
1004    a: *const Complex64,
1005    a_rs: usize,
1006    a_cs: usize,
1007    b: *const Complex64,
1008    b_rs: usize,
1009    b_cs: usize,
1010    beta: Complex64,
1011    c: *mut Complex64,
1012    c_rs: usize,
1013    c_cs: usize,
1014) {
1015    for i in 0..m {
1016        for j in 0..n {
1017            let mut dx = Complex64::ZERO;
1018            for p in 0..k {
1019                dx += *a.add(a_rs * i + a_cs * p) * *b.add(b_rs * p + b_cs * j);
1020            }
1021            *c.add(c_rs * i + c_cs * j) = alpha * dx + beta * *c.add(c_rs * i + c_cs * j);
1022        }
1023    }
1024}
1025
1026pub unsafe fn gemm_fallback_f16(
1027    m: usize,
1028    n: usize,
1029    k: usize,
1030    alpha: f16,
1031    a: *const f16,
1032    a_rs: usize,
1033    a_cs: usize,
1034    b: *const f16,
1035    b_rs: usize,
1036    b_cs: usize,
1037    beta: f16,
1038    c: *mut f16,
1039    c_rs: usize,
1040    c_cs: usize,
1041) {
1042    for i in 0..m {
1043        for j in 0..n {
1044            let mut dx = f16::ZERO;
1045            for p in 0..k {
1046                dx += *a.add(a_rs * i + a_cs * p) * *b.add(b_rs * p + b_cs * j);
1047            }
1048            *c.add(c_rs * i + c_cs * j) = alpha * dx + beta * *c.add(c_rs * i + c_cs * j);
1049        }
1050    }
1051}
1052
1053pub fn stride_to_cblas(
1054    m: usize,
1055    n: usize,
1056    k: usize,
1057    a_rs: usize,
1058    a_cs: usize,
1059    b_rs: usize,
1060    b_cs: usize,
1061    c_rs: usize,
1062    c_cs: usize,
1063) -> (CBLAS_LAYOUT, CBLAS_TRANSPOSE, CBLAS_TRANSPOSE, c_int, c_int, c_int) {
1064    let (a_rs, a_cs, b_rs, b_cs, c_rs, c_cs) = if c_rs == 1 {
1065        (a_rs, a_cs, b_rs, b_cs, c_rs, c_cs)
1066    } else if c_cs == 1 {
1067        (a_cs, a_rs, b_cs, b_rs, c_cs, c_rs)
1068    } else {
1069        panic!("Non Trivial Stride is not available for Cblas Api");
1070    };
1071    // c_rs == 1
1072    let ldc = c_cs as c_int;
1073    let (a_trans, b_trans, lda, ldb) = if a_rs == 1 && b_rs == 1 && a_cs == m && b_cs == k {
1074        (CBLAS_TRANSPOSE::CblasNoTrans, CBLAS_TRANSPOSE::CblasNoTrans, a_cs as c_int, b_cs as c_int)
1075    } else if a_rs == 1 && b_cs == 1 && a_cs == m && b_rs == n {
1076        (CBLAS_TRANSPOSE::CblasNoTrans, CBLAS_TRANSPOSE::CblasTrans, a_cs as c_int, b_rs as c_int)
1077    } else if a_cs == 1 && b_rs == 1 && a_rs == k && b_cs == k {
1078        (CBLAS_TRANSPOSE::CblasTrans, CBLAS_TRANSPOSE::CblasNoTrans, a_rs as c_int, b_cs as c_int)
1079    } else if a_cs == 1 && b_cs == 1 && a_rs == k && b_rs == n {
1080        (CBLAS_TRANSPOSE::CblasTrans, CBLAS_TRANSPOSE::CblasTrans, a_rs as c_int, b_rs as c_int)
1081    } else {
1082        panic!("Non Trivial Stride is not available for Cblas Api");
1083    };
1084    (CBLAS_LAYOUT::CblasColMajor, a_trans, b_trans, lda, ldb, ldc)
1085}
1086
1087fn cblas_to_stride(
1088    layout: CBLAS_LAYOUT,
1089    transa: CBLAS_TRANSPOSE,
1090    transb: CBLAS_TRANSPOSE,
1091    lda: c_int,
1092    ldb: c_int,
1093    ldc: c_int,
1094) -> (usize, usize, usize, usize, usize, usize) {
1095    if layout == CBLAS_LAYOUT::CblasColMajor {
1096        let (a_rs, a_cs) = if transa == CBLAS_TRANSPOSE::CblasNoTrans { (1, lda as usize) } else { (lda as usize, 1) };
1097        let (b_rs, b_cs) = if transb == CBLAS_TRANSPOSE::CblasNoTrans { (1, ldb as usize) } else { (ldb as usize, 1) };
1098        (a_rs, a_cs, b_rs, b_cs, 1, ldc as usize)
1099    } else {
1100        let (a_rs, a_cs) = if transa == CBLAS_TRANSPOSE::CblasNoTrans { (lda as usize, 1) } else { (1, lda as usize) };
1101        let (b_rs, b_cs) = if transb == CBLAS_TRANSPOSE::CblasNoTrans { (ldb as usize, 1) } else { (1, ldb as usize) };
1102        (a_rs, a_cs, b_rs, b_cs, ldc as usize, 1)
1103    }
1104}
1105
1106pub unsafe fn check_gemm_s16s16s32(
1107    m: usize,
1108    n: usize,
1109    k: usize,
1110    alpha: f32,
1111    a: *const i16,
1112    a_rs: usize,
1113    a_cs: usize,
1114    b: *const i16,
1115    b_rs: usize,
1116    b_cs: usize,
1117    beta: f32,
1118    c: &[i32],
1119    c_rs: usize,
1120    c_cs: usize,
1121    c_ref: &mut [i32],
1122    unary: unsafe fn(*mut i32, m: usize),
1123    eps: f64,
1124) -> f64 {
1125    #[cfg(feature = "mkl")]
1126    {
1127        let oc_val = 0;
1128        let oc = &oc_val as *const c_int;
1129        let (layout, transa, transb, lda, ldb, ldc) = stride_to_cblas(m, n, k, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs);
1130        cblas_gemm_s16s16s32(
1131            layout,
1132            transa,
1133            transb,
1134            CblasFixOffset,
1135            m as c_int,
1136            n as c_int,
1137            k as c_int,
1138            alpha,
1139            a,
1140            lda,
1141            0,
1142            b,
1143            ldb,
1144            0,
1145            beta,
1146            c_ref.as_mut_ptr(),
1147            ldc,
1148            oc,
1149            CBlasBackend::Mkl,
1150        );
1151    }
1152    #[cfg(not(feature = "mkl"))]
1153    {
1154        // calculate diff using fallback
1155        gemm_fallback_s16s16s32(m, n, k, alpha, a, a_rs, a_cs, b, b_rs, b_cs, beta, c_ref.as_mut_ptr(), c_rs, c_cs);
1156    }
1157
1158    let c_ref_ptr = c_ref.as_mut_ptr();
1159    if c_rs == 1 {
1160        for j in 0..n {
1161            unary(c_ref_ptr.add(j * c_cs), m);
1162        }
1163    } else if c_cs == 1 {
1164        for i in 0..m {
1165            unary(c_ref_ptr.add(i * c_rs), n);
1166        }
1167    } else {
1168        for i in 0..m {
1169            for j in 0..n {
1170                unary(c_ref_ptr.add(i * c_rs + j * c_cs), 1);
1171            }
1172        }
1173    }
1174
1175    let diff = max_abs_diff(&c, &c_ref, eps);
1176    return diff;
1177}
1178
1179pub unsafe fn check_gemm_s8u8s32(
1180    m: usize,
1181    n: usize,
1182    k: usize,
1183    alpha: f32,
1184    a: *const i8,
1185    a_rs: usize,
1186    a_cs: usize,
1187    b: *const u8,
1188    b_rs: usize,
1189    b_cs: usize,
1190    beta: f32,
1191    c: &[i32],
1192    c_rs: usize,
1193    c_cs: usize,
1194    c_ref: &mut [i32],
1195    unary: unsafe fn(*mut i32, m: usize),
1196    eps: f64,
1197) -> f64 {
1198    #[cfg(feature = "mkl")]
1199    {
1200        let oc_val = 0;
1201        let oc = &oc_val as *const c_int;
1202        let (layout, transa, transb, lda, ldb, ldc) = stride_to_cblas(m, n, k, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs);
1203        let a = a as *const c_void;
1204        let b = b as *const c_void;
1205        cblas_gemm_s8u8s32(
1206            layout,
1207            transa,
1208            transb,
1209            CblasFixOffset,
1210            m as c_int,
1211            n as c_int,
1212            k as c_int,
1213            alpha,
1214            a,
1215            lda,
1216            0,
1217            b,
1218            ldb,
1219            0,
1220            beta,
1221            c_ref.as_mut_ptr(),
1222            ldc,
1223            oc,
1224            CBlasBackend::Mkl,
1225        );
1226    }
1227    #[cfg(not(feature = "mkl"))]
1228    {
1229        // calculate diff using fallback
1230        gemm_fallback_s8u8s32(m, n, k, alpha, a, a_rs, a_cs, b, b_rs, b_cs, beta, c_ref.as_mut_ptr(), c_rs, c_cs);
1231    }
1232
1233    let c_ref_ptr = c_ref.as_mut_ptr();
1234    if c_rs == 1 {
1235        for j in 0..n {
1236            unary(c_ref_ptr.add(j * c_cs), m);
1237        }
1238    } else if c_cs == 1 {
1239        for i in 0..m {
1240            unary(c_ref_ptr.add(i * c_rs), n);
1241        }
1242    } else {
1243        for i in 0..m {
1244            for j in 0..n {
1245                unary(c_ref_ptr.add(i * c_rs + j * c_cs), 1);
1246            }
1247        }
1248    }
1249
1250    let diff = max_abs_diff(&c, &c_ref, eps);
1251    return diff;
1252}
1253
1254pub unsafe fn check_gemm_f16(
1255    m: usize,
1256    n: usize,
1257    k: usize,
1258    alpha: f16,
1259    a: *const f16,
1260    a_rs: usize,
1261    a_cs: usize,
1262    b: *const f16,
1263    b_rs: usize,
1264    b_cs: usize,
1265    beta: f16,
1266    c: &[f16],
1267    c_rs: usize,
1268    c_cs: usize,
1269    c_ref: &mut [f16],
1270    unary: unsafe fn(*mut f16, m: usize),
1271    eps: f64,
1272) -> f64 {
1273    #[cfg(feature = "mkl")]
1274    {
1275        let (layout, transa, transb, lda, ldb, ldc) = stride_to_cblas(m, n, k, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs);
1276        let a = a as *const c_ushort;
1277        let b = b as *const c_ushort;
1278        let c_ref_ptr = c_ref.as_mut_ptr() as *mut c_ushort;
1279        let alpha = alpha.to_bits();
1280        let beta = beta.to_bits();
1281        cblas_hgemm(
1282            layout,
1283            transa,
1284            transb,
1285            m as c_int,
1286            n as c_int,
1287            k as c_int,
1288            alpha,
1289            a,
1290            lda,
1291            b,
1292            ldb,
1293            beta,
1294            c_ref_ptr,
1295            ldc,
1296            CBlasBackend::Mkl,
1297        );
1298    }
1299    #[cfg(not(feature = "mkl"))]
1300    {
1301        // calculate diff using fallback
1302        gemm_fallback_f16(m, n, k, alpha, a, a_rs, a_cs, b, b_rs, b_cs, beta, c_ref.as_mut_ptr(), c_rs, c_cs);
1303    }
1304
1305    let c_ref_ptr = c_ref.as_mut_ptr();
1306    if c_rs == 1 {
1307        for j in 0..n {
1308            unary(c_ref_ptr.add(j * c_cs), m);
1309        }
1310    } else if c_cs == 1 {
1311        for i in 0..m {
1312            unary(c_ref_ptr.add(i * c_rs), n);
1313        }
1314    } else {
1315        for i in 0..m {
1316            for j in 0..n {
1317                unary(c_ref_ptr.add(i * c_rs + j * c_cs), 1);
1318            }
1319        }
1320    }
1321
1322    let diff = max_abs_diff(&c, &c_ref, eps);
1323    return diff;
1324}
1325
1326pub unsafe fn check_gemm_f64(
1327    m: usize,
1328    n: usize,
1329    k: usize,
1330    alpha: f64,
1331    a: *const f64,
1332    a_rs: usize,
1333    a_cs: usize,
1334    b: *const f64,
1335    b_rs: usize,
1336    b_cs: usize,
1337    beta: f64,
1338    c: &[f64],
1339    c_rs: usize,
1340    c_cs: usize,
1341    c_ref: &mut [f64],
1342    unary: unsafe fn(*mut f64, m: usize),
1343    eps: f64,
1344) -> f64 {
1345    #[cfg(feature = "mkl")]
1346    {
1347        let (layout, transa, transb, lda, ldb, ldc) = stride_to_cblas(m, n, k, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs);
1348        cblas_dgemm(
1349            layout,
1350            transa,
1351            transb,
1352            m as c_int,
1353            n as c_int,
1354            k as c_int,
1355            alpha,
1356            a,
1357            lda,
1358            b,
1359            ldb,
1360            beta,
1361            c_ref.as_mut_ptr(),
1362            ldc,
1363            CBlasBackend::Mkl,
1364        );
1365    }
1366    #[cfg(not(feature = "mkl"))]
1367    {
1368        // calculate diff using fallback
1369        gemm_fallback_f64(m, n, k, alpha, a, a_rs, a_cs, b, b_rs, b_cs, beta, c_ref.as_mut_ptr(), c_rs, c_cs);
1370    }
1371
1372    let c_ref_ptr = c_ref.as_mut_ptr();
1373    if c_rs == 1 {
1374        for j in 0..n {
1375            unary(c_ref_ptr.add(j * c_cs), m);
1376        }
1377    } else if c_cs == 1 {
1378        for i in 0..m {
1379            unary(c_ref_ptr.add(i * c_rs), n);
1380        }
1381    } else {
1382        for i in 0..m {
1383            for j in 0..n {
1384                unary(c_ref_ptr.add(i * c_rs + j * c_cs), 1);
1385            }
1386        }
1387    }
1388
1389    let diff = max_abs_diff(&c, &c_ref, eps);
1390    return diff;
1391}
1392
1393pub unsafe fn check_gemm_f32(
1394    m: usize,
1395    n: usize,
1396    k: usize,
1397    alpha: f32,
1398    a: *const f32,
1399    a_rs: usize,
1400    a_cs: usize,
1401    b: *const f32,
1402    b_rs: usize,
1403    b_cs: usize,
1404    beta: f32,
1405    c: &[f32],
1406    c_rs: usize,
1407    c_cs: usize,
1408    c_ref: &mut [f32],
1409    unary: unsafe fn(*mut f32, m: usize),
1410    eps: f64,
1411) -> f64 {
1412    #[cfg(feature = "mkl")]
1413    {
1414        let (layout, transa, transb, lda, ldb, ldc) = stride_to_cblas(m, n, k, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs);
1415        cblas_sgemm(
1416            layout,
1417            transa,
1418            transb,
1419            m as c_int,
1420            n as c_int,
1421            k as c_int,
1422            alpha,
1423            a,
1424            lda,
1425            b,
1426            ldb,
1427            beta,
1428            c_ref.as_mut_ptr(),
1429            ldc,
1430            CBlasBackend::Mkl,
1431        );
1432    }
1433    #[cfg(not(feature = "mkl"))]
1434    {
1435        // calculate diff using fallback
1436        gemm_fallback_f32(m, n, k, alpha, a, a_rs, a_cs, b, b_rs, b_cs, beta, c_ref.as_mut_ptr(), c_rs, c_cs);
1437    }
1438    let c_ref_ptr = c_ref.as_mut_ptr();
1439    if c_rs == 1 {
1440        for j in 0..n {
1441            unary(c_ref_ptr.add(j * c_cs), m);
1442        }
1443    } else if c_cs == 1 {
1444        for i in 0..m {
1445            unary(c_ref_ptr.add(i * c_rs), n);
1446        }
1447    } else {
1448        for i in 0..m {
1449            for j in 0..n {
1450                unary(c_ref_ptr.add(i * c_rs + j * c_cs), 1);
1451            }
1452        }
1453    }
1454
1455    let diff = max_abs_diff(&c, &c_ref, eps);
1456    return diff;
1457}
1458
1459pub unsafe fn check_gemm_c32(
1460    m: usize,
1461    n: usize,
1462    k: usize,
1463    alpha: Complex32,
1464    a: *const Complex32,
1465    a_rs: usize,
1466    a_cs: usize,
1467    b: *const Complex32,
1468    b_rs: usize,
1469    b_cs: usize,
1470    beta: Complex32,
1471    c: &[Complex32],
1472    c_rs: usize,
1473    c_cs: usize,
1474    c_ref: &mut [Complex32],
1475    unary: unsafe fn(*mut Complex32, m: usize),
1476    eps: f64,
1477) -> f64 {
1478    #[cfg(feature = "mkl")]
1479    {
1480        let (layout, transa, transb, lda, ldb, ldc) = stride_to_cblas(m, n, k, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs);
1481        let a = a as *const c_void;
1482        let b = b as *const c_void;
1483        let c_ref_ptr = c_ref.as_mut_ptr() as *mut c_void;
1484        let alpha_ptr = &alpha as *const Complex32 as *const c_void;
1485        let beta_ptr = &beta as *const Complex32 as *const c_void;
1486        cblas_cgemm(
1487            layout,
1488            transa,
1489            transb,
1490            m as c_int,
1491            n as c_int,
1492            k as c_int,
1493            alpha_ptr,
1494            a,
1495            lda,
1496            b,
1497            ldb,
1498            beta_ptr,
1499            c_ref_ptr,
1500            ldc,
1501            CBlasBackend::Mkl,
1502        );
1503    }
1504    #[cfg(not(feature = "mkl"))]
1505    {
1506        // calculate diff using fallback
1507        gemm_fallback_c32(m, n, k, alpha, a, a_rs, a_cs, b, b_rs, b_cs, beta, c_ref.as_mut_ptr(), c_rs, c_cs);
1508    }
1509    let c_ref_ptr = c_ref.as_mut_ptr();
1510    if c_rs == 1 {
1511        for j in 0..n {
1512            unary(c_ref_ptr.add(j * c_cs), m);
1513        }
1514    } else if c_cs == 1 {
1515        for i in 0..m {
1516            unary(c_ref_ptr.add(i * c_rs), n);
1517        }
1518    } else {
1519        for i in 0..m {
1520            for j in 0..n {
1521                unary(c_ref_ptr.add(i * c_rs + j * c_cs), 1);
1522            }
1523        }
1524    }
1525
1526    let diff = max_abs_diff(&c, &c_ref, eps);
1527    return diff;
1528}
1529
1530pub unsafe fn check_gemm_c64(
1531    m: usize,
1532    n: usize,
1533    k: usize,
1534    alpha: Complex64,
1535    a: *const Complex64,
1536    a_rs: usize,
1537    a_cs: usize,
1538    b: *const Complex64,
1539    b_rs: usize,
1540    b_cs: usize,
1541    beta: Complex64,
1542    c: &[Complex64],
1543    c_rs: usize,
1544    c_cs: usize,
1545    c_ref: &mut [Complex64],
1546    unary: unsafe fn(*mut Complex64, m: usize),
1547    eps: f64,
1548) -> f64 {
1549    #[cfg(feature = "mkl")]
1550    {
1551        let (layout, transa, transb, lda, ldb, ldc) = stride_to_cblas(m, n, k, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs);
1552        let a = a as *const c_void;
1553        let b = b as *const c_void;
1554        let c_ref_ptr = c_ref.as_mut_ptr() as *mut c_void;
1555        let alpha_ptr = &alpha as *const Complex64 as *const c_void;
1556        let beta_ptr = &beta as *const Complex64 as *const c_void;
1557        cblas_zgemm(
1558            layout,
1559            transa,
1560            transb,
1561            m as c_int,
1562            n as c_int,
1563            k as c_int,
1564            alpha_ptr,
1565            a,
1566            lda,
1567            b,
1568            ldb,
1569            beta_ptr,
1570            c_ref_ptr,
1571            ldc,
1572            CBlasBackend::Mkl,
1573        );
1574    }
1575    #[cfg(not(feature = "mkl"))]
1576    {
1577        // calculate diff using fallback
1578        gemm_fallback_c64(m, n, k, alpha, a, a_rs, a_cs, b, b_rs, b_cs, beta, c_ref.as_mut_ptr(), c_rs, c_cs);
1579    }
1580
1581    let c_ref_ptr = c_ref.as_mut_ptr();
1582    if c_rs == 1 {
1583        for j in 0..n {
1584            unary(c_ref_ptr.add(j * c_cs), m);
1585        }
1586    } else if c_cs == 1 {
1587        for i in 0..m {
1588            unary(c_ref_ptr.add(i * c_rs), n);
1589        }
1590    } else {
1591        for i in 0..m {
1592            for j in 0..n {
1593                unary(c_ref_ptr.add(i * c_rs + j * c_cs), 1);
1594            }
1595        }
1596    }
1597
1598    let diff = max_abs_diff(&c, &c_ref, eps);
1599    return diff;
1600}
1601
1602pub fn cblas_params_from_str(
1603    layout_str: &str,
1604    m: usize,
1605    n: usize,
1606    k: usize,
1607) -> (i32, i32, i32, CBLAS_TRANSPOSE, CBLAS_TRANSPOSE) {
1608    if layout_str == "nn" {
1609        (m as i32, k as i32, m as i32, CBLAS_TRANSPOSE::CblasNoTrans, CBLAS_TRANSPOSE::CblasNoTrans)
1610    } else if layout_str == "nt" {
1611        (m as i32, n as i32, m as i32, CBLAS_TRANSPOSE::CblasNoTrans, CBLAS_TRANSPOSE::CblasTrans)
1612    } else if layout_str == "tn" {
1613        (k as i32, k as i32, m as i32, CBLAS_TRANSPOSE::CblasTrans, CBLAS_TRANSPOSE::CblasNoTrans)
1614    } else if layout_str == "tt" {
1615        (k as i32, n as i32, m as i32, CBLAS_TRANSPOSE::CblasTrans, CBLAS_TRANSPOSE::CblasTrans)
1616    } else {
1617        panic!("Unsupported layout str");
1618    }
1619}
1620
1621pub fn generate_m_dims(mc: usize, mr: usize) -> Vec<usize> {
1622    return vec![1, 67, 137];
1623    let mut a_dims = vec![];
1624    for m in 1..mr {
1625        a_dims.push(m);
1626        a_dims.push(m + 100);
1627        a_dims.push(m + 1000);
1628        // a_dims.push(m+mc);
1629    }
1630    a_dims.push(mc + 29);
1631    a_dims
1632}
1633
1634pub fn generate_n_dims(nc: usize, nr: usize) -> Vec<usize> {
1635    // return vec![1, 17, 47, 101, 901];
1636    let mut a_dims = vec![];
1637    for n in 1..nr {
1638        a_dims.push(n);
1639        a_dims.push(n + 400);
1640        a_dims.push(n + nc);
1641    }
1642    a_dims
1643}
1644// kr does not really exist, it is to have the same patter as other dims, also
1645// it might be also be thought of as being tested against k_unrolling parameter
1646pub fn generate_k_dims(kc: usize, kr: usize) -> Vec<usize> {
1647    // return vec![1, 17, 47, 101, 901]
1648    let mut a_dims = vec![];
1649    let kr = 8;
1650    for k in 1..kr {
1651        a_dims.push(k);
1652        a_dims.push(k + 50);
1653        a_dims.push(k + kc);
1654    }
1655    a_dims
1656}