1#![allow(non_camel_case_types)]
2#![allow(dead_code)]
3#![allow(unused)]
4
5use libc::{c_double, c_float, c_int, c_schar, c_short, c_ushort, c_void};
9
10use num_complex::{c32, c64, Complex32, Complex64};
11
12use half::f16;
13use once_cell::sync::Lazy;
14
15#[repr(C)]
16#[derive(Clone, Copy, Debug, PartialEq)]
17#[allow(clippy::enum_variant_names)]
18pub enum CBLAS_LAYOUT {
19 CblasRowMajor = 101,
20 CblasColMajor = 102,
21}
22pub use self::CBLAS_LAYOUT::*;
23
24#[repr(C)]
25pub struct cntx_t(i32);
26
27#[repr(C)]
28#[derive(Clone, Copy, Debug, PartialEq)]
29#[allow(clippy::enum_variant_names)]
30pub enum CBLAS_TRANSPOSE {
31 CblasNoTrans = 111,
32 CblasTrans = 112,
33 CblasConjTrans = 113,
34}
35pub use self::CBLAS_TRANSPOSE::*;
36
37#[repr(C)]
38#[derive(Clone, Copy, Debug)]
39#[allow(clippy::enum_variant_names)]
40pub enum CBLAS_OFFSET {
41 CblasRowOffset = 171,
42 CblasColOffset = 172,
43 CblasFixOffset = 173,
44}
45pub use self::CBLAS_OFFSET::*;
46
47type SGEMM_FN_TYPE = unsafe extern "C" fn(
48 CBLAS_LAYOUT,
49 CBLAS_TRANSPOSE,
50 CBLAS_TRANSPOSE,
51 c_int,
52 c_int,
53 c_int,
54 c_float,
55 *const c_float,
56 c_int,
57 *const c_float,
58 c_int,
59 c_float,
60 *mut c_float,
61 c_int,
62);
63
64type SGEMM_B_FN_TYPE = unsafe extern "C" fn(
65 CBLAS_LAYOUT,
66 *const CBLAS_TRANSPOSE,
67 *const CBLAS_TRANSPOSE,
68 *const c_int,
69 *const c_int,
70 *const c_int,
71 *const c_float,
72 *const *const c_float,
73 *const c_int,
74 *const *const c_float,
75 *const c_int,
76 *const c_float,
77 *const *mut c_float,
78 *const c_int,
79 c_int,
80 *const c_int,
81);
82
83type DGEMM_FN_TYPE = unsafe extern "C" fn(
84 CBLAS_LAYOUT,
85 CBLAS_TRANSPOSE,
86 CBLAS_TRANSPOSE,
87 c_int,
88 c_int,
89 c_int,
90 c_double,
91 *const c_double,
92 c_int,
93 *const c_double,
94 c_int,
95 c_double,
96 *mut c_double,
97 c_int,
98);
99
100type CGEMM_FN_TYPE = unsafe extern "C" fn(
101 CBLAS_LAYOUT,
102 CBLAS_TRANSPOSE,
103 CBLAS_TRANSPOSE,
104 c_int,
105 c_int,
106 c_int,
107 *const c_void,
108 *const c_void,
109 c_int,
110 *const c_void,
111 c_int,
112 *const c_void,
113 *mut c_void,
114 c_int,
115);
116
117type ZGEMM_FN_TYPE = unsafe extern "C" fn(
118 CBLAS_LAYOUT,
119 CBLAS_TRANSPOSE,
120 CBLAS_TRANSPOSE,
121 c_int,
122 c_int,
123 c_int,
124 *const c_void,
125 *const c_void,
126 c_int,
127 *const c_void,
128 c_int,
129 *const c_void,
130 *mut c_void,
131 c_int,
132);
133
134type HGEMM_FN_TYPE = unsafe extern "C" fn(
135 CBLAS_LAYOUT,
136 CBLAS_TRANSPOSE,
137 CBLAS_TRANSPOSE,
138 c_int,
139 c_int,
140 c_int,
141 c_ushort,
142 *const c_ushort,
143 c_int,
144 *const c_ushort,
145 c_int,
146 c_ushort,
147 *mut c_ushort,
148 c_int,
149);
150
151type GEMM_I8_FN_TYPE = unsafe extern "C" fn(
152 CBLAS_LAYOUT,
153 CBLAS_TRANSPOSE,
154 CBLAS_TRANSPOSE,
155 CBLAS_OFFSET,
156 c_int,
157 c_int,
158 c_int,
159 c_float,
160 *const c_void,
161 c_int,
162 c_schar,
163 *const c_void,
164 c_int,
165 c_schar,
166 c_float,
167 *mut c_int,
168 c_int,
169 *const c_int,
170);
171
172type GEMM_I16_FN_TYPE = unsafe extern "C" fn(
173 CBLAS_LAYOUT,
174 CBLAS_TRANSPOSE,
175 CBLAS_TRANSPOSE,
176 CBLAS_OFFSET,
177 c_int,
178 c_int,
179 c_int,
180 c_float,
181 *const c_short,
182 c_int,
183 c_short,
184 *const c_short,
185 c_int,
186 c_short,
187 c_float,
188 *mut c_int,
189 c_int,
190 *const c_int,
191);
192
193const PROJECT_DIR: &str = core::env!("CARGO_MANIFEST_DIR");
194
195pub static CBLAS_LIBRARY_MKL: Lazy<libloading::Library> = Lazy::new(|| unsafe {
197 let default_mkl_path = format!("{PROJECT_DIR}/../../.env/Library/bin/mkl_rt.2.dll");
198 let mkl_path = std::env::var("PIRE_MKL_PATH").unwrap_or(default_mkl_path);
199 libloading::Library::new(mkl_path).unwrap()
200});
201
202pub static CBLAS_LIBRARY_OPENBLAS: Lazy<libloading::Library> = Lazy::new(|| unsafe {
203 let default_openblas_path = format!("{PROJECT_DIR}/../../openblas/openblas.dll");
204 let openblas_path = std::env::var("PIRE_OPENBLAS_PATH").unwrap_or(default_openblas_path);
205 libloading::Library::new(openblas_path).unwrap()
206});
207
208pub static CBLAS_LIBRARY_BLIS: Lazy<libloading::Library> = Lazy::new(|| unsafe {
209 let default_blis_path = format!("{PROJECT_DIR}/../../blis/blis.dll");
210 let blis_path = std::env::var("PIRE_BLIS_PATH").unwrap_or(default_blis_path);
211 libloading::Library::new(blis_path).unwrap()
212});
213
214pub static CBLAS_SGEMM_MKL: Lazy<libloading::Symbol<'static, SGEMM_FN_TYPE>> = Lazy::new(|| unsafe {
215 let cblas_gemm = CBLAS_LIBRARY_MKL.get(b"cblas_sgemm").unwrap();
216 cblas_gemm
217});
218
219pub static CBLAS_SGEMM_OPENBLAS: Lazy<libloading::Symbol<'static, SGEMM_FN_TYPE>> = Lazy::new(|| unsafe {
220 let cblas_gemm = CBLAS_LIBRARY_OPENBLAS.get(b"cblas_sgemm").unwrap();
221 cblas_gemm
222});
223
224pub static CBLAS_SGEMM_BLIS: Lazy<libloading::Symbol<'static, SGEMM_FN_TYPE>> = Lazy::new(|| unsafe {
225 let cblas_gemm = CBLAS_LIBRARY_BLIS.get(b"cblas_sgemm").unwrap();
226 cblas_gemm
227});
228
229pub static CBLAS_SGEMM_B_MKL: Lazy<libloading::Symbol<'static, SGEMM_B_FN_TYPE>> = Lazy::new(|| unsafe {
230 let cblas_gemm = CBLAS_LIBRARY_MKL.get(b"cblas_sgemm_batch").unwrap();
231 cblas_gemm
232});
233
234pub static CBLAS_DGEMM_MKL: Lazy<libloading::Symbol<'static, DGEMM_FN_TYPE>> = Lazy::new(|| unsafe {
235 let cblas_gemm = CBLAS_LIBRARY_MKL.get(b"cblas_dgemm").unwrap();
236 cblas_gemm
237});
238
239pub static CBLAS_DGEMM_OPENBLAS: Lazy<libloading::Symbol<'static, DGEMM_FN_TYPE>> = Lazy::new(|| unsafe {
240 let cblas_gemm = CBLAS_LIBRARY_OPENBLAS.get(b"cblas_dgemm").unwrap();
241 cblas_gemm
242});
243
244pub static CBLAS_DGEMM_BLIS: Lazy<libloading::Symbol<'static, DGEMM_FN_TYPE>> = Lazy::new(|| unsafe {
245 let cblas_gemm = CBLAS_LIBRARY_BLIS.get(b"cblas_dgemm").unwrap();
246 cblas_gemm
247});
248
249pub static CBLAS_CGEMM_MKL: Lazy<libloading::Symbol<'static, CGEMM_FN_TYPE>> = Lazy::new(|| unsafe {
250 let cblas_gemm = CBLAS_LIBRARY_MKL.get(b"cblas_cgemm").unwrap();
251 cblas_gemm
252});
253
254pub static CBLAS_CGEMM_OPENBLAS: Lazy<libloading::Symbol<'static, CGEMM_FN_TYPE>> = Lazy::new(|| unsafe {
255 let cblas_gemm = CBLAS_LIBRARY_OPENBLAS.get(b"cblas_cgemm").unwrap();
256 cblas_gemm
257});
258
259pub static CBLAS_CGEMM_BLIS: Lazy<libloading::Symbol<'static, CGEMM_FN_TYPE>> = Lazy::new(|| unsafe {
260 let cblas_gemm = CBLAS_LIBRARY_BLIS.get(b"cblas_cgemm").unwrap();
261 cblas_gemm
262});
263
264pub static CBLAS_ZGEMM_MKL: Lazy<libloading::Symbol<'static, ZGEMM_FN_TYPE>> = Lazy::new(|| unsafe {
265 let cblas_gemm = CBLAS_LIBRARY_MKL.get(b"cblas_zgemm").unwrap();
266 cblas_gemm
267});
268
269pub static CBLAS_ZGEMM_OPENBLAS: Lazy<libloading::Symbol<'static, ZGEMM_FN_TYPE>> = Lazy::new(|| unsafe {
270 let cblas_gemm = CBLAS_LIBRARY_OPENBLAS.get(b"cblas_zgemm").unwrap();
271 cblas_gemm
272});
273
274pub static CBLAS_ZGEMM_BLIS: Lazy<libloading::Symbol<'static, ZGEMM_FN_TYPE>> = Lazy::new(|| unsafe {
275 let cblas_gemm = CBLAS_LIBRARY_BLIS.get(b"cblas_zgemm").unwrap();
276 cblas_gemm
277});
278
279pub static CBLAS_HGEMM_MKL: Lazy<libloading::Symbol<'static, HGEMM_FN_TYPE>> = Lazy::new(|| unsafe {
280 let cblas_gemm = CBLAS_LIBRARY_MKL.get(b"cblas_hgemm").unwrap();
281 cblas_gemm
282});
283
284pub static CBLAS_GEMM_I8: Lazy<libloading::Symbol<'static, GEMM_I8_FN_TYPE>> = Lazy::new(|| unsafe {
285 let cblas_gemm = CBLAS_LIBRARY_MKL.get(b"cblas_gemm_s8u8s32").unwrap();
286 cblas_gemm
287});
288
289pub static CBLAS_GEMM_I16: Lazy<libloading::Symbol<'static, GEMM_I16_FN_TYPE>> = Lazy::new(|| unsafe {
290 let cblas_gemm = CBLAS_LIBRARY_MKL.get(b"cblas_gemm_s16s16s32").unwrap();
291 cblas_gemm
292});
293
294pub enum CBlasBackend {
295 Mkl,
296 Blis,
297 OpenBlas,
298}
299
300pub unsafe fn cblas_sgemm(
301 layout: CBLAS_LAYOUT,
302 transa: CBLAS_TRANSPOSE,
303 transb: CBLAS_TRANSPOSE,
304 m: c_int,
305 n: c_int,
306 k: c_int,
307 alpha: c_float,
308 a: *const c_float,
309 lda: c_int,
310 b: *const c_float,
311 ldb: c_int,
312 beta: c_float,
313 c: *mut c_float,
314 ldc: c_int,
315 backend: CBlasBackend,
316) {
317 match backend {
318 CBlasBackend::Mkl => {
319 CBLAS_SGEMM_MKL(layout, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
320 }
321 CBlasBackend::Blis => {
322 CBLAS_SGEMM_BLIS(layout, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
323 }
324 CBlasBackend::OpenBlas => {
325 CBLAS_SGEMM_OPENBLAS(layout, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
326 }
327 }
328}
329
330pub unsafe fn cblas_dgemm(
331 layout: CBLAS_LAYOUT,
332 transa: CBLAS_TRANSPOSE,
333 transb: CBLAS_TRANSPOSE,
334 m: c_int,
335 n: c_int,
336 k: c_int,
337 alpha: c_double,
338 a: *const c_double,
339 lda: c_int,
340 b: *const c_double,
341 ldb: c_int,
342 beta: c_double,
343 c: *mut c_double,
344 ldc: c_int,
345 backend: CBlasBackend,
346) {
347 match backend {
348 CBlasBackend::Mkl => {
349 CBLAS_DGEMM_MKL(layout, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
350 }
351 CBlasBackend::Blis => {
352 CBLAS_DGEMM_BLIS(layout, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
353 }
354 CBlasBackend::OpenBlas => {
355 CBLAS_DGEMM_OPENBLAS(layout, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
356 }
357 }
358}
359
360pub unsafe fn cblas_cgemm(
361 layout: CBLAS_LAYOUT,
362 transa: CBLAS_TRANSPOSE,
363 transb: CBLAS_TRANSPOSE,
364 m: c_int,
365 n: c_int,
366 k: c_int,
367 alpha: *const c_void,
368 a: *const c_void,
369 lda: c_int,
370 b: *const c_void,
371 ldb: c_int,
372 beta: *const c_void,
373 c: *mut c_void,
374 ldc: c_int,
375 backend: CBlasBackend,
376) {
377 match backend {
378 CBlasBackend::Mkl => {
379 CBLAS_CGEMM_MKL(layout, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
380 }
381 CBlasBackend::Blis => {
382 CBLAS_CGEMM_BLIS(layout, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
383 }
384 CBlasBackend::OpenBlas => {
385 CBLAS_CGEMM_OPENBLAS(layout, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
386 }
387 }
388}
389
390pub unsafe fn cblas_zgemm(
391 layout: CBLAS_LAYOUT,
392 transa: CBLAS_TRANSPOSE,
393 transb: CBLAS_TRANSPOSE,
394 m: c_int,
395 n: c_int,
396 k: c_int,
397 alpha: *const c_void,
398 a: *const c_void,
399 lda: c_int,
400 b: *const c_void,
401 ldb: c_int,
402 beta: *const c_void,
403 c: *mut c_void,
404 ldc: c_int,
405 backend: CBlasBackend,
406) {
407 match backend {
408 CBlasBackend::Mkl => {
409 CBLAS_ZGEMM_MKL(layout, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
410 }
411 CBlasBackend::Blis => {
412 CBLAS_ZGEMM_BLIS(layout, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
413 }
414 CBlasBackend::OpenBlas => {
415 CBLAS_ZGEMM_OPENBLAS(layout, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
416 }
417 }
418}
419pub unsafe fn cblas_hgemm(
420 layout: CBLAS_LAYOUT,
421 transa: CBLAS_TRANSPOSE,
422 transb: CBLAS_TRANSPOSE,
423 m: c_int,
424 n: c_int,
425 k: c_int,
426 alpha: c_ushort,
427 a: *const c_ushort,
428 lda: c_int,
429 b: *const c_ushort,
430 ldb: c_int,
431 beta: c_ushort,
432 c: *mut c_ushort,
433 ldc: c_int,
434 backend: CBlasBackend,
435) {
436 match backend {
437 CBlasBackend::Mkl => {
438 CBLAS_HGEMM_MKL(layout, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
439 }
440 CBlasBackend::Blis => {
441 unimplemented!()
442 }
443 CBlasBackend::OpenBlas => {
444 unimplemented!()
445 }
446 }
447}
448
449pub unsafe fn cblas_gemm_s8u8s32(
450 layout: CBLAS_LAYOUT,
451 transa: CBLAS_TRANSPOSE,
452 transb: CBLAS_TRANSPOSE,
453 offsetc: CBLAS_OFFSET,
454 m: c_int,
455 n: c_int,
456 k: c_int,
457 alpha: c_float,
458 a: *const c_void,
459 lda: c_int,
460 oa: c_schar,
461 b: *const c_void,
462 ldb: c_int,
463 ob: c_schar,
464 beta: c_float,
465 c: *mut c_int,
466 ldc: c_int,
467 oc: *const c_int,
468 backend: CBlasBackend,
469) {
470 match backend {
471 CBlasBackend::Mkl => {
472 CBLAS_GEMM_I8(layout, transa, transb, offsetc, m, n, k, alpha, a, lda, oa, b, ldb, ob, beta, c, ldc, oc);
473 }
474 CBlasBackend::Blis => {
475 unimplemented!()
476 }
477 CBlasBackend::OpenBlas => {
478 unimplemented!()
479 }
480 }
481}
482
483#[allow(clippy::too_many_arguments)]
484pub unsafe fn cblas_gemm_s16s16s32(
485 layout: CBLAS_LAYOUT,
486 transa: CBLAS_TRANSPOSE,
487 transb: CBLAS_TRANSPOSE,
488 offsetc: CBLAS_OFFSET,
489 m: c_int,
490 n: c_int,
491 k: c_int,
492 alpha: c_float,
493 a: *const c_short,
494 lda: c_int,
495 oa: c_short,
496 b: *const c_short,
497 ldb: c_int,
498 ob: c_short,
499 beta: c_float,
500 c: *mut c_int,
501 ldc: c_int,
502 oc: *const c_int,
503 backend: CBlasBackend,
504) {
505 match backend {
506 CBlasBackend::Mkl => {
507 CBLAS_GEMM_I16(layout, transa, transb, offsetc, m, n, k, alpha, a, lda, oa, b, ldb, ob, beta, c, ldc, oc);
508 }
509 CBlasBackend::Blis => {
510 unimplemented!()
511 }
512 CBlasBackend::OpenBlas => {
513 unimplemented!()
514 }
515 }
516}
517
518pub unsafe fn cblas_sgemm_batch(
519 layout: CBLAS_LAYOUT,
520 transa: *const CBLAS_TRANSPOSE,
521 transb: *const CBLAS_TRANSPOSE,
522 m: *const c_int,
523 n: *const c_int,
524 k: *const c_int,
525 alpha: *const c_float,
526 a: *const *const c_float,
527 lda: *const c_int,
528 b: *const *const c_float,
529 ldb: *const c_int,
530 beta: *const c_float,
531 c: *const *mut c_float,
532 ldc: *const c_int,
533 group_count: c_int,
534 group_size: *const c_int,
535 backend: CBlasBackend,
536) {
537 let lib = libloading::Library::new("C:/Users/I011745/Desktop/corenum/pire/.env/Library/bin/mkl_rt.2.dll").unwrap();
538 let cblas_sgemm_batch: libloading::Symbol<
539 unsafe extern "C" fn(
540 CBLAS_LAYOUT,
541 *const CBLAS_TRANSPOSE,
542 *const CBLAS_TRANSPOSE,
543 *const c_int,
544 *const c_int,
545 *const c_int,
546 *const c_float,
547 *const *const c_float,
548 *const c_int,
549 *const *const c_float,
550 *const c_int,
551 *const c_float,
552 *const *mut c_float,
553 *const c_int,
554 c_int,
555 *const c_int,
556 ),
557 > = lib.get(b"cblas_sgemm_batch").unwrap();
558 cblas_sgemm_batch(layout, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, group_count, group_size);
559
560 match backend {
561 CBlasBackend::Mkl => {
562 CBLAS_SGEMM_B_MKL(
563 layout,
564 transa,
565 transb,
566 m,
567 n,
568 k,
569 alpha,
570 a,
571 lda,
572 b,
573 ldb,
574 beta,
575 c,
576 ldc,
577 group_count,
578 group_size,
579 );
580 }
581 CBlasBackend::Blis => {
582 unimplemented!()
583 }
584 CBlasBackend::OpenBlas => {
585 unimplemented!()
586 }
587 }
588}
589
590pub enum ABLayout {
591 NN,
592 NT,
593 TN,
594 TT,
595}
596
597pub fn layout_to_strides(
598 layout: &ABLayout,
599 m: usize,
600 n: usize,
601 k: usize,
602) -> (usize, usize, usize, usize, usize, usize) {
603 match layout {
604 ABLayout::NN => (1, m, 1, k, 1, m),
605 ABLayout::NT => (1, m, n, 1, 1, m),
606 ABLayout::TN => (k, 1, 1, k, 1, m),
607 ABLayout::TT => (k, 1, n, 1, 1, m),
608 }
609}
610
611use rand::distributions::{Distribution, Uniform};
612use rand::rngs::StdRng;
613use rand::{Rng, SeedableRng};
614
615pub trait Bound {
616 type X: rand::distributions::uniform::SampleUniform;
617 fn min_value() -> Self::X;
618 fn max_value() -> Self::X;
619 fn my_sample(dist: &Uniform<Self::X>, rng: &mut StdRng) -> Self;
620}
621
622impl Bound for f32 {
623 type X = f32;
624 fn min_value() -> Self {
625 -2.0
626 }
627 fn max_value() -> Self {
628 2.0
629 }
630 fn my_sample(dist: &Uniform<Self>, rng: &mut StdRng) -> Self {
631 dist.sample(rng)
632 }
633}
634
635impl Bound for f64 {
636 type X = f64;
637 fn min_value() -> Self {
638 -10.0
639 }
640 fn max_value() -> Self {
641 10.0
642 }
643 fn my_sample(dist: &Uniform<Self>, rng: &mut StdRng) -> Self {
644 dist.sample(rng)
645 }
646}
647
648impl Bound for i16 {
649 type X = i16;
650 fn min_value() -> Self {
651 -10
652 }
653 fn max_value() -> Self {
654 10
655 }
656
657 fn my_sample(dist: &Uniform<Self>, rng: &mut StdRng) -> Self {
658 dist.sample(rng)
659 }
660}
661
662impl Bound for i8 {
663 type X = i8;
664 fn min_value() -> Self {
665 -10
666 }
667 fn max_value() -> Self {
668 10
669 }
670
671 fn my_sample(dist: &Uniform<Self>, rng: &mut StdRng) -> Self {
672 dist.sample(rng)
673 }
674}
675
676impl Bound for u8 {
677 type X = u8;
678 fn min_value() -> Self {
679 10
680 }
681 fn max_value() -> Self {
682 20
683 }
684
685 fn my_sample(dist: &Uniform<Self>, rng: &mut StdRng) -> Self {
686 dist.sample(rng)
687 }
688}
689
690impl Bound for i32 {
691 type X = i32;
692 fn min_value() -> Self {
693 -10
694 }
695 fn max_value() -> Self {
696 10
697 }
698 fn my_sample(dist: &Uniform<Self>, rng: &mut StdRng) -> Self {
699 dist.sample(rng)
700 }
701}
702
703impl Bound for f16 {
704 type X = f16;
705 fn min_value() -> Self {
706 f16::from_f32(-1.0)
707 }
708 fn max_value() -> Self {
709 f16::from_f32(1.0)
710 }
711 fn my_sample(dist: &Uniform<Self>, rng: &mut StdRng) -> Self {
712 dist.sample(rng)
713 }
714}
715
716impl Bound for Complex<f32> {
717 type X = f32;
718 fn min_value() -> f32 {
719 -1.0
720 }
721 fn max_value() -> f32 {
722 1.0
723 }
724 fn my_sample(dist: &Uniform<f32>, rng: &mut StdRng) -> Self {
725 let x = dist.sample(rng);
727 let y = dist.sample(rng);
728 Complex::new(x, y)
729 }
730}
731
732impl Bound for Complex<f64> {
733 type X = f64;
734 fn min_value() -> f64 {
735 -1.0
736 }
737 fn max_value() -> f64 {
738 1.0
739 }
740 fn my_sample(dist: &Uniform<f64>, rng: &mut StdRng) -> Self {
741 let x = dist.sample(rng);
743 let y = dist.sample(rng);
744 Complex::new(x, y)
745 }
746}
747
748pub fn random_matrix_std<T>(arr: &mut [T])
749where
750 rand::distributions::Standard: rand::prelude::Distribution<T>,
751{
752 let mut x = StdRng::seed_from_u64(43);
753 arr.iter_mut().for_each(|p| *p = x.gen::<T>());
754}
755
756pub fn random_matrix_uniform<T>(arr: &mut [T])
757where
758 T: Bound,
759 T::X: rand::distributions::uniform::SampleUniform,
760{
761 let t0 = T::min_value();
762 let t1 = T::max_value();
763 let mut x = StdRng::seed_from_u64(43);
764 let un_dist = Uniform::new(t0, t1);
765 arr.iter_mut().for_each(|p| *p = T::my_sample(&un_dist, &mut x));
766}
767
768pub trait Diff {
769 fn diff(&self, other: &Self) -> f64;
770}
771
772impl Diff for f32 {
773 fn diff(&self, other: &Self) -> f64 {
774 let diff_abs = (self - other).abs();
775 let diff_rel = diff_abs / self.abs();
776 diff_abs.min(diff_rel) as f64
777 }
778}
779
780impl Diff for f64 {
781 fn diff(&self, other: &Self) -> f64 {
782 let diff_abs = (self - other).abs();
783 let diff_rel = diff_abs / self.abs();
784 diff_abs.min(diff_rel) as f64
785 }
786}
787
788impl Diff for i16 {
789 fn diff(&self, other: &Self) -> f64 {
790 let diff_abs = (*self - *other).abs() as f64;
791 diff_abs
792 }
793}
794
795impl Diff for i8 {
796 fn diff(&self, other: &Self) -> f64 {
797 let diff_abs = (*self as i16 - *other as i16).abs() as f64;
798 diff_abs
799 }
800}
801
802impl Diff for u8 {
803 fn diff(&self, other: &Self) -> f64 {
804 let diff_abs = (*self as i16 - *other as i16).abs() as f64;
805 diff_abs
806 }
807}
808
809impl Diff for i32 {
810 fn diff(&self, other: &Self) -> f64 {
811 let diff_abs = (*self - *other).abs() as f64;
812 diff_abs
813 }
814}
815
816impl Diff for f16 {
817 fn diff(&self, other: &Self) -> f64 {
818 let x = self.to_f32();
819 let y = other.to_f32();
820 let diff_abs = (x - y).abs();
821 let diff_rel = diff_abs / x.abs();
822 diff_abs.min(diff_rel) as f64
823 }
824}
825
826use num_complex::Complex;
827
828impl Diff for Complex<f32> {
829 fn diff(&self, other: &Self) -> f64 {
830 let diff_re = self.re.diff(&other.re);
831 let diff_im = self.im.diff(&other.im);
832 diff_re.max(diff_im)
833 }
834}
835
836impl Diff for Complex<f64> {
837 fn diff(&self, other: &Self) -> f64 {
838 let diff_re = self.re.diff(&other.re);
839 let diff_im = self.im.diff(&other.im);
840 diff_re.max(diff_im)
841 }
842}
843
844pub fn max_abs_diff<T: Copy + std::fmt::Debug>(ap: &[T], bp: &[T], eps: f64) -> f64
845where
846 T: Diff,
847{
848 let mut diff = 0_f64;
849 let len = ap.len();
850 let mut diff_idx = 0;
852 for i in 0..len {
853 let a = ap[i];
854 let b = bp[i];
855 let cur_diff: f64 = a.diff(&b);
856 if cur_diff > diff {
857 diff_idx = i;
858 diff = cur_diff;
859 }
860 }
861 diff
862}
863
864pub unsafe fn gemm_fallback_f64(
865 m: usize,
866 n: usize,
867 k: usize,
868 alpha: f64,
869 a: *const f64,
870 a_rs: usize,
871 a_cs: usize,
872 b: *const f64,
873 b_rs: usize,
874 b_cs: usize,
875 beta: f64,
876 c: *mut f64,
877 c_rs: usize,
878 c_cs: usize,
879) {
880 for i in 0..m {
881 for j in 0..n {
882 let mut dx = 0.0;
883 for p in 0..k {
884 dx += *a.add(a_rs * i + a_cs * p) * *b.add(b_rs * p + b_cs * j);
885 }
886 *c.add(c_rs * i + c_cs * j) = alpha * dx + beta * *c.add(c_rs * i + c_cs * j);
887 }
888 }
889}
890
891pub unsafe fn gemm_fallback_f32(
892 m: usize,
893 n: usize,
894 k: usize,
895 alpha: f32,
896 a: *const f32,
897 a_rs: usize,
898 a_cs: usize,
899 b: *const f32,
900 b_rs: usize,
901 b_cs: usize,
902 beta: f32,
903 c: *mut f32,
904 c_rs: usize,
905 c_cs: usize,
906) {
907 for i in 0..m {
908 for j in 0..n {
909 let mut dx = 0.0;
910 for p in 0..k {
911 dx += *a.add(a_rs * i + a_cs * p) * *b.add(b_rs * p + b_cs * j);
912 }
913 *c.add(c_rs * i + c_cs * j) = alpha * dx + beta * *c.add(c_rs * i + c_cs * j);
914 }
915 }
916}
917
918pub unsafe fn gemm_fallback_s16s16s32(
919 m: usize,
920 n: usize,
921 k: usize,
922 alpha: f32,
923 a: *const i16,
924 a_rs: usize,
925 a_cs: usize,
926 b: *const i16,
927 b_rs: usize,
928 b_cs: usize,
929 beta: f32,
930 c: *mut i32,
931 c_rs: usize,
932 c_cs: usize,
933) {
934 for i in 0..m {
935 for j in 0..n {
936 let mut dx = 0i32;
937 for p in 0..k {
938 dx += *a.add(a_rs * i + a_cs * p) as i32 * *b.add(b_rs * p + b_cs * j) as i32;
939 }
940 *c.add(c_rs * i + c_cs * j) = (alpha * dx as f32 + beta * *c.add(c_rs * i + c_cs * j) as f32) as i32;
941 }
942 }
943}
944
945pub unsafe fn gemm_fallback_s8u8s32(
946 m: usize,
947 n: usize,
948 k: usize,
949 alpha: f32,
950 a: *const i8,
951 a_rs: usize,
952 a_cs: usize,
953 b: *const u8,
954 b_rs: usize,
955 b_cs: usize,
956 beta: f32,
957 c: *mut i32,
958 c_rs: usize,
959 c_cs: usize,
960) {
961 for i in 0..m {
962 for j in 0..n {
963 let mut dx = 0i32;
964 for p in 0..k {
965 dx += *a.add(a_rs * i + a_cs * p) as i32 * *b.add(b_rs * p + b_cs * j) as i32;
966 }
967 *c.add(c_rs * i + c_cs * j) = (alpha * dx as f32 + beta * *c.add(c_rs * i + c_cs * j) as f32) as i32;
968 }
969 }
970}
971
972pub unsafe fn gemm_fallback_c32(
973 m: usize,
974 n: usize,
975 k: usize,
976 alpha: Complex32,
977 a: *const Complex32,
978 a_rs: usize,
979 a_cs: usize,
980 b: *const Complex32,
981 b_rs: usize,
982 b_cs: usize,
983 beta: Complex32,
984 c: *mut Complex32,
985 c_rs: usize,
986 c_cs: usize,
987) {
988 for i in 0..m {
989 for j in 0..n {
990 let mut dx = Complex32::ZERO;
991 for p in 0..k {
992 dx += *a.add(a_rs * i + a_cs * p) * *b.add(b_rs * p + b_cs * j);
993 }
994 *c.add(c_rs * i + c_cs * j) = alpha * dx + beta * *c.add(c_rs * i + c_cs * j);
995 }
996 }
997}
998
999pub unsafe fn gemm_fallback_c64(
1000 m: usize,
1001 n: usize,
1002 k: usize,
1003 alpha: Complex64,
1004 a: *const Complex64,
1005 a_rs: usize,
1006 a_cs: usize,
1007 b: *const Complex64,
1008 b_rs: usize,
1009 b_cs: usize,
1010 beta: Complex64,
1011 c: *mut Complex64,
1012 c_rs: usize,
1013 c_cs: usize,
1014) {
1015 for i in 0..m {
1016 for j in 0..n {
1017 let mut dx = Complex64::ZERO;
1018 for p in 0..k {
1019 dx += *a.add(a_rs * i + a_cs * p) * *b.add(b_rs * p + b_cs * j);
1020 }
1021 *c.add(c_rs * i + c_cs * j) = alpha * dx + beta * *c.add(c_rs * i + c_cs * j);
1022 }
1023 }
1024}
1025
1026pub unsafe fn gemm_fallback_f16(
1027 m: usize,
1028 n: usize,
1029 k: usize,
1030 alpha: f16,
1031 a: *const f16,
1032 a_rs: usize,
1033 a_cs: usize,
1034 b: *const f16,
1035 b_rs: usize,
1036 b_cs: usize,
1037 beta: f16,
1038 c: *mut f16,
1039 c_rs: usize,
1040 c_cs: usize,
1041) {
1042 for i in 0..m {
1043 for j in 0..n {
1044 let mut dx = f16::ZERO;
1045 for p in 0..k {
1046 dx += *a.add(a_rs * i + a_cs * p) * *b.add(b_rs * p + b_cs * j);
1047 }
1048 *c.add(c_rs * i + c_cs * j) = alpha * dx + beta * *c.add(c_rs * i + c_cs * j);
1049 }
1050 }
1051}
1052
1053pub fn stride_to_cblas(
1054 m: usize,
1055 n: usize,
1056 k: usize,
1057 a_rs: usize,
1058 a_cs: usize,
1059 b_rs: usize,
1060 b_cs: usize,
1061 c_rs: usize,
1062 c_cs: usize,
1063) -> (CBLAS_LAYOUT, CBLAS_TRANSPOSE, CBLAS_TRANSPOSE, c_int, c_int, c_int) {
1064 let (a_rs, a_cs, b_rs, b_cs, c_rs, c_cs) = if c_rs == 1 {
1065 (a_rs, a_cs, b_rs, b_cs, c_rs, c_cs)
1066 } else if c_cs == 1 {
1067 (a_cs, a_rs, b_cs, b_rs, c_cs, c_rs)
1068 } else {
1069 panic!("Non Trivial Stride is not available for Cblas Api");
1070 };
1071 let ldc = c_cs as c_int;
1073 let (a_trans, b_trans, lda, ldb) = if a_rs == 1 && b_rs == 1 && a_cs == m && b_cs == k {
1074 (CBLAS_TRANSPOSE::CblasNoTrans, CBLAS_TRANSPOSE::CblasNoTrans, a_cs as c_int, b_cs as c_int)
1075 } else if a_rs == 1 && b_cs == 1 && a_cs == m && b_rs == n {
1076 (CBLAS_TRANSPOSE::CblasNoTrans, CBLAS_TRANSPOSE::CblasTrans, a_cs as c_int, b_rs as c_int)
1077 } else if a_cs == 1 && b_rs == 1 && a_rs == k && b_cs == k {
1078 (CBLAS_TRANSPOSE::CblasTrans, CBLAS_TRANSPOSE::CblasNoTrans, a_rs as c_int, b_cs as c_int)
1079 } else if a_cs == 1 && b_cs == 1 && a_rs == k && b_rs == n {
1080 (CBLAS_TRANSPOSE::CblasTrans, CBLAS_TRANSPOSE::CblasTrans, a_rs as c_int, b_rs as c_int)
1081 } else {
1082 panic!("Non Trivial Stride is not available for Cblas Api");
1083 };
1084 (CBLAS_LAYOUT::CblasColMajor, a_trans, b_trans, lda, ldb, ldc)
1085}
1086
1087fn cblas_to_stride(
1088 layout: CBLAS_LAYOUT,
1089 transa: CBLAS_TRANSPOSE,
1090 transb: CBLAS_TRANSPOSE,
1091 lda: c_int,
1092 ldb: c_int,
1093 ldc: c_int,
1094) -> (usize, usize, usize, usize, usize, usize) {
1095 if layout == CBLAS_LAYOUT::CblasColMajor {
1096 let (a_rs, a_cs) = if transa == CBLAS_TRANSPOSE::CblasNoTrans { (1, lda as usize) } else { (lda as usize, 1) };
1097 let (b_rs, b_cs) = if transb == CBLAS_TRANSPOSE::CblasNoTrans { (1, ldb as usize) } else { (ldb as usize, 1) };
1098 (a_rs, a_cs, b_rs, b_cs, 1, ldc as usize)
1099 } else {
1100 let (a_rs, a_cs) = if transa == CBLAS_TRANSPOSE::CblasNoTrans { (lda as usize, 1) } else { (1, lda as usize) };
1101 let (b_rs, b_cs) = if transb == CBLAS_TRANSPOSE::CblasNoTrans { (ldb as usize, 1) } else { (1, ldb as usize) };
1102 (a_rs, a_cs, b_rs, b_cs, ldc as usize, 1)
1103 }
1104}
1105
1106pub unsafe fn check_gemm_s16s16s32(
1107 m: usize,
1108 n: usize,
1109 k: usize,
1110 alpha: f32,
1111 a: *const i16,
1112 a_rs: usize,
1113 a_cs: usize,
1114 b: *const i16,
1115 b_rs: usize,
1116 b_cs: usize,
1117 beta: f32,
1118 c: &[i32],
1119 c_rs: usize,
1120 c_cs: usize,
1121 c_ref: &mut [i32],
1122 unary: unsafe fn(*mut i32, m: usize),
1123 eps: f64,
1124) -> f64 {
1125 #[cfg(feature = "mkl")]
1126 {
1127 let oc_val = 0;
1128 let oc = &oc_val as *const c_int;
1129 let (layout, transa, transb, lda, ldb, ldc) = stride_to_cblas(m, n, k, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs);
1130 cblas_gemm_s16s16s32(
1131 layout,
1132 transa,
1133 transb,
1134 CblasFixOffset,
1135 m as c_int,
1136 n as c_int,
1137 k as c_int,
1138 alpha,
1139 a,
1140 lda,
1141 0,
1142 b,
1143 ldb,
1144 0,
1145 beta,
1146 c_ref.as_mut_ptr(),
1147 ldc,
1148 oc,
1149 CBlasBackend::Mkl,
1150 );
1151 }
1152 #[cfg(not(feature = "mkl"))]
1153 {
1154 gemm_fallback_s16s16s32(m, n, k, alpha, a, a_rs, a_cs, b, b_rs, b_cs, beta, c_ref.as_mut_ptr(), c_rs, c_cs);
1156 }
1157
1158 let c_ref_ptr = c_ref.as_mut_ptr();
1159 if c_rs == 1 {
1160 for j in 0..n {
1161 unary(c_ref_ptr.add(j * c_cs), m);
1162 }
1163 } else if c_cs == 1 {
1164 for i in 0..m {
1165 unary(c_ref_ptr.add(i * c_rs), n);
1166 }
1167 } else {
1168 for i in 0..m {
1169 for j in 0..n {
1170 unary(c_ref_ptr.add(i * c_rs + j * c_cs), 1);
1171 }
1172 }
1173 }
1174
1175 let diff = max_abs_diff(&c, &c_ref, eps);
1176 return diff;
1177}
1178
1179pub unsafe fn check_gemm_s8u8s32(
1180 m: usize,
1181 n: usize,
1182 k: usize,
1183 alpha: f32,
1184 a: *const i8,
1185 a_rs: usize,
1186 a_cs: usize,
1187 b: *const u8,
1188 b_rs: usize,
1189 b_cs: usize,
1190 beta: f32,
1191 c: &[i32],
1192 c_rs: usize,
1193 c_cs: usize,
1194 c_ref: &mut [i32],
1195 unary: unsafe fn(*mut i32, m: usize),
1196 eps: f64,
1197) -> f64 {
1198 #[cfg(feature = "mkl")]
1199 {
1200 let oc_val = 0;
1201 let oc = &oc_val as *const c_int;
1202 let (layout, transa, transb, lda, ldb, ldc) = stride_to_cblas(m, n, k, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs);
1203 let a = a as *const c_void;
1204 let b = b as *const c_void;
1205 cblas_gemm_s8u8s32(
1206 layout,
1207 transa,
1208 transb,
1209 CblasFixOffset,
1210 m as c_int,
1211 n as c_int,
1212 k as c_int,
1213 alpha,
1214 a,
1215 lda,
1216 0,
1217 b,
1218 ldb,
1219 0,
1220 beta,
1221 c_ref.as_mut_ptr(),
1222 ldc,
1223 oc,
1224 CBlasBackend::Mkl,
1225 );
1226 }
1227 #[cfg(not(feature = "mkl"))]
1228 {
1229 gemm_fallback_s8u8s32(m, n, k, alpha, a, a_rs, a_cs, b, b_rs, b_cs, beta, c_ref.as_mut_ptr(), c_rs, c_cs);
1231 }
1232
1233 let c_ref_ptr = c_ref.as_mut_ptr();
1234 if c_rs == 1 {
1235 for j in 0..n {
1236 unary(c_ref_ptr.add(j * c_cs), m);
1237 }
1238 } else if c_cs == 1 {
1239 for i in 0..m {
1240 unary(c_ref_ptr.add(i * c_rs), n);
1241 }
1242 } else {
1243 for i in 0..m {
1244 for j in 0..n {
1245 unary(c_ref_ptr.add(i * c_rs + j * c_cs), 1);
1246 }
1247 }
1248 }
1249
1250 let diff = max_abs_diff(&c, &c_ref, eps);
1251 return diff;
1252}
1253
1254pub unsafe fn check_gemm_f16(
1255 m: usize,
1256 n: usize,
1257 k: usize,
1258 alpha: f16,
1259 a: *const f16,
1260 a_rs: usize,
1261 a_cs: usize,
1262 b: *const f16,
1263 b_rs: usize,
1264 b_cs: usize,
1265 beta: f16,
1266 c: &[f16],
1267 c_rs: usize,
1268 c_cs: usize,
1269 c_ref: &mut [f16],
1270 unary: unsafe fn(*mut f16, m: usize),
1271 eps: f64,
1272) -> f64 {
1273 #[cfg(feature = "mkl")]
1274 {
1275 let (layout, transa, transb, lda, ldb, ldc) = stride_to_cblas(m, n, k, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs);
1276 let a = a as *const c_ushort;
1277 let b = b as *const c_ushort;
1278 let c_ref_ptr = c_ref.as_mut_ptr() as *mut c_ushort;
1279 let alpha = alpha.to_bits();
1280 let beta = beta.to_bits();
1281 cblas_hgemm(
1282 layout,
1283 transa,
1284 transb,
1285 m as c_int,
1286 n as c_int,
1287 k as c_int,
1288 alpha,
1289 a,
1290 lda,
1291 b,
1292 ldb,
1293 beta,
1294 c_ref_ptr,
1295 ldc,
1296 CBlasBackend::Mkl,
1297 );
1298 }
1299 #[cfg(not(feature = "mkl"))]
1300 {
1301 gemm_fallback_f16(m, n, k, alpha, a, a_rs, a_cs, b, b_rs, b_cs, beta, c_ref.as_mut_ptr(), c_rs, c_cs);
1303 }
1304
1305 let c_ref_ptr = c_ref.as_mut_ptr();
1306 if c_rs == 1 {
1307 for j in 0..n {
1308 unary(c_ref_ptr.add(j * c_cs), m);
1309 }
1310 } else if c_cs == 1 {
1311 for i in 0..m {
1312 unary(c_ref_ptr.add(i * c_rs), n);
1313 }
1314 } else {
1315 for i in 0..m {
1316 for j in 0..n {
1317 unary(c_ref_ptr.add(i * c_rs + j * c_cs), 1);
1318 }
1319 }
1320 }
1321
1322 let diff = max_abs_diff(&c, &c_ref, eps);
1323 return diff;
1324}
1325
1326pub unsafe fn check_gemm_f64(
1327 m: usize,
1328 n: usize,
1329 k: usize,
1330 alpha: f64,
1331 a: *const f64,
1332 a_rs: usize,
1333 a_cs: usize,
1334 b: *const f64,
1335 b_rs: usize,
1336 b_cs: usize,
1337 beta: f64,
1338 c: &[f64],
1339 c_rs: usize,
1340 c_cs: usize,
1341 c_ref: &mut [f64],
1342 unary: unsafe fn(*mut f64, m: usize),
1343 eps: f64,
1344) -> f64 {
1345 #[cfg(feature = "mkl")]
1346 {
1347 let (layout, transa, transb, lda, ldb, ldc) = stride_to_cblas(m, n, k, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs);
1348 cblas_dgemm(
1349 layout,
1350 transa,
1351 transb,
1352 m as c_int,
1353 n as c_int,
1354 k as c_int,
1355 alpha,
1356 a,
1357 lda,
1358 b,
1359 ldb,
1360 beta,
1361 c_ref.as_mut_ptr(),
1362 ldc,
1363 CBlasBackend::Mkl,
1364 );
1365 }
1366 #[cfg(not(feature = "mkl"))]
1367 {
1368 gemm_fallback_f64(m, n, k, alpha, a, a_rs, a_cs, b, b_rs, b_cs, beta, c_ref.as_mut_ptr(), c_rs, c_cs);
1370 }
1371
1372 let c_ref_ptr = c_ref.as_mut_ptr();
1373 if c_rs == 1 {
1374 for j in 0..n {
1375 unary(c_ref_ptr.add(j * c_cs), m);
1376 }
1377 } else if c_cs == 1 {
1378 for i in 0..m {
1379 unary(c_ref_ptr.add(i * c_rs), n);
1380 }
1381 } else {
1382 for i in 0..m {
1383 for j in 0..n {
1384 unary(c_ref_ptr.add(i * c_rs + j * c_cs), 1);
1385 }
1386 }
1387 }
1388
1389 let diff = max_abs_diff(&c, &c_ref, eps);
1390 return diff;
1391}
1392
1393pub unsafe fn check_gemm_f32(
1394 m: usize,
1395 n: usize,
1396 k: usize,
1397 alpha: f32,
1398 a: *const f32,
1399 a_rs: usize,
1400 a_cs: usize,
1401 b: *const f32,
1402 b_rs: usize,
1403 b_cs: usize,
1404 beta: f32,
1405 c: &[f32],
1406 c_rs: usize,
1407 c_cs: usize,
1408 c_ref: &mut [f32],
1409 unary: unsafe fn(*mut f32, m: usize),
1410 eps: f64,
1411) -> f64 {
1412 #[cfg(feature = "mkl")]
1413 {
1414 let (layout, transa, transb, lda, ldb, ldc) = stride_to_cblas(m, n, k, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs);
1415 cblas_sgemm(
1416 layout,
1417 transa,
1418 transb,
1419 m as c_int,
1420 n as c_int,
1421 k as c_int,
1422 alpha,
1423 a,
1424 lda,
1425 b,
1426 ldb,
1427 beta,
1428 c_ref.as_mut_ptr(),
1429 ldc,
1430 CBlasBackend::Mkl,
1431 );
1432 }
1433 #[cfg(not(feature = "mkl"))]
1434 {
1435 gemm_fallback_f32(m, n, k, alpha, a, a_rs, a_cs, b, b_rs, b_cs, beta, c_ref.as_mut_ptr(), c_rs, c_cs);
1437 }
1438 let c_ref_ptr = c_ref.as_mut_ptr();
1439 if c_rs == 1 {
1440 for j in 0..n {
1441 unary(c_ref_ptr.add(j * c_cs), m);
1442 }
1443 } else if c_cs == 1 {
1444 for i in 0..m {
1445 unary(c_ref_ptr.add(i * c_rs), n);
1446 }
1447 } else {
1448 for i in 0..m {
1449 for j in 0..n {
1450 unary(c_ref_ptr.add(i * c_rs + j * c_cs), 1);
1451 }
1452 }
1453 }
1454
1455 let diff = max_abs_diff(&c, &c_ref, eps);
1456 return diff;
1457}
1458
1459pub unsafe fn check_gemm_c32(
1460 m: usize,
1461 n: usize,
1462 k: usize,
1463 alpha: Complex32,
1464 a: *const Complex32,
1465 a_rs: usize,
1466 a_cs: usize,
1467 b: *const Complex32,
1468 b_rs: usize,
1469 b_cs: usize,
1470 beta: Complex32,
1471 c: &[Complex32],
1472 c_rs: usize,
1473 c_cs: usize,
1474 c_ref: &mut [Complex32],
1475 unary: unsafe fn(*mut Complex32, m: usize),
1476 eps: f64,
1477) -> f64 {
1478 #[cfg(feature = "mkl")]
1479 {
1480 let (layout, transa, transb, lda, ldb, ldc) = stride_to_cblas(m, n, k, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs);
1481 let a = a as *const c_void;
1482 let b = b as *const c_void;
1483 let c_ref_ptr = c_ref.as_mut_ptr() as *mut c_void;
1484 let alpha_ptr = &alpha as *const Complex32 as *const c_void;
1485 let beta_ptr = &beta as *const Complex32 as *const c_void;
1486 cblas_cgemm(
1487 layout,
1488 transa,
1489 transb,
1490 m as c_int,
1491 n as c_int,
1492 k as c_int,
1493 alpha_ptr,
1494 a,
1495 lda,
1496 b,
1497 ldb,
1498 beta_ptr,
1499 c_ref_ptr,
1500 ldc,
1501 CBlasBackend::Mkl,
1502 );
1503 }
1504 #[cfg(not(feature = "mkl"))]
1505 {
1506 gemm_fallback_c32(m, n, k, alpha, a, a_rs, a_cs, b, b_rs, b_cs, beta, c_ref.as_mut_ptr(), c_rs, c_cs);
1508 }
1509 let c_ref_ptr = c_ref.as_mut_ptr();
1510 if c_rs == 1 {
1511 for j in 0..n {
1512 unary(c_ref_ptr.add(j * c_cs), m);
1513 }
1514 } else if c_cs == 1 {
1515 for i in 0..m {
1516 unary(c_ref_ptr.add(i * c_rs), n);
1517 }
1518 } else {
1519 for i in 0..m {
1520 for j in 0..n {
1521 unary(c_ref_ptr.add(i * c_rs + j * c_cs), 1);
1522 }
1523 }
1524 }
1525
1526 let diff = max_abs_diff(&c, &c_ref, eps);
1527 return diff;
1528}
1529
1530pub unsafe fn check_gemm_c64(
1531 m: usize,
1532 n: usize,
1533 k: usize,
1534 alpha: Complex64,
1535 a: *const Complex64,
1536 a_rs: usize,
1537 a_cs: usize,
1538 b: *const Complex64,
1539 b_rs: usize,
1540 b_cs: usize,
1541 beta: Complex64,
1542 c: &[Complex64],
1543 c_rs: usize,
1544 c_cs: usize,
1545 c_ref: &mut [Complex64],
1546 unary: unsafe fn(*mut Complex64, m: usize),
1547 eps: f64,
1548) -> f64 {
1549 #[cfg(feature = "mkl")]
1550 {
1551 let (layout, transa, transb, lda, ldb, ldc) = stride_to_cblas(m, n, k, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs);
1552 let a = a as *const c_void;
1553 let b = b as *const c_void;
1554 let c_ref_ptr = c_ref.as_mut_ptr() as *mut c_void;
1555 let alpha_ptr = &alpha as *const Complex64 as *const c_void;
1556 let beta_ptr = &beta as *const Complex64 as *const c_void;
1557 cblas_zgemm(
1558 layout,
1559 transa,
1560 transb,
1561 m as c_int,
1562 n as c_int,
1563 k as c_int,
1564 alpha_ptr,
1565 a,
1566 lda,
1567 b,
1568 ldb,
1569 beta_ptr,
1570 c_ref_ptr,
1571 ldc,
1572 CBlasBackend::Mkl,
1573 );
1574 }
1575 #[cfg(not(feature = "mkl"))]
1576 {
1577 gemm_fallback_c64(m, n, k, alpha, a, a_rs, a_cs, b, b_rs, b_cs, beta, c_ref.as_mut_ptr(), c_rs, c_cs);
1579 }
1580
1581 let c_ref_ptr = c_ref.as_mut_ptr();
1582 if c_rs == 1 {
1583 for j in 0..n {
1584 unary(c_ref_ptr.add(j * c_cs), m);
1585 }
1586 } else if c_cs == 1 {
1587 for i in 0..m {
1588 unary(c_ref_ptr.add(i * c_rs), n);
1589 }
1590 } else {
1591 for i in 0..m {
1592 for j in 0..n {
1593 unary(c_ref_ptr.add(i * c_rs + j * c_cs), 1);
1594 }
1595 }
1596 }
1597
1598 let diff = max_abs_diff(&c, &c_ref, eps);
1599 return diff;
1600}
1601
1602pub fn cblas_params_from_str(
1603 layout_str: &str,
1604 m: usize,
1605 n: usize,
1606 k: usize,
1607) -> (i32, i32, i32, CBLAS_TRANSPOSE, CBLAS_TRANSPOSE) {
1608 if layout_str == "nn" {
1609 (m as i32, k as i32, m as i32, CBLAS_TRANSPOSE::CblasNoTrans, CBLAS_TRANSPOSE::CblasNoTrans)
1610 } else if layout_str == "nt" {
1611 (m as i32, n as i32, m as i32, CBLAS_TRANSPOSE::CblasNoTrans, CBLAS_TRANSPOSE::CblasTrans)
1612 } else if layout_str == "tn" {
1613 (k as i32, k as i32, m as i32, CBLAS_TRANSPOSE::CblasTrans, CBLAS_TRANSPOSE::CblasNoTrans)
1614 } else if layout_str == "tt" {
1615 (k as i32, n as i32, m as i32, CBLAS_TRANSPOSE::CblasTrans, CBLAS_TRANSPOSE::CblasTrans)
1616 } else {
1617 panic!("Unsupported layout str");
1618 }
1619}
1620
1621pub fn generate_m_dims(mc: usize, mr: usize) -> Vec<usize> {
1622 return vec![1, 67, 137];
1623 let mut a_dims = vec![];
1624 for m in 1..mr {
1625 a_dims.push(m);
1626 a_dims.push(m + 100);
1627 a_dims.push(m + 1000);
1628 }
1630 a_dims.push(mc + 29);
1631 a_dims
1632}
1633
1634pub fn generate_n_dims(nc: usize, nr: usize) -> Vec<usize> {
1635 let mut a_dims = vec![];
1637 for n in 1..nr {
1638 a_dims.push(n);
1639 a_dims.push(n + 400);
1640 a_dims.push(n + nc);
1641 }
1642 a_dims
1643}
1644pub fn generate_k_dims(kc: usize, kr: usize) -> Vec<usize> {
1647 let mut a_dims = vec![];
1649 let kr = 8;
1650 for k in 1..kr {
1651 a_dims.push(k);
1652 a_dims.push(k + 50);
1653 a_dims.push(k + kc);
1654 }
1655 a_dims
1656}