rstsr_openblas/
matmul.rs

1use crate::matmul_impl::*;
2use crate::prelude_dev::*;
3use crate::threading::with_num_threads;
4use core::any::TypeId;
5use core::ops::{Add, Mul};
6use core::slice::{from_raw_parts, from_raw_parts_mut};
7use num::{Complex, Zero};
8use rayon::prelude::*;
9
10// code from ndarray
11fn same_type<A: 'static, B: 'static>() -> bool {
12    TypeId::of::<A>() == TypeId::of::<B>()
13}
14
15#[allow(clippy::too_many_arguments)]
16pub fn gemm_blas_ix2_no_conj_dispatch<TA, TB, TC>(
17    c: &mut [TC],
18    lc: &Layout<Ix2>,
19    a: &[TA],
20    la: &Layout<Ix2>,
21    b: &[TB],
22    lb: &Layout<Ix2>,
23    alpha: TC,
24    beta: TC,
25    pool: Option<&ThreadPool>,
26) -> Result<()>
27where
28    TA: Clone + Send + Sync + 'static,
29    TB: Clone + Send + Sync + 'static,
30    TC: Clone + Send + Sync + 'static,
31    TA: Mul<TB, Output = TC>,
32    TC: Mul<TC, Output = TC> + Add<TC, Output = TC> + Zero + PartialEq,
33{
34    // check if syrk could be applicable
35    let able_syrk = beta == TC::zero()
36        && same_type::<TB, TC>()
37        && same_type::<TA, TC>()
38        && unsafe {
39            let a_ptr = a.as_ptr().add(la.offset()) as *const TC;
40            let b_ptr = b.as_ptr().add(lb.offset()) as *const TC;
41            let equal_ptr = core::ptr::eq(a_ptr, b_ptr);
42            let equal_shape = la.shape() == lb.reverse_axes().shape();
43            let equal_stride = la.stride() == lb.reverse_axes().stride();
44            equal_ptr && equal_shape && equal_stride
45        };
46
47    // type check and dispatch
48    macro_rules! impl_gemm_dispatch {
49        ($ty: ty, $fn_gemm_name: ident, $fn_syrk_name: ident) => {
50            if (same_type::<TA, $ty>() && same_type::<TB, $ty>() && same_type::<TC, $ty>()) {
51                let a_slice = unsafe { from_raw_parts(a.as_ptr() as *const $ty, a.len()) };
52                let b_slice = unsafe { from_raw_parts(b.as_ptr() as *const $ty, b.len()) };
53                let c_slice = unsafe { from_raw_parts_mut(c.as_mut_ptr() as *mut $ty, c.len()) };
54                let alpha = unsafe { *(&alpha as *const TC as *const $ty) };
55                let beta = unsafe { *(&beta as *const TC as *const $ty) };
56                if able_syrk {
57                    $fn_syrk_name(c_slice, lc, a_slice, la, alpha, pool)?;
58                } else {
59                    $fn_gemm_name(c_slice, lc, a_slice, la, b_slice, lb, alpha, beta, pool)?;
60                }
61                return Ok(());
62            }
63        };
64    }
65
66    impl_gemm_dispatch!(f32, gemm_blas_no_conj_f32, syrk_blas_no_conj_f32);
67    impl_gemm_dispatch!(f64, gemm_blas_no_conj_f64, syrk_blas_no_conj_f64);
68    impl_gemm_dispatch!(Complex<f32>, gemm_blas_no_conj_c32, syrk_blas_no_conj_c32);
69    impl_gemm_dispatch!(Complex<f64>, gemm_blas_no_conj_c64, syrk_blas_no_conj_c64);
70
71    // not able to be accelarated by blas_no_conj
72    // fallback to naive implementation
73    let c_slice = c;
74    let a_slice = a;
75    let b_slice = b;
76    return gemm_ix2_naive_cpu_rayon(c_slice, lc, a_slice, la, b_slice, lb, alpha, beta, pool);
77}
78
79#[allow(clippy::too_many_arguments)]
80pub fn matmul_row_major_blas<TA, TB, TC, DA, DB, DC>(
81    c: &mut [TC],
82    lc: &Layout<DC>,
83    a: &[TA],
84    la: &Layout<DA>,
85    b: &[TB],
86    lb: &Layout<DB>,
87    alpha: TC,
88    beta: TC,
89    pool: Option<&ThreadPool>,
90) -> Result<()>
91where
92    TA: Clone + Send + Sync + 'static,
93    TB: Clone + Send + Sync + 'static,
94    TC: Clone + Send + Sync + 'static,
95    DA: DimAPI,
96    DB: DimAPI,
97    DC: DimAPI,
98    TA: Mul<TB, Output = TC>,
99    TC: Mul<TC, Output = TC> + Add<TC, Output = TC> + Zero + PartialEq,
100{
101    // NOTE: this only works for row-major layout
102    // for column-major layout, we need to transpose the input:
103    // C = A * B  =>  C^T = B^T * A^T
104
105    // quick return for empty matrix
106    // in this case, we do not check the shape of a, b, c
107    if lc.size() == 0 {
108        return Ok(());
109    }
110
111    let nthreads = match pool {
112        Some(pool) => pool.current_num_threads(),
113        None => 1,
114    };
115
116    // handle special cases
117    match (la.ndim(), lb.ndim(), lc.ndim()) {
118        (1, 1, 0) => {
119            // rule 1: vector inner dot
120            let la = &la.clone().into_dim::<Ix1>().unwrap();
121            let lb = &lb.clone().into_dim::<Ix1>().unwrap();
122            let lc = &lc.clone().into_dim::<Ix0>().unwrap();
123            let c_num = &mut c[lc.offset()];
124            return with_num_threads(nthreads, || inner_dot_naive_cpu_rayon(c_num, a, la, b, lb, alpha, beta, pool));
125        },
126        (2, 2, 2) => {
127            // rule 2: matrix multiplication
128            let la = &la.clone().into_dim::<Ix2>().unwrap();
129            let lb = &lb.clone().into_dim::<Ix2>().unwrap();
130            let lc = &lc.clone().into_dim::<Ix2>().unwrap();
131            return with_num_threads(nthreads, || {
132                gemm_blas_ix2_no_conj_dispatch(c, lc, a, la, b, lb, alpha, beta, pool)
133            });
134        },
135        _ => (),
136    };
137
138    // handle broadcasted cases
139    // temporary variables
140    let la_matmul;
141    let lb_matmul;
142    let lc_matmul;
143    let la_rest;
144    let lb_rest;
145    let lc_rest;
146
147    match (la.ndim(), lb.ndim(), lc.ndim()) {
148        // we have already handled these cases
149        (1, 1, 0) | (2, 2, 2) => unreachable!(),
150        (1, 2.., _) => {
151            // rule 3: | `        K` | `..., K, N` | `   ..., N` |
152            rstsr_assert_eq!(lb.ndim(), lc.ndim() + 1, InvalidLayout)?;
153            let (la_r, la_m) = la.dim_split_at(-1)?;
154            let (lb_r, lb_m) = lb.dim_split_at(-2)?;
155            let (lc_r, lc_m) = lc.dim_split_at(-1)?;
156            la_rest = broadcast_layout_to_first(&lc_r, &la_r, RowMajor)?.1;
157            lb_rest = lb_r;
158            lc_rest = lc_r;
159            la_matmul = la_m.dim_insert(0)?.into_dim::<Ix2>()?;
160            lb_matmul = lb_m.into_dim::<Ix2>()?;
161            lc_matmul = lc_m.dim_insert(0)?.into_dim::<Ix2>()?;
162        },
163        (2.., 1, _) => {
164            // rule 4: | `..., M, K` | `        K` | `   ..., M` |
165            rstsr_assert_eq!(la.ndim(), lc.ndim() + 1, InvalidLayout)?;
166            let (la_r, la_m) = la.dim_split_at(-2)?;
167            let (lb_r, lb_m) = lb.dim_split_at(-1)?;
168            let (lc_r, lc_m) = lc.dim_split_at(-1)?;
169            la_rest = la_r;
170            lb_rest = broadcast_layout_to_first(&lc_r, &lb_r, RowMajor)?.1;
171            lc_rest = lc_r;
172            la_matmul = la_m.into_dim::<Ix2>()?;
173            lb_matmul = lb_m.dim_insert(1)?.into_dim::<Ix2>()?;
174            lc_matmul = lc_m.dim_insert(1)?.into_dim::<Ix2>()?;
175        },
176        (2, 3.., _) => {
177            // rule 5: | `     M, K` | `..., K, N` | `..., M, N` |
178            rstsr_assert_eq!(lb.ndim(), lc.ndim(), InvalidLayout)?;
179            let (la_r, la_m) = la.dim_split_at(-2)?;
180            let (lb_r, lb_m) = lb.dim_split_at(-2)?;
181            let (lc_r, lc_m) = lc.dim_split_at(-2)?;
182            la_rest = broadcast_layout_to_first(&lc_r, &la_r, RowMajor)?.1;
183            lb_rest = lb_r;
184            lc_rest = lc_r;
185            la_matmul = la_m.into_dim::<Ix2>()?;
186            lb_matmul = lb_m.into_dim::<Ix2>()?;
187            lc_matmul = lc_m.into_dim::<Ix2>()?;
188        },
189        (3.., 2, _) => {
190            // rule 6: | `..., M, K` | `     K, N` | `..., M, N` |
191            rstsr_assert_eq!(la.ndim(), lc.ndim(), InvalidLayout)?;
192            let (la_r, la_m) = la.dim_split_at(-2)?;
193            let (lb_r, lb_m) = lb.dim_split_at(-2)?;
194            let (lc_r, lc_m) = lc.dim_split_at(-2)?;
195            la_rest = la_r;
196            lb_rest = broadcast_layout_to_first(&lc_r, &lb_r, RowMajor)?.1;
197            lc_rest = lc_r;
198            la_matmul = la_m.into_dim::<Ix2>()?;
199            lb_matmul = lb_m.into_dim::<Ix2>()?;
200            lc_matmul = lc_m.into_dim::<Ix2>()?;
201        },
202        (3.., 3.., _) => {
203            // rule 7: | `..., M, K` | `..., K, N` | `..., M, N` |
204            rstsr_assert_eq!(la.ndim(), lc.ndim(), InvalidLayout)?;
205            rstsr_assert_eq!(lb.ndim(), lc.ndim(), InvalidLayout)?;
206            let (la_r, la_m) = la.dim_split_at(-2)?;
207            let (lb_r, lb_m) = lb.dim_split_at(-2)?;
208            let (lc_r, lc_m) = lc.dim_split_at(-2)?;
209            la_rest = la_r;
210            lb_rest = lb_r;
211            lc_rest = lc_r;
212            la_matmul = la_m.into_dim::<Ix2>()?;
213            lb_matmul = lb_m.into_dim::<Ix2>()?;
214            lc_matmul = lc_m.into_dim::<Ix2>()?;
215        },
216        _ => {
217            rstsr_raise!(InvalidLayout, "This is not valid layout for matmul broadcasting.")?;
218            unreachable!()
219        },
220    }
221    // now, lx_rest should have the same shape, while lx_matmul
222    // should be matmulable
223    // only parallel matmul when lx_rest is small (larger than
224    // 2*nthreads), otherwise parallel matmul anyway
225    rstsr_assert_eq!(la_rest.shape(), lb_rest.shape(), InvalidLayout)?;
226    rstsr_assert_eq!(lb_rest.shape(), lc_rest.shape(), InvalidLayout)?;
227    let n_task = la_rest.size();
228    let ita_rest = IterLayoutColMajor::new(&la_rest)?;
229    let itb_rest = IterLayoutColMajor::new(&lb_rest)?;
230    let itc_rest = IterLayoutColMajor::new(&lc_rest)?;
231    if n_task >= 4 * nthreads {
232        // parallel outer, sequential matmul
233        with_num_threads(1, || {
234            let task = || {
235                ita_rest.into_par_iter().zip(itb_rest).zip(itc_rest).try_for_each(
236                    |((ia_rest, ib_rest), ic_rest)| -> Result<()> {
237                        // prepare layout
238                        let mut la_m = la_matmul.clone();
239                        let mut lb_m = lb_matmul.clone();
240                        let mut lc_m = lc_matmul.clone();
241                        unsafe {
242                            la_m.set_offset(ia_rest);
243                            lb_m.set_offset(ib_rest);
244                            lc_m.set_offset(ic_rest);
245                        }
246                        // move mutable reference into parallel closure
247                        let c = unsafe {
248                            let c_ptr = c.as_ptr() as *mut TC;
249                            let c_len = c.len();
250                            from_raw_parts_mut(c_ptr, c_len)
251                        };
252                        // clone alpha and beta
253                        let alpha = alpha.clone();
254                        let beta = beta.clone();
255                        gemm_blas_ix2_no_conj_dispatch(c, &lc_m, a, &la_m, b, &lb_m, alpha, beta, None)
256                    },
257                )
258            };
259            match pool {
260                Some(pool) => pool.install(task),
261                None => task(),
262            }
263        })
264    } else {
265        // sequential outer, parallel matmul
266        with_num_threads(nthreads, || -> Result<()> {
267            izip!(ita_rest, itb_rest, itc_rest).try_for_each(|(ia_rest, ib_rest, ic_rest)| {
268                // prepare layout
269                let mut la_m = la_matmul.clone();
270                let mut lb_m = lb_matmul.clone();
271                let mut lc_m = lc_matmul.clone();
272                unsafe {
273                    la_m.set_offset(ia_rest);
274                    lb_m.set_offset(ib_rest);
275                    lc_m.set_offset(ic_rest);
276                }
277                // clone alpha and beta
278                let alpha = alpha.clone();
279                let beta = beta.clone();
280                gemm_blas_ix2_no_conj_dispatch(c, &lc_m, a, &la_m, b, &lb_m, alpha, beta, pool)
281            })
282        })
283    }
284}
285
286#[allow(clippy::too_many_arguments)]
287impl<TA, TB, TC, DA, DB, DC> DeviceMatMulAPI<TA, TB, TC, DA, DB, DC> for DeviceBLAS
288where
289    TA: Clone + Send + Sync + 'static,
290    TB: Clone + Send + Sync + 'static,
291    TC: Clone + Send + Sync + 'static,
292    DA: DimAPI,
293    DB: DimAPI,
294    DC: DimAPI,
295    TA: Mul<TB, Output = TC>,
296    TB: Mul<TA, Output = TC>,
297    TC: Mul<TC, Output = TC> + Add<TC, Output = TC> + Zero + PartialEq,
298{
299    fn matmul(
300        &self,
301        c: &mut Vec<TC>,
302        lc: &Layout<DC>,
303        a: &Vec<TA>,
304        la: &Layout<DA>,
305        b: &Vec<TB>,
306        lb: &Layout<DB>,
307        alpha: TC,
308        beta: TC,
309    ) -> Result<()> {
310        let default_order = self.default_order();
311        let pool = self.get_current_pool();
312        match default_order {
313            RowMajor => matmul_row_major_blas(c, lc, a, la, b, lb, alpha, beta, pool),
314            ColMajor => {
315                let la = la.reverse_axes();
316                let lb = lb.reverse_axes();
317                let lc = lc.reverse_axes();
318                matmul_row_major_blas(c, &lc, b, &lb, a, &la, alpha, beta, pool)
319            },
320        }
321    }
322}
323
324#[cfg(test)]
325mod test {
326    use super::*;
327
328    #[test]
329    fn test_matmul() {
330        let device = DeviceBLAS::default();
331        let a = linspace((0.0, 14.0, 15, &device)).into_shape([3, 5]);
332        let b = linspace((0.0, 14.0, 15, &device)).into_shape([5, 3]);
333        println!("{:}", &a % &b);
334
335        let a = linspace((0.0, 14.0, 15, &device));
336        let b = linspace((0.0, 14.0, 15, &device));
337        println!("{:}", &a % &b);
338
339        let a = linspace((0.0, 2.0, 3, &device));
340        let b = linspace((0.0, 29.0, 30, &device)).into_shape([2, 3, 5]);
341        println!("{:}", &a % &b);
342
343        let a = linspace((0.0, 29.0, 30, &device)).into_shape([2, 3, 5]);
344        let b = linspace((0.0, 4.0, 5, &device));
345        println!("{:}", &a % &b);
346
347        let a = linspace((0.0, 14.0, 15, &device)).into_shape([5, 3]);
348        let b = linspace((0.0, 29.0, 30, &device)).into_shape([2, 3, 5]);
349        println!("{:}", &a % &b);
350
351        let a = linspace((0.0, 29.0, 30, &device)).into_shape([2, 3, 5]);
352        let b = linspace((0.0, 14.0, 15, &device)).into_shape([5, 3]);
353        println!("{:}", &a % &b);
354    }
355
356    #[test]
357    #[ignore]
358    fn parallel_test_full() {
359        let device = DeviceBLAS::default();
360        let a = linspace((0.0, 1.0, 8192 * 8192, &device)).into_shape([8192, 8192]);
361        let b = linspace((0.0, 1.0, 8192 * 8192, &device)).into_shape([8192, 8192]);
362        for _ in 0..10 {
363            let start = std::time::Instant::now();
364            let _ = &a % &b;
365            println!("time: {:?}", start.elapsed());
366        }
367    }
368
369    #[test]
370    #[ignore]
371    fn parallel_test_full_512() {
372        let device = DeviceBLAS::new(1);
373        let a = linspace((0.0, 1.0, 512 * 512, &device)).into_shape([512, 512]);
374        let b = linspace((0.0, 1.0, 512 * 512, &device)).into_shape([512, 512]);
375        for _ in 0..1000 {
376            let start = std::time::Instant::now();
377            let c = &a % &b;
378            println!("{:?}", c.device());
379            println!("time: {:?}", start.elapsed());
380        }
381    }
382
383    #[test]
384    #[ignore]
385    fn parallel_test_par_rule7() {
386        let device = DeviceBLAS::default();
387        let a = linspace((0.0, 1.0, 8192 * 8192, &device)).into_shape([256, 512, 512]);
388        let b = linspace((0.0, 1.0, 8192 * 8192, &device)).into_shape([256, 512, 512]);
389        for i in 0..10 {
390            let start = std::time::Instant::now();
391            let c = &a % &b;
392            println!("{:?}", c.layout());
393            println!("time: {:?}", start.elapsed());
394            if i == 0 {
395                println!("{c:?}");
396            }
397        }
398    }
399
400    #[test]
401    #[ignore]
402    fn parallel_test_par_rule6() {
403        let device = DeviceBLAS::default();
404        let a = linspace((0.0, 1.0, 8192 * 8192, &device)).into_shape([256, 512, 512]);
405        let b = linspace((0.0, 1.0, 512 * 512, &device)).into_shape([512, 512]);
406        for i in 0..10 {
407            let start = std::time::Instant::now();
408            let c = &a % &b;
409            println!("{:?}", c.layout());
410            println!("time: {:?}", start.elapsed());
411            if i == 0 {
412                println!("{c:?}");
413            }
414        }
415    }
416
417    #[test]
418    #[ignore]
419    fn parallel_test_par_rule6_fprefer() {
420        let device = DeviceBLAS::default();
421        let a = linspace((0.0, 1.0, 8192 * 8192, &device)).into_shape([512, 512, 256]).into_reverse_axes();
422        let b = linspace((0.0, 1.0, 512 * 512, &device)).into_shape([512, 512]);
423        for i in 0..10 {
424            let start = std::time::Instant::now();
425            let c = &a % &b;
426            println!("{:?}", c.layout());
427            println!("time: {:?}", start.elapsed());
428            if i == 0 {
429                println!("{c:?}");
430            }
431        }
432    }
433
434    #[test]
435    fn syrk_correctness() {
436        let device = DeviceBLAS::default();
437        let a = linspace((0.0, 1.0, 512 * 512, &device)).into_shape([512, 512]);
438        let b = linspace((0.0, 1.0, 512 * 512, &device)).into_shape([512, 512]);
439        let c = &a % &a.t();
440        let d = &a % &b.t();
441        assert!(allclose_f64(&c, &d));
442
443        let device = DeviceBLAS::default();
444        let a = linspace((0.0, 1.0, 1024 * 1024, &device)).into_shape([4, 512, 512]);
445        let b = linspace((0.0, 1.0, 1024 * 1024, &device)).into_shape([4, 512, 512]);
446        let c = &a % &a.swapaxes(-1, -2);
447        let d = &a % &b.swapaxes(-1, -2);
448        assert!(allclose_f64(&c, &d));
449    }
450
451    #[test]
452    #[ignore]
453    fn syrk_efficiency() {
454        use std::hint::black_box;
455        let device = DeviceBLAS::default();
456        let a = linspace((0.0, 1.0, 8192 * 8192, &device)).into_shape([256, 512, 512]);
457        let b = linspace((0.0, 1.0, 8192 * 8192, &device)).into_shape([256, 512, 512]);
458        for _ in 0..10 {
459            let start = std::time::Instant::now();
460            black_box(&a % &a.swapaxes(-1, -2));
461            println!("syrk time: {:?}", start.elapsed());
462            let start = std::time::Instant::now();
463            black_box(&a % &b.swapaxes(-1, -2));
464            println!("gemm time: {:?}", start.elapsed());
465        }
466
467        println!("---------------------");
468        let a = linspace((0.0, 1.0, 8192 * 8192, &device)).into_shape([8192, 8192]);
469        let b = linspace((0.0, 1.0, 8192 * 8192, &device)).into_shape([8192, 8192]);
470        for _ in 0..10 {
471            let start = std::time::Instant::now();
472            black_box(&a % &a.swapaxes(-1, -2));
473            println!("syrk time: {:?}", start.elapsed());
474            let start = std::time::Instant::now();
475            black_box(&a % &b.swapaxes(-1, -2));
476            println!("gemm time: {:?}", start.elapsed());
477        }
478    }
479}