1#![allow(non_camel_case_types)]
2
3use crate::prelude_dev::*;
4use lapack_ffi::cblas;
5use num::complex::Complex;
6use num::traits::ConstZero;
7use rayon::prelude::*;
8use rstsr_core::prelude_dev::uninitialized_vec;
9use std::ffi::c_void;
10
11type c32 = Complex<f32>;
12type c64 = Complex<f64>;
13
14use cblas::CBLAS_LAYOUT::CblasColMajor as ColMajor;
15use cblas::CBLAS_TRANSPOSE::CblasNoTrans as NoTrans;
16use cblas::CBLAS_TRANSPOSE::CblasTrans as Trans;
17use cblas::CBLAS_UPLO::CblasUpper as Upper;
18
19#[duplicate_item(
22 ty fn_name cblas_wrap ;
23 [f32] [gemm_blas_no_conj_f32] [cblas_sgemm_wrap];
24 [f64] [gemm_blas_no_conj_f64] [cblas_dgemm_wrap];
25 [c32] [gemm_blas_no_conj_c32] [cblas_cgemm_wrap];
26 [c64] [gemm_blas_no_conj_c64] [cblas_zgemm_wrap];
27)]
28#[allow(clippy::too_many_arguments)]
29pub fn fn_name(
30 c: &mut [ty],
31 lc: &Layout<Ix2>,
32 a: &[ty],
33 la: &Layout<Ix2>,
34 b: &[ty],
35 lb: &Layout<Ix2>,
36 alpha: ty,
37 beta: ty,
38 pool: Option<&ThreadPool>,
39) -> Result<()> {
40 if !lc.f_prefer() {
45 if lc.c_prefer() {
48 return fn_name(c, &lc.reverse_axes(), b, &lb.reverse_axes(), a, &la.reverse_axes(), alpha, beta, pool);
50 } else {
51 let lc_new = lc.shape().new_f_contig(None);
53 let mut c_new = unsafe { uninitialized_vec(lc_new.size())? };
54 if beta == <ty>::ZERO {
55 fill_cpu_rayon(&mut c_new, &lc_new, <ty>::ZERO, pool)?;
56 } else {
57 assign_cpu_rayon(&mut c_new, &lc_new, c, lc, pool)?;
58 }
59 fn_name(&mut c_new, &lc_new, a, la, b, lb, alpha, <ty>::ZERO, pool)?;
60 assign_cpu_rayon(c, lc, &c_new, &lc_new, pool)?;
61 return Ok(());
62 }
63 }
64
65 let sc = lc.shape();
67 let sa = la.shape();
68 let sb = lb.shape();
69 rstsr_assert_eq!(sc[0], sa[0], InvalidLayout)?;
70 rstsr_assert_eq!(sa[1], sb[0], InvalidLayout)?;
71 rstsr_assert_eq!(sc[1], sb[1], InvalidLayout)?;
72
73 let m = sc[0];
74 let n = sc[1];
75 let k = sa[1];
76
77 if k == 0 {
79 return fill_cpu_rayon(c, lc, <ty>::ZERO, pool);
81 }
82
83 if n == 0 || m == 0 {
85 return Ok(());
87 }
88
89 let mut a_data: Option<Vec<ty>> = None;
91 let mut b_data: Option<Vec<ty>> = None;
92 let (a_trans, la) = if la.f_prefer() {
93 (NoTrans, la.clone())
94 } else if la.c_prefer() {
95 (Trans, la.reverse_axes())
96 } else {
97 let len = la.size();
98 a_data = unsafe { Some(uninitialized_vec(len)?) };
99 let la_data = la.shape().new_f_contig(None);
100 assign_cpu_rayon(a_data.as_mut().unwrap(), &la_data, a, la, pool)?;
101 (NoTrans, la_data)
102 };
103 let (b_trans, lb) = if lb.f_prefer() {
104 (NoTrans, lb.clone())
105 } else if lb.c_prefer() {
106 (Trans, lb.reverse_axes())
107 } else {
108 let len = lb.size();
109 b_data = unsafe { Some(uninitialized_vec(len)?) };
110 let lb_data = lb.shape().new_f_contig(None);
111 assign_cpu_rayon(b_data.as_mut().unwrap(), &lb_data, b, lb, pool)?;
112 (NoTrans, lb_data)
113 };
114
115 let lda = if la.shape()[1] != 1 { la.stride()[1] } else { la.shape()[0] as isize };
118 let ldb = if lb.shape()[1] != 1 { lb.stride()[1] } else { lb.shape()[0] as isize };
119 let ldc = if lc.shape()[1] != 1 { lc.stride()[1] } else { lc.shape()[0] as isize };
120
121 let ptr_c = unsafe { c.as_mut_ptr().add(lc.offset()) };
122 let ptr_a =
123 if let Some(a_data) = a_data.as_ref() { a_data.as_ptr() } else { unsafe { a.as_ptr().add(la.offset()) } };
124 let ptr_b =
125 if let Some(b_data) = b_data.as_ref() { b_data.as_ptr() } else { unsafe { b.as_ptr().add(lb.offset()) } };
126
127 unsafe {
129 cblas_wrap(ColMajor, a_trans, b_trans, m, n, k, alpha, ptr_a, lda, ptr_b, ldb, beta, ptr_c, ldc);
130 }
131 Ok(())
132}
133
134#[allow(clippy::too_many_arguments)]
135unsafe fn cblas_sgemm_wrap(
136 order: cblas::CBLAS_LAYOUT,
137 a_trans: cblas::CBLAS_TRANSPOSE,
138 b_trans: cblas::CBLAS_TRANSPOSE,
139 m: usize,
140 n: usize,
141 k: usize,
142 alpha: f32,
143 ptr_a: *const f32,
144 lda: isize,
145 ptr_b: *const f32,
146 ldb: isize,
147 beta: f32,
148 ptr_c: *mut f32,
149 ldc: isize,
150) {
151 unsafe {
152 cblas::cblas_sgemm(
153 order as cblas::CBLAS_LAYOUT,
154 a_trans as cblas::CBLAS_TRANSPOSE,
155 b_trans as cblas::CBLAS_TRANSPOSE,
156 m as cblas::blas_int,
157 n as cblas::blas_int,
158 k as cblas::blas_int,
159 alpha,
160 ptr_a,
161 lda as cblas::blas_int,
162 ptr_b,
163 ldb as cblas::blas_int,
164 beta,
165 ptr_c,
166 ldc as cblas::blas_int,
167 );
168 }
169}
170
171#[allow(clippy::too_many_arguments)]
172unsafe fn cblas_dgemm_wrap(
173 order: cblas::CBLAS_LAYOUT,
174 a_trans: cblas::CBLAS_TRANSPOSE,
175 b_trans: cblas::CBLAS_TRANSPOSE,
176 m: usize,
177 n: usize,
178 k: usize,
179 alpha: f64,
180 ptr_a: *const f64,
181 lda: isize,
182 ptr_b: *const f64,
183 ldb: isize,
184 beta: f64,
185 ptr_c: *mut f64,
186 ldc: isize,
187) {
188 unsafe {
189 cblas::cblas_dgemm(
190 order as cblas::CBLAS_LAYOUT,
191 a_trans as cblas::CBLAS_TRANSPOSE,
192 b_trans as cblas::CBLAS_TRANSPOSE,
193 m as cblas::blas_int,
194 n as cblas::blas_int,
195 k as cblas::blas_int,
196 alpha,
197 ptr_a,
198 lda as cblas::blas_int,
199 ptr_b,
200 ldb as cblas::blas_int,
201 beta,
202 ptr_c,
203 ldc as cblas::blas_int,
204 );
205 }
206}
207
208#[allow(clippy::too_many_arguments)]
209unsafe fn cblas_cgemm_wrap(
210 order: cblas::CBLAS_LAYOUT,
211 a_trans: cblas::CBLAS_TRANSPOSE,
212 b_trans: cblas::CBLAS_TRANSPOSE,
213 m: usize,
214 n: usize,
215 k: usize,
216 alpha: c32,
217 ptr_a: *const c32,
218 lda: isize,
219 ptr_b: *const c32,
220 ldb: isize,
221 beta: c32,
222 ptr_c: *mut c32,
223 ldc: isize,
224) {
225 unsafe {
226 cblas::cblas_cgemm(
227 order as cblas::CBLAS_LAYOUT,
228 a_trans as cblas::CBLAS_TRANSPOSE,
229 b_trans as cblas::CBLAS_TRANSPOSE,
230 m as cblas::blas_int,
231 n as cblas::blas_int,
232 k as cblas::blas_int,
233 &alpha as *const _ as *const c_void,
234 ptr_a as *const c_void,
235 lda as cblas::blas_int,
236 ptr_b as *const c_void,
237 ldb as cblas::blas_int,
238 &beta as *const _ as *const c_void,
239 ptr_c as *mut c_void,
240 ldc as cblas::blas_int,
241 );
242 }
243}
244
245#[allow(clippy::too_many_arguments)]
246unsafe fn cblas_zgemm_wrap(
247 order: cblas::CBLAS_LAYOUT,
248 a_trans: cblas::CBLAS_TRANSPOSE,
249 b_trans: cblas::CBLAS_TRANSPOSE,
250 m: usize,
251 n: usize,
252 k: usize,
253 alpha: c64,
254 ptr_a: *const c64,
255 lda: isize,
256 ptr_b: *const c64,
257 ldb: isize,
258 beta: c64,
259 ptr_c: *mut c64,
260 ldc: isize,
261) {
262 unsafe {
263 cblas::cblas_zgemm(
264 order as cblas::CBLAS_LAYOUT,
265 a_trans as cblas::CBLAS_TRANSPOSE,
266 b_trans as cblas::CBLAS_TRANSPOSE,
267 m as cblas::blas_int,
268 n as cblas::blas_int,
269 k as cblas::blas_int,
270 &alpha as *const _ as *const c_void,
271 ptr_a as *const c_void,
272 lda as cblas::blas_int,
273 ptr_b as *const c_void,
274 ldb as cblas::blas_int,
275 &beta as *const _ as *const c_void,
276 ptr_c as *mut c_void,
277 ldc as cblas::blas_int,
278 );
279 }
280}
281
282#[duplicate_item(
287 ty fn_name cblas_wrap ;
288 [f32] [syrk_blas_no_conj_f32] [cblas_ssyrk_wrap];
289 [f64] [syrk_blas_no_conj_f64] [cblas_dsyrk_wrap];
290 [c32] [syrk_blas_no_conj_c32] [cblas_csyrk_wrap];
291 [c64] [syrk_blas_no_conj_c64] [cblas_zsyrk_wrap];
292)]
293pub fn fn_name(
294 c: &mut [ty],
295 lc: &Layout<Ix2>,
296 a: &[ty],
297 la: &Layout<Ix2>,
298 alpha: ty,
299 pool: Option<&ThreadPool>,
300) -> Result<()> {
301 if !lc.f_prefer() {
305 if lc.c_prefer() {
308 return fn_name(c, &lc.reverse_axes(), a, la, alpha, pool);
310 } else {
311 let lc_new = lc.shape().new_f_contig(None);
313 let mut c_new = unsafe { uninitialized_vec(lc_new.size())? };
314 fill_cpu_rayon(&mut c_new, &lc_new, <ty>::ZERO, pool)?;
315 fn_name(&mut c_new, &lc_new, a, la, alpha, pool)?;
316 assign_cpu_rayon(c, lc, &c_new, &lc_new, pool)?;
317 return Ok(());
318 }
319 }
320
321 let sc = lc.shape();
323 let sa = la.shape();
324 rstsr_assert_eq!(sc[0], sa[0], InvalidLayout)?;
325 rstsr_assert_eq!(sc[1], sc[0], InvalidLayout)?;
326
327 let n = sc[0];
328 let k = sa[1];
329
330 if k == 0 {
332 return fill_cpu_rayon(c, lc, <ty>::ZERO, pool);
334 }
335
336 if n == 0 {
338 return Ok(());
340 }
341
342 let mut a_data: Option<Vec<ty>> = None;
344 let (a_trans, la) = if la.f_prefer() {
345 (NoTrans, la.clone())
346 } else if la.c_prefer() {
347 (Trans, la.reverse_axes())
348 } else {
349 let len = la.size();
350 a_data = unsafe { Some(uninitialized_vec(len)?) };
351 let la_data = la.shape().new_f_contig(None);
352 assign_cpu_rayon(a_data.as_mut().unwrap(), &la_data, a, la, pool)?;
353 (NoTrans, la_data)
354 };
355
356 let lda = if la.shape()[1] != 1 { la.stride()[1] } else { la.shape()[0] as isize };
359 let ldc = if lc.shape()[1] != 1 { lc.stride()[1] } else { lc.shape()[0] as isize };
360
361 let ptr_c = unsafe { c.as_mut_ptr().add(lc.offset()) };
362 let ptr_a =
363 if let Some(a_data) = a_data.as_ref() { a_data.as_ptr() } else { unsafe { a.as_ptr().add(la.offset()) } };
364
365 unsafe {
367 cblas_wrap(ColMajor, Upper, a_trans, n, k, alpha, ptr_a, lda, <ty>::ZERO, ptr_c, ldc);
368 }
369
370 let n = sc[0];
372 let ldc = lc.stride()[1];
373 let offset = lc.offset() as isize;
374 let task = || {
375 (0..(n as isize)).into_par_iter().for_each(|j| {
376 ((j + 1)..(n as isize)).for_each(|i| unsafe {
377 let idx_ij = (offset + j * ldc + i) as usize;
378 let idx_ji = (offset + i * ldc + j) as usize;
379 let c_ptr_ij = c.as_ptr().add(idx_ij) as *mut ty;
380 *c_ptr_ij = c[idx_ji];
381 });
382 });
383 };
384 pool.map_or_else(task, |pool| pool.install(task));
385 Ok(())
386}
387
388#[allow(clippy::too_many_arguments)]
389unsafe fn cblas_ssyrk_wrap(
390 order: cblas::CBLAS_LAYOUT,
391 uplo: cblas::CBLAS_UPLO,
392 a_trans: cblas::CBLAS_TRANSPOSE,
393 n: usize,
394 k: usize,
395 alpha: f32,
396 ptr_a: *const f32,
397 lda: isize,
398 beta: f32,
399 ptr_c: *mut f32,
400 ldc: isize,
401) {
402 unsafe {
403 cblas::cblas_ssyrk(
404 order as cblas::CBLAS_LAYOUT,
405 uplo as cblas::CBLAS_UPLO,
406 a_trans as cblas::CBLAS_TRANSPOSE,
407 n as cblas::blas_int,
408 k as cblas::blas_int,
409 alpha,
410 ptr_a,
411 lda as cblas::blas_int,
412 beta,
413 ptr_c,
414 ldc as cblas::blas_int,
415 );
416 }
417}
418
419#[allow(clippy::too_many_arguments)]
420unsafe fn cblas_dsyrk_wrap(
421 order: cblas::CBLAS_LAYOUT,
422 uplo: cblas::CBLAS_UPLO,
423 a_trans: cblas::CBLAS_TRANSPOSE,
424 n: usize,
425 k: usize,
426 alpha: f64,
427 ptr_a: *const f64,
428 lda: isize,
429 beta: f64,
430 ptr_c: *mut f64,
431 ldc: isize,
432) {
433 unsafe {
434 cblas::cblas_dsyrk(
435 order as cblas::CBLAS_LAYOUT,
436 uplo as cblas::CBLAS_UPLO,
437 a_trans as cblas::CBLAS_TRANSPOSE,
438 n as cblas::blas_int,
439 k as cblas::blas_int,
440 alpha,
441 ptr_a,
442 lda as cblas::blas_int,
443 beta,
444 ptr_c,
445 ldc as cblas::blas_int,
446 );
447 }
448}
449
450#[allow(clippy::too_many_arguments)]
451unsafe fn cblas_csyrk_wrap(
452 order: cblas::CBLAS_LAYOUT,
453 uplo: cblas::CBLAS_UPLO,
454 a_trans: cblas::CBLAS_TRANSPOSE,
455 n: usize,
456 k: usize,
457 alpha: c32,
458 ptr_a: *const c32,
459 lda: isize,
460 beta: c32,
461 ptr_c: *mut c32,
462 ldc: isize,
463) {
464 unsafe {
465 cblas::cblas_csyrk(
466 order as cblas::CBLAS_LAYOUT,
467 uplo as cblas::CBLAS_UPLO,
468 a_trans as cblas::CBLAS_TRANSPOSE,
469 n as cblas::blas_int,
470 k as cblas::blas_int,
471 &alpha as *const _ as *const c_void,
472 ptr_a as *const c_void,
473 lda as cblas::blas_int,
474 &beta as *const _ as *const c_void,
475 ptr_c as *mut c_void,
476 ldc as cblas::blas_int,
477 );
478 }
479}
480
481#[allow(clippy::too_many_arguments)]
482unsafe fn cblas_zsyrk_wrap(
483 order: cblas::CBLAS_LAYOUT,
484 uplo: cblas::CBLAS_UPLO,
485 a_trans: cblas::CBLAS_TRANSPOSE,
486 n: usize,
487 k: usize,
488 alpha: c64,
489 ptr_a: *const c64,
490 lda: isize,
491 beta: c64,
492 ptr_c: *mut c64,
493 ldc: isize,
494) {
495 unsafe {
496 cblas::cblas_zsyrk(
497 order as cblas::CBLAS_LAYOUT,
498 uplo as cblas::CBLAS_UPLO,
499 a_trans as cblas::CBLAS_TRANSPOSE,
500 n as cblas::blas_int,
501 k as cblas::blas_int,
502 &alpha as *const _ as *const c_void,
503 ptr_a as *const c_void,
504 lda as cblas::blas_int,
505 &beta as *const _ as *const c_void,
506 ptr_c as *mut c_void,
507 ldc as cblas::blas_int,
508 );
509 }
510}
511
512#[cfg(test)]
515mod test {
516 use super::*;
517
518 #[test]
519 fn test_f32() {
520 let a = vec![1., 2., 3., 4., 5., 6.];
521 let b = vec![1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.];
522 let mut c = vec![0.0; 16];
523
524 let la = [2, 3].c();
525 let lb = [3, 4].c();
526 let lc = [2, 4].c();
527 let pool = rayon::ThreadPoolBuilder::new().num_threads(16).build().unwrap();
528 let pool = Some(&pool);
529 gemm_blas_no_conj_f32(&mut c, &lc, &a, &la, &b, &lb, 1.0, 0.0, pool).unwrap();
530 let c_tsr = TensorView::new(asarray(&c).into_raw_parts().0, lc);
531 println!("{c_tsr:}");
532 println!("{:}", c_tsr.reshape([8]));
533 let c_ref = asarray(vec![38., 44., 50., 56., 83., 98., 113., 128.]);
534 assert!(allclose_f64(&c_tsr.map(|v| *v as f64), &c_ref));
535
536 let la = [2, 3].c();
537 let lb = [3, 4].c();
538 let lc = [2, 4].f();
539 gemm_blas_no_conj_f32(&mut c, &lc, &a, &la, &b, &lb, 1.0, 0.0, pool).unwrap();
540 let c_tsr = TensorView::new(asarray(&c).into_raw_parts().0, lc);
541 println!("{c_tsr:}");
542 println!("{:}", c_tsr.reshape([8]));
543 let c_ref = asarray(vec![38., 44., 50., 56., 83., 98., 113., 128.]);
544 assert!(allclose_f64(&c_tsr.map(|v| *v as f64), &c_ref));
545
546 let la = [2, 3].f();
547 let lb = [3, 4].c();
548 let lc = [2, 4].c();
549 gemm_blas_no_conj_f32(&mut c, &lc, &a, &la, &b, &lb, 1.0, 0.0, pool).unwrap();
550 let c_tsr = TensorView::new(asarray(&c).into_raw_parts().0, lc);
551 println!("{c_tsr:}");
552 println!("{:}", c_tsr.reshape([8]));
553 let c_ref = asarray(vec![61., 70., 79., 88., 76., 88., 100., 112.]);
554 assert!(allclose_f64(&c_tsr.map(|v| *v as f64), &c_ref));
555
556 let la = [2, 3].f();
557 let lb = [3, 4].c();
558 let lc = [2, 4].f();
559 gemm_blas_no_conj_f32(&mut c, &lc, &a, &la, &b, &lb, 2.0, 0.0, pool).unwrap();
560 let c_tsr = TensorView::new(asarray(&c).into_raw_parts().0, lc);
561 println!("{c_tsr:}");
562 println!("{:}", c_tsr.reshape([8]));
563 let c_ref = 2 * asarray(vec![61., 70., 79., 88., 76., 88., 100., 112.]);
564 assert!(allclose_f64(&c_tsr.map(|v| *v as f64), &c_ref));
565 }
566
567 #[test]
568 fn test_c32() {
569 let a = linspace((c32::new(1., 1.), c32::new(6., 6.), 6)).into_vec();
570 let b = linspace((c32::new(1., 1.), c32::new(12., 12.), 12)).into_vec();
571 let mut c = vec![c32::ZERO; 16];
572
573 let la = [2, 3].c();
574 let lb = [3, 4].c();
575 let lc = [2, 4].c();
576 let pool = rayon::ThreadPoolBuilder::new().num_threads(16).build().unwrap();
577 let pool = Some(&pool);
578 gemm_blas_no_conj_c32(&mut c, &lc, &a, &la, &b, &lb, c32::ONE, c32::ZERO, pool).unwrap();
579 let c_tsr = TensorView::new(asarray(&c).into_raw_parts().0, lc);
580 println!("{c_tsr:}");
581 println!("{:}", c_tsr.reshape([8]));
582 }
583}