rstsr_openblas/
matmul_impl.rs

1#![allow(non_camel_case_types)]
2
3use crate::prelude_dev::*;
4use lapack_ffi::cblas;
5use num::complex::Complex;
6use num::traits::ConstZero;
7use rayon::prelude::*;
8use rstsr_core::prelude_dev::uninitialized_vec;
9use std::ffi::c_void;
10
11type c32 = Complex<f32>;
12type c64 = Complex<f64>;
13
14use cblas::CBLAS_LAYOUT::CblasColMajor as ColMajor;
15use cblas::CBLAS_TRANSPOSE::CblasNoTrans as NoTrans;
16use cblas::CBLAS_TRANSPOSE::CblasTrans as Trans;
17use cblas::CBLAS_UPLO::CblasUpper as Upper;
18
19/* #region gemm */
20
21#[duplicate_item(
22     ty    fn_name                 cblas_wrap       ;
23    [f32] [gemm_blas_no_conj_f32] [cblas_sgemm_wrap];
24    [f64] [gemm_blas_no_conj_f64] [cblas_dgemm_wrap];
25    [c32] [gemm_blas_no_conj_c32] [cblas_cgemm_wrap];
26    [c64] [gemm_blas_no_conj_c64] [cblas_zgemm_wrap];
27)]
28#[allow(clippy::too_many_arguments)]
29pub fn fn_name(
30    c: &mut [ty],
31    lc: &Layout<Ix2>,
32    a: &[ty],
33    la: &Layout<Ix2>,
34    b: &[ty],
35    lb: &Layout<Ix2>,
36    alpha: ty,
37    beta: ty,
38    pool: Option<&ThreadPool>,
39) -> Result<()> {
40    // nthreads is only used for `assign_cpu_rayon`.
41    // the threading of openblas should be handled outside this function.
42
43    // check layout of output
44    if !lc.f_prefer() {
45        // change to f-contig anyway
46        // we do not handle conj, so this can be done easily
47        if lc.c_prefer() {
48            // c-prefer, transpose and run
49            return fn_name(c, &lc.reverse_axes(), b, &lb.reverse_axes(), a, &la.reverse_axes(), alpha, beta, pool);
50        } else {
51            // not c-prefer, allocate new buffer and copy back
52            let lc_new = lc.shape().new_f_contig(None);
53            let mut c_new = unsafe { uninitialized_vec(lc_new.size())? };
54            if beta == <ty>::ZERO {
55                fill_cpu_rayon(&mut c_new, &lc_new, <ty>::ZERO, pool)?;
56            } else {
57                assign_cpu_rayon(&mut c_new, &lc_new, c, lc, pool)?;
58            }
59            fn_name(&mut c_new, &lc_new, a, la, b, lb, alpha, <ty>::ZERO, pool)?;
60            assign_cpu_rayon(c, lc, &c_new, &lc_new, pool)?;
61            return Ok(());
62        }
63    }
64
65    // we assume that the layout is correct
66    let sc = lc.shape();
67    let sa = la.shape();
68    let sb = lb.shape();
69    rstsr_assert_eq!(sc[0], sa[0], InvalidLayout)?;
70    rstsr_assert_eq!(sa[1], sb[0], InvalidLayout)?;
71    rstsr_assert_eq!(sc[1], sb[1], InvalidLayout)?;
72
73    let m = sc[0];
74    let n = sc[1];
75    let k = sa[1];
76
77    // handle the special case that k is zero-dimensional
78    if k == 0 {
79        // if k is zero, the result is a zero matrix
80        return fill_cpu_rayon(c, lc, <ty>::ZERO, pool);
81    }
82
83    // handle the special case that n/m is zero-dimensional
84    if n == 0 || m == 0 {
85        // if n or m is zero, the result matrix size is zero, and nothing to do
86        return Ok(());
87    }
88
89    // determine trans/layout and clone data if necessary
90    let mut a_data: Option<Vec<ty>> = None;
91    let mut b_data: Option<Vec<ty>> = None;
92    let (a_trans, la) = if la.f_prefer() {
93        (NoTrans, la.clone())
94    } else if la.c_prefer() {
95        (Trans, la.reverse_axes())
96    } else {
97        let len = la.size();
98        a_data = unsafe { Some(uninitialized_vec(len)?) };
99        let la_data = la.shape().new_f_contig(None);
100        assign_cpu_rayon(a_data.as_mut().unwrap(), &la_data, a, la, pool)?;
101        (NoTrans, la_data)
102    };
103    let (b_trans, lb) = if lb.f_prefer() {
104        (NoTrans, lb.clone())
105    } else if lb.c_prefer() {
106        (Trans, lb.reverse_axes())
107    } else {
108        let len = lb.size();
109        b_data = unsafe { Some(uninitialized_vec(len)?) };
110        let lb_data = lb.shape().new_f_contig(None);
111        assign_cpu_rayon(b_data.as_mut().unwrap(), &lb_data, b, lb, pool)?;
112        (NoTrans, lb_data)
113    };
114
115    // final configuration
116    // shape may be broadcasted for one-dimension case, so make this check
117    let lda = if la.shape()[1] != 1 { la.stride()[1] } else { la.shape()[0] as isize };
118    let ldb = if lb.shape()[1] != 1 { lb.stride()[1] } else { lb.shape()[0] as isize };
119    let ldc = if lc.shape()[1] != 1 { lc.stride()[1] } else { lc.shape()[0] as isize };
120
121    let ptr_c = unsafe { c.as_mut_ptr().add(lc.offset()) };
122    let ptr_a =
123        if let Some(a_data) = a_data.as_ref() { a_data.as_ptr() } else { unsafe { a.as_ptr().add(la.offset()) } };
124    let ptr_b =
125        if let Some(b_data) = b_data.as_ref() { b_data.as_ptr() } else { unsafe { b.as_ptr().add(lb.offset()) } };
126
127    // actual computation
128    unsafe {
129        cblas_wrap(ColMajor, a_trans, b_trans, m, n, k, alpha, ptr_a, lda, ptr_b, ldb, beta, ptr_c, ldc);
130    }
131    Ok(())
132}
133
134#[allow(clippy::too_many_arguments)]
135unsafe fn cblas_sgemm_wrap(
136    order: cblas::CBLAS_LAYOUT,
137    a_trans: cblas::CBLAS_TRANSPOSE,
138    b_trans: cblas::CBLAS_TRANSPOSE,
139    m: usize,
140    n: usize,
141    k: usize,
142    alpha: f32,
143    ptr_a: *const f32,
144    lda: isize,
145    ptr_b: *const f32,
146    ldb: isize,
147    beta: f32,
148    ptr_c: *mut f32,
149    ldc: isize,
150) {
151    unsafe {
152        cblas::cblas_sgemm(
153            order as cblas::CBLAS_LAYOUT,
154            a_trans as cblas::CBLAS_TRANSPOSE,
155            b_trans as cblas::CBLAS_TRANSPOSE,
156            m as cblas::blas_int,
157            n as cblas::blas_int,
158            k as cblas::blas_int,
159            alpha,
160            ptr_a,
161            lda as cblas::blas_int,
162            ptr_b,
163            ldb as cblas::blas_int,
164            beta,
165            ptr_c,
166            ldc as cblas::blas_int,
167        );
168    }
169}
170
171#[allow(clippy::too_many_arguments)]
172unsafe fn cblas_dgemm_wrap(
173    order: cblas::CBLAS_LAYOUT,
174    a_trans: cblas::CBLAS_TRANSPOSE,
175    b_trans: cblas::CBLAS_TRANSPOSE,
176    m: usize,
177    n: usize,
178    k: usize,
179    alpha: f64,
180    ptr_a: *const f64,
181    lda: isize,
182    ptr_b: *const f64,
183    ldb: isize,
184    beta: f64,
185    ptr_c: *mut f64,
186    ldc: isize,
187) {
188    unsafe {
189        cblas::cblas_dgemm(
190            order as cblas::CBLAS_LAYOUT,
191            a_trans as cblas::CBLAS_TRANSPOSE,
192            b_trans as cblas::CBLAS_TRANSPOSE,
193            m as cblas::blas_int,
194            n as cblas::blas_int,
195            k as cblas::blas_int,
196            alpha,
197            ptr_a,
198            lda as cblas::blas_int,
199            ptr_b,
200            ldb as cblas::blas_int,
201            beta,
202            ptr_c,
203            ldc as cblas::blas_int,
204        );
205    }
206}
207
208#[allow(clippy::too_many_arguments)]
209unsafe fn cblas_cgemm_wrap(
210    order: cblas::CBLAS_LAYOUT,
211    a_trans: cblas::CBLAS_TRANSPOSE,
212    b_trans: cblas::CBLAS_TRANSPOSE,
213    m: usize,
214    n: usize,
215    k: usize,
216    alpha: c32,
217    ptr_a: *const c32,
218    lda: isize,
219    ptr_b: *const c32,
220    ldb: isize,
221    beta: c32,
222    ptr_c: *mut c32,
223    ldc: isize,
224) {
225    unsafe {
226        cblas::cblas_cgemm(
227            order as cblas::CBLAS_LAYOUT,
228            a_trans as cblas::CBLAS_TRANSPOSE,
229            b_trans as cblas::CBLAS_TRANSPOSE,
230            m as cblas::blas_int,
231            n as cblas::blas_int,
232            k as cblas::blas_int,
233            &alpha as *const _ as *const c_void,
234            ptr_a as *const c_void,
235            lda as cblas::blas_int,
236            ptr_b as *const c_void,
237            ldb as cblas::blas_int,
238            &beta as *const _ as *const c_void,
239            ptr_c as *mut c_void,
240            ldc as cblas::blas_int,
241        );
242    }
243}
244
245#[allow(clippy::too_many_arguments)]
246unsafe fn cblas_zgemm_wrap(
247    order: cblas::CBLAS_LAYOUT,
248    a_trans: cblas::CBLAS_TRANSPOSE,
249    b_trans: cblas::CBLAS_TRANSPOSE,
250    m: usize,
251    n: usize,
252    k: usize,
253    alpha: c64,
254    ptr_a: *const c64,
255    lda: isize,
256    ptr_b: *const c64,
257    ldb: isize,
258    beta: c64,
259    ptr_c: *mut c64,
260    ldc: isize,
261) {
262    unsafe {
263        cblas::cblas_zgemm(
264            order as cblas::CBLAS_LAYOUT,
265            a_trans as cblas::CBLAS_TRANSPOSE,
266            b_trans as cblas::CBLAS_TRANSPOSE,
267            m as cblas::blas_int,
268            n as cblas::blas_int,
269            k as cblas::blas_int,
270            &alpha as *const _ as *const c_void,
271            ptr_a as *const c_void,
272            lda as cblas::blas_int,
273            ptr_b as *const c_void,
274            ldb as cblas::blas_int,
275            &beta as *const _ as *const c_void,
276            ptr_c as *mut c_void,
277            ldc as cblas::blas_int,
278        );
279    }
280}
281
282/* #endregion */
283
284/* #region syrk */
285
286#[duplicate_item(
287     ty    fn_name                 cblas_wrap       ;
288    [f32] [syrk_blas_no_conj_f32] [cblas_ssyrk_wrap];
289    [f64] [syrk_blas_no_conj_f64] [cblas_dsyrk_wrap];
290    [c32] [syrk_blas_no_conj_c32] [cblas_csyrk_wrap];
291    [c64] [syrk_blas_no_conj_c64] [cblas_zsyrk_wrap];
292)]
293pub fn fn_name(
294    c: &mut [ty],
295    lc: &Layout<Ix2>,
296    a: &[ty],
297    la: &Layout<Ix2>,
298    alpha: ty,
299    pool: Option<&ThreadPool>,
300) -> Result<()> {
301    // beta is assumed to be zero, and not passed as argument.
302
303    // check layout of output
304    if !lc.f_prefer() {
305        // change to f-contig anyway
306        // we do not handle conj, so this can be done easily
307        if lc.c_prefer() {
308            // c-prefer, transpose and run
309            return fn_name(c, &lc.reverse_axes(), a, la, alpha, pool);
310        } else {
311            // not c-prefer, allocate new buffer and copy back
312            let lc_new = lc.shape().new_f_contig(None);
313            let mut c_new = unsafe { uninitialized_vec(lc_new.size())? };
314            fill_cpu_rayon(&mut c_new, &lc_new, <ty>::ZERO, pool)?;
315            fn_name(&mut c_new, &lc_new, a, la, alpha, pool)?;
316            assign_cpu_rayon(c, lc, &c_new, &lc_new, pool)?;
317            return Ok(());
318        }
319    }
320
321    // we assume that the layout is correct
322    let sc = lc.shape();
323    let sa = la.shape();
324    rstsr_assert_eq!(sc[0], sa[0], InvalidLayout)?;
325    rstsr_assert_eq!(sc[1], sc[0], InvalidLayout)?;
326
327    let n = sc[0];
328    let k = sa[1];
329
330    // handle the special case that k is zero-dimensional
331    if k == 0 {
332        // if k is zero, the result is a zero matrix
333        return fill_cpu_rayon(c, lc, <ty>::ZERO, pool);
334    }
335
336    // handle the special case that n is zero-dimensional
337    if n == 0 {
338        // if n is zero, the result matrix size is zero, and nothing to do
339        return Ok(());
340    }
341
342    // determine trans/layout and clone data if necessary
343    let mut a_data: Option<Vec<ty>> = None;
344    let (a_trans, la) = if la.f_prefer() {
345        (NoTrans, la.clone())
346    } else if la.c_prefer() {
347        (Trans, la.reverse_axes())
348    } else {
349        let len = la.size();
350        a_data = unsafe { Some(uninitialized_vec(len)?) };
351        let la_data = la.shape().new_f_contig(None);
352        assign_cpu_rayon(a_data.as_mut().unwrap(), &la_data, a, la, pool)?;
353        (NoTrans, la_data)
354    };
355
356    // final configuration
357    // shape may be broadcasted for one-dimension case, so make this check
358    let lda = if la.shape()[1] != 1 { la.stride()[1] } else { la.shape()[0] as isize };
359    let ldc = if lc.shape()[1] != 1 { lc.stride()[1] } else { lc.shape()[0] as isize };
360
361    let ptr_c = unsafe { c.as_mut_ptr().add(lc.offset()) };
362    let ptr_a =
363        if let Some(a_data) = a_data.as_ref() { a_data.as_ptr() } else { unsafe { a.as_ptr().add(la.offset()) } };
364
365    // actual computation
366    unsafe {
367        cblas_wrap(ColMajor, Upper, a_trans, n, k, alpha, ptr_a, lda, <ty>::ZERO, ptr_c, ldc);
368    }
369
370    // write back to lower triangle
371    let n = sc[0];
372    let ldc = lc.stride()[1];
373    let offset = lc.offset() as isize;
374    let task = || {
375        (0..(n as isize)).into_par_iter().for_each(|j| {
376            ((j + 1)..(n as isize)).for_each(|i| unsafe {
377                let idx_ij = (offset + j * ldc + i) as usize;
378                let idx_ji = (offset + i * ldc + j) as usize;
379                let c_ptr_ij = c.as_ptr().add(idx_ij) as *mut ty;
380                *c_ptr_ij = c[idx_ji];
381            });
382        });
383    };
384    pool.map_or_else(task, |pool| pool.install(task));
385    Ok(())
386}
387
388#[allow(clippy::too_many_arguments)]
389unsafe fn cblas_ssyrk_wrap(
390    order: cblas::CBLAS_LAYOUT,
391    uplo: cblas::CBLAS_UPLO,
392    a_trans: cblas::CBLAS_TRANSPOSE,
393    n: usize,
394    k: usize,
395    alpha: f32,
396    ptr_a: *const f32,
397    lda: isize,
398    beta: f32,
399    ptr_c: *mut f32,
400    ldc: isize,
401) {
402    unsafe {
403        cblas::cblas_ssyrk(
404            order as cblas::CBLAS_LAYOUT,
405            uplo as cblas::CBLAS_UPLO,
406            a_trans as cblas::CBLAS_TRANSPOSE,
407            n as cblas::blas_int,
408            k as cblas::blas_int,
409            alpha,
410            ptr_a,
411            lda as cblas::blas_int,
412            beta,
413            ptr_c,
414            ldc as cblas::blas_int,
415        );
416    }
417}
418
419#[allow(clippy::too_many_arguments)]
420unsafe fn cblas_dsyrk_wrap(
421    order: cblas::CBLAS_LAYOUT,
422    uplo: cblas::CBLAS_UPLO,
423    a_trans: cblas::CBLAS_TRANSPOSE,
424    n: usize,
425    k: usize,
426    alpha: f64,
427    ptr_a: *const f64,
428    lda: isize,
429    beta: f64,
430    ptr_c: *mut f64,
431    ldc: isize,
432) {
433    unsafe {
434        cblas::cblas_dsyrk(
435            order as cblas::CBLAS_LAYOUT,
436            uplo as cblas::CBLAS_UPLO,
437            a_trans as cblas::CBLAS_TRANSPOSE,
438            n as cblas::blas_int,
439            k as cblas::blas_int,
440            alpha,
441            ptr_a,
442            lda as cblas::blas_int,
443            beta,
444            ptr_c,
445            ldc as cblas::blas_int,
446        );
447    }
448}
449
450#[allow(clippy::too_many_arguments)]
451unsafe fn cblas_csyrk_wrap(
452    order: cblas::CBLAS_LAYOUT,
453    uplo: cblas::CBLAS_UPLO,
454    a_trans: cblas::CBLAS_TRANSPOSE,
455    n: usize,
456    k: usize,
457    alpha: c32,
458    ptr_a: *const c32,
459    lda: isize,
460    beta: c32,
461    ptr_c: *mut c32,
462    ldc: isize,
463) {
464    unsafe {
465        cblas::cblas_csyrk(
466            order as cblas::CBLAS_LAYOUT,
467            uplo as cblas::CBLAS_UPLO,
468            a_trans as cblas::CBLAS_TRANSPOSE,
469            n as cblas::blas_int,
470            k as cblas::blas_int,
471            &alpha as *const _ as *const c_void,
472            ptr_a as *const c_void,
473            lda as cblas::blas_int,
474            &beta as *const _ as *const c_void,
475            ptr_c as *mut c_void,
476            ldc as cblas::blas_int,
477        );
478    }
479}
480
481#[allow(clippy::too_many_arguments)]
482unsafe fn cblas_zsyrk_wrap(
483    order: cblas::CBLAS_LAYOUT,
484    uplo: cblas::CBLAS_UPLO,
485    a_trans: cblas::CBLAS_TRANSPOSE,
486    n: usize,
487    k: usize,
488    alpha: c64,
489    ptr_a: *const c64,
490    lda: isize,
491    beta: c64,
492    ptr_c: *mut c64,
493    ldc: isize,
494) {
495    unsafe {
496        cblas::cblas_zsyrk(
497            order as cblas::CBLAS_LAYOUT,
498            uplo as cblas::CBLAS_UPLO,
499            a_trans as cblas::CBLAS_TRANSPOSE,
500            n as cblas::blas_int,
501            k as cblas::blas_int,
502            &alpha as *const _ as *const c_void,
503            ptr_a as *const c_void,
504            lda as cblas::blas_int,
505            &beta as *const _ as *const c_void,
506            ptr_c as *mut c_void,
507            ldc as cblas::blas_int,
508        );
509    }
510}
511
512/* #endregion */
513
514#[cfg(test)]
515mod test {
516    use super::*;
517
518    #[test]
519    fn test_f32() {
520        let a = vec![1., 2., 3., 4., 5., 6.];
521        let b = vec![1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.];
522        let mut c = vec![0.0; 16];
523
524        let la = [2, 3].c();
525        let lb = [3, 4].c();
526        let lc = [2, 4].c();
527        let pool = rayon::ThreadPoolBuilder::new().num_threads(16).build().unwrap();
528        let pool = Some(&pool);
529        gemm_blas_no_conj_f32(&mut c, &lc, &a, &la, &b, &lb, 1.0, 0.0, pool).unwrap();
530        let c_tsr = TensorView::new(asarray(&c).into_raw_parts().0, lc);
531        println!("{c_tsr:}");
532        println!("{:}", c_tsr.reshape([8]));
533        let c_ref = asarray(vec![38., 44., 50., 56., 83., 98., 113., 128.]);
534        assert!(allclose_f64(&c_tsr.map(|v| *v as f64), &c_ref));
535
536        let la = [2, 3].c();
537        let lb = [3, 4].c();
538        let lc = [2, 4].f();
539        gemm_blas_no_conj_f32(&mut c, &lc, &a, &la, &b, &lb, 1.0, 0.0, pool).unwrap();
540        let c_tsr = TensorView::new(asarray(&c).into_raw_parts().0, lc);
541        println!("{c_tsr:}");
542        println!("{:}", c_tsr.reshape([8]));
543        let c_ref = asarray(vec![38., 44., 50., 56., 83., 98., 113., 128.]);
544        assert!(allclose_f64(&c_tsr.map(|v| *v as f64), &c_ref));
545
546        let la = [2, 3].f();
547        let lb = [3, 4].c();
548        let lc = [2, 4].c();
549        gemm_blas_no_conj_f32(&mut c, &lc, &a, &la, &b, &lb, 1.0, 0.0, pool).unwrap();
550        let c_tsr = TensorView::new(asarray(&c).into_raw_parts().0, lc);
551        println!("{c_tsr:}");
552        println!("{:}", c_tsr.reshape([8]));
553        let c_ref = asarray(vec![61., 70., 79., 88., 76., 88., 100., 112.]);
554        assert!(allclose_f64(&c_tsr.map(|v| *v as f64), &c_ref));
555
556        let la = [2, 3].f();
557        let lb = [3, 4].c();
558        let lc = [2, 4].f();
559        gemm_blas_no_conj_f32(&mut c, &lc, &a, &la, &b, &lb, 2.0, 0.0, pool).unwrap();
560        let c_tsr = TensorView::new(asarray(&c).into_raw_parts().0, lc);
561        println!("{c_tsr:}");
562        println!("{:}", c_tsr.reshape([8]));
563        let c_ref = 2 * asarray(vec![61., 70., 79., 88., 76., 88., 100., 112.]);
564        assert!(allclose_f64(&c_tsr.map(|v| *v as f64), &c_ref));
565    }
566
567    #[test]
568    fn test_c32() {
569        let a = linspace((c32::new(1., 1.), c32::new(6., 6.), 6)).into_vec();
570        let b = linspace((c32::new(1., 1.), c32::new(12., 12.), 12)).into_vec();
571        let mut c = vec![c32::ZERO; 16];
572
573        let la = [2, 3].c();
574        let lb = [3, 4].c();
575        let lc = [2, 4].c();
576        let pool = rayon::ThreadPoolBuilder::new().num_threads(16).build().unwrap();
577        let pool = Some(&pool);
578        gemm_blas_no_conj_c32(&mut c, &lc, &a, &la, &b, &lb, c32::ONE, c32::ZERO, pool).unwrap();
579        let c_tsr = TensorView::new(asarray(&c).into_raw_parts().0, lc);
580        println!("{c_tsr:}");
581        println!("{:}", c_tsr.reshape([8]));
582    }
583}