rstsr_core/device_faer/
matmul.rs

1//! Implementation of faer matmul
2//!
3//! This implementation does not specialize gemv. We always use gemm for matmul.
4
5use super::matmul_impl::*;
6use crate::prelude_dev::*;
7use core::any::TypeId;
8use core::ops::{Add, Mul};
9use core::slice::{from_raw_parts, from_raw_parts_mut};
10use num::{Complex, Zero};
11use rayon::prelude::*;
12
13// code from ndarray
14fn same_type<A: 'static, B: 'static>() -> bool {
15    TypeId::of::<A>() == TypeId::of::<B>()
16}
17
18#[allow(clippy::too_many_arguments)]
19pub fn gemm_faer_ix2_dispatch<TA, TB, TC>(
20    c: &mut [TC],
21    lc: &Layout<Ix2>,
22    a: &[TA],
23    la: &Layout<Ix2>,
24    b: &[TB],
25    lb: &Layout<Ix2>,
26    alpha: TC,
27    beta: TC,
28    pool: Option<&ThreadPool>,
29) -> Result<()>
30where
31    TA: Clone + Send + Sync + 'static,
32    TB: Clone + Send + Sync + 'static,
33    TC: Clone + Send + Sync + 'static,
34    TA: Mul<TB, Output = TC>,
35    TC: Mul<TC, Output = TC> + Add<TC, Output = TC> + Zero + PartialEq,
36{
37    // check if syrk could be applicable
38    let able_syrk = beta == TC::zero()
39        && same_type::<TA, TC>()
40        && same_type::<TB, TC>()
41        && unsafe {
42            let a_ptr = a.as_ptr().add(la.offset()) as *const TC;
43            let b_ptr = b.as_ptr().add(lb.offset()) as *const TC;
44            let equal_ptr = core::ptr::eq(a_ptr, b_ptr);
45            let equal_shape = la.shape() == lb.reverse_axes().shape();
46            let equal_stride = la.stride() == lb.reverse_axes().stride();
47            equal_ptr && equal_shape && equal_stride
48        };
49
50    // type check and dispatch
51    macro_rules! impl_gemm_dispatch {
52        ($ty: ty) => {
53            if (same_type::<TA, $ty>() && same_type::<TB, $ty>() && same_type::<TC, $ty>()) {
54                let a_slice = unsafe { from_raw_parts(a.as_ptr() as *const $ty, a.len()) };
55                let b_slice = unsafe { from_raw_parts(b.as_ptr() as *const $ty, b.len()) };
56                let c_slice = unsafe { from_raw_parts_mut(c.as_mut_ptr() as *mut $ty, c.len()) };
57                let alpha = unsafe { *(&alpha as *const TC as *const $ty) };
58                let beta = unsafe { *(&beta as *const TC as *const $ty) };
59                if able_syrk {
60                    gemm_with_syrk_faer(c_slice, lc, a_slice, la, alpha, beta, pool)?;
61                } else {
62                    gemm_faer(c_slice, lc, a_slice, la, b_slice, lb, alpha, beta, pool)?;
63                }
64                return Ok(());
65            }
66        };
67    }
68
69    impl_gemm_dispatch!(f32);
70    impl_gemm_dispatch!(f64);
71    impl_gemm_dispatch!(Complex<f32>);
72    impl_gemm_dispatch!(Complex<f64>);
73
74    // not able to be accelarated by faer
75    // fallback to naive implementation
76    let c_slice = c;
77    let a_slice = a;
78    let b_slice = b;
79    return gemm_ix2_naive_cpu_rayon(c_slice, lc, a_slice, la, b_slice, lb, alpha, beta, pool);
80}
81
82#[allow(clippy::too_many_arguments)]
83pub fn matmul_row_major_faer<TA, TB, TC, DA, DB, DC>(
84    c: &mut [TC],
85    lc: &Layout<DC>,
86    a: &[TA],
87    la: &Layout<DA>,
88    b: &[TB],
89    lb: &Layout<DB>,
90    alpha: TC,
91    beta: TC,
92    pool: Option<&ThreadPool>,
93) -> Result<()>
94where
95    TA: Clone + Send + Sync + 'static,
96    TB: Clone + Send + Sync + 'static,
97    TC: Clone + Send + Sync + 'static,
98    DA: DimAPI,
99    DB: DimAPI,
100    DC: DimAPI,
101    TA: Mul<TB, Output = TC>,
102    TC: Mul<TC, Output = TC> + Add<TC, Output = TC> + Zero + PartialEq,
103{
104    // NOTE: this only works for row-major layout
105    // for column-major layout, we need to transpose the input:
106    // C = A * B  =>  C^T = B^T * A^T
107
108    let nthreads = match pool {
109        Some(pool) => pool.current_num_threads(),
110        None => 1,
111    };
112
113    // handle special cases
114    match (la.ndim(), lb.ndim(), lc.ndim()) {
115        (1, 1, 0) => {
116            // rule 1: vector inner dot
117            let la = &la.clone().into_dim::<Ix1>().unwrap();
118            let lb = &lb.clone().into_dim::<Ix1>().unwrap();
119            let lc = &lc.clone().into_dim::<Ix0>().unwrap();
120            let c_num = &mut c[lc.offset()];
121            return inner_dot_naive_cpu_rayon(c_num, a, la, b, lb, alpha, beta, pool);
122        },
123        (2, 2, 2) => {
124            // rule 2: matrix multiplication
125            let la = &la.clone().into_dim::<Ix2>().unwrap();
126            let lb = &lb.clone().into_dim::<Ix2>().unwrap();
127            let lc = &lc.clone().into_dim::<Ix2>().unwrap();
128            return gemm_faer_ix2_dispatch(c, lc, a, la, b, lb, alpha, beta, pool);
129        },
130        _ => (),
131    }
132
133    // handle broadcasted cases
134    // temporary variables
135    let la_matmul;
136    let lb_matmul;
137    let lc_matmul;
138    let la_rest;
139    let lb_rest;
140    let lc_rest;
141
142    match (la.ndim(), lb.ndim(), lc.ndim()) {
143        // we have already handled these cases
144        (1, 1, 0) | (2, 2, 2) => unreachable!(),
145        (1, 2.., _) => {
146            // rule 3: | `        K` | `..., K, N` | `   ..., N` |
147            rstsr_assert_eq!(lb.ndim(), lc.ndim() + 1, InvalidLayout)?;
148            let (la_r, la_m) = la.dim_split_at(-1)?;
149            let (lb_r, lb_m) = lb.dim_split_at(-2)?;
150            let (lc_r, lc_m) = lc.dim_split_at(-1)?;
151            la_rest = broadcast_layout_to_first(&lc_r, &la_r, RowMajor)?.1;
152            lb_rest = lb_r;
153            lc_rest = lc_r;
154            la_matmul = la_m.dim_insert(0)?.into_dim::<Ix2>()?;
155            lb_matmul = lb_m.into_dim::<Ix2>()?;
156            lc_matmul = lc_m.dim_insert(0)?.into_dim::<Ix2>()?;
157        },
158        (2.., 1, _) => {
159            // rule 4: | `..., M, K` | `        K` | `   ..., M` |
160            rstsr_assert_eq!(la.ndim(), lc.ndim() + 1, InvalidLayout)?;
161            let (la_r, la_m) = la.dim_split_at(-2)?;
162            let (lb_r, lb_m) = lb.dim_split_at(-1)?;
163            let (lc_r, lc_m) = lc.dim_split_at(-1)?;
164            la_rest = la_r;
165            lb_rest = broadcast_layout_to_first(&lc_r, &lb_r, RowMajor)?.1;
166            lc_rest = lc_r;
167            la_matmul = la_m.into_dim::<Ix2>()?;
168            lb_matmul = lb_m.dim_insert(1)?.into_dim::<Ix2>()?;
169            lc_matmul = lc_m.dim_insert(1)?.into_dim::<Ix2>()?;
170        },
171        (2, 3.., _) => {
172            // rule 5: | `     M, K` | `..., K, N` | `..., M, N` |
173            rstsr_assert_eq!(lb.ndim(), lc.ndim(), InvalidLayout)?;
174            let (la_r, la_m) = la.dim_split_at(-2)?;
175            let (lb_r, lb_m) = lb.dim_split_at(-2)?;
176            let (lc_r, lc_m) = lc.dim_split_at(-2)?;
177            la_rest = broadcast_layout_to_first(&lc_r, &la_r, RowMajor)?.1;
178            lb_rest = lb_r;
179            lc_rest = lc_r;
180            la_matmul = la_m.into_dim::<Ix2>()?;
181            lb_matmul = lb_m.into_dim::<Ix2>()?;
182            lc_matmul = lc_m.into_dim::<Ix2>()?;
183        },
184        (3.., 2, _) => {
185            // rule 6: | `..., M, K` | `     K, N` | `..., M, N` |
186            rstsr_assert_eq!(la.ndim(), lc.ndim(), InvalidLayout)?;
187            let (la_r, la_m) = la.dim_split_at(-2)?;
188            let (lb_r, lb_m) = lb.dim_split_at(-2)?;
189            let (lc_r, lc_m) = lc.dim_split_at(-2)?;
190            la_rest = la_r;
191            lb_rest = broadcast_layout_to_first(&lc_r, &lb_r, RowMajor)?.1;
192            lc_rest = lc_r;
193            la_matmul = la_m.into_dim::<Ix2>()?;
194            lb_matmul = lb_m.into_dim::<Ix2>()?;
195            lc_matmul = lc_m.into_dim::<Ix2>()?;
196        },
197        (3.., 3.., _) => {
198            // rule 7: | `..., M, K` | `..., K, N` | `..., M, N` |
199            rstsr_assert_eq!(la.ndim(), lc.ndim(), InvalidLayout)?;
200            rstsr_assert_eq!(lb.ndim(), lc.ndim(), InvalidLayout)?;
201            let (la_r, la_m) = la.dim_split_at(-2)?;
202            let (lb_r, lb_m) = lb.dim_split_at(-2)?;
203            let (lc_r, lc_m) = lc.dim_split_at(-2)?;
204            la_rest = la_r;
205            lb_rest = lb_r;
206            lc_rest = lc_r;
207            la_matmul = la_m.into_dim::<Ix2>()?;
208            lb_matmul = lb_m.into_dim::<Ix2>()?;
209            lc_matmul = lc_m.into_dim::<Ix2>()?;
210        },
211        _ => todo!(),
212    }
213    // now, lx_rest should have the same shape, while lx_matmul
214    // should be matmulable
215    // only parallel matmul when lx_rest is small (larger than
216    // 2*nthreads), otherwise parallel matmul anyway
217    rstsr_assert_eq!(la_rest.shape(), lb_rest.shape(), InvalidLayout)?;
218    rstsr_assert_eq!(lb_rest.shape(), lc_rest.shape(), InvalidLayout)?;
219    let n_task = la_rest.size();
220    let ita_rest = IterLayoutColMajor::new(&la_rest)?;
221    let itb_rest = IterLayoutColMajor::new(&lb_rest)?;
222    let itc_rest = IterLayoutColMajor::new(&lc_rest)?;
223    if n_task > 4 * nthreads {
224        // parallel outer, sequential matmul
225        let task = || {
226            ita_rest.into_par_iter().zip(itb_rest).zip(itc_rest).try_for_each(
227                |((ia_rest, ib_rest), ic_rest)| -> Result<()> {
228                    // prepare layout
229                    let mut la_m = la_matmul.clone();
230                    let mut lb_m = lb_matmul.clone();
231                    let mut lc_m = lc_matmul.clone();
232                    unsafe {
233                        la_m.set_offset(ia_rest);
234                        lb_m.set_offset(ib_rest);
235                        lc_m.set_offset(ic_rest);
236                    }
237                    // move mutable reference into parallel closure
238                    let c = unsafe {
239                        let c_ptr = c.as_ptr() as *mut TC;
240                        let c_len = c.len();
241                        from_raw_parts_mut(c_ptr, c_len)
242                    };
243                    // clone alpha and beta
244                    let alpha = alpha.clone();
245                    let beta = beta.clone();
246                    gemm_faer_ix2_dispatch(c, &lc_m, a, &la_m, b, &lb_m, alpha, beta, None)
247                },
248            )
249        };
250        match pool {
251            Some(pool) => pool.install(task)?,
252            None => task()?,
253        };
254    } else {
255        // sequential outer, parallel matmul
256        for (ia_rest, ib_rest, ic_rest) in izip!(ita_rest, itb_rest, itc_rest) {
257            // prepare layout
258            let mut la_m = la_matmul.clone();
259            let mut lb_m = lb_matmul.clone();
260            let mut lc_m = lc_matmul.clone();
261            unsafe {
262                la_m.set_offset(ia_rest);
263                lb_m.set_offset(ib_rest);
264                lc_m.set_offset(ic_rest);
265            }
266            // clone alpha and beta
267            let alpha = alpha.clone();
268            let beta = beta.clone();
269            gemm_faer_ix2_dispatch(c, &lc_m, a, &la_m, b, &lb_m, alpha, beta, pool)?;
270        }
271    }
272    return Ok(());
273}
274
275#[allow(clippy::too_many_arguments)]
276impl<TA, TB, TC, DA, DB, DC> DeviceMatMulAPI<TA, TB, TC, DA, DB, DC> for DeviceFaer
277where
278    TA: Clone + Send + Sync + 'static,
279    TB: Clone + Send + Sync + 'static,
280    TC: Clone + Send + Sync + 'static,
281    DA: DimAPI,
282    DB: DimAPI,
283    DC: DimAPI,
284    TA: Mul<TB, Output = TC>,
285    TB: Mul<TA, Output = TC>,
286    TC: Mul<TC, Output = TC> + Add<TC, Output = TC> + Zero + PartialEq,
287{
288    fn matmul(
289        &self,
290        c: &mut Vec<TC>,
291        lc: &Layout<DC>,
292        a: &Vec<TA>,
293        la: &Layout<DA>,
294        b: &Vec<TB>,
295        lb: &Layout<DB>,
296        alpha: TC,
297        beta: TC,
298    ) -> Result<()> {
299        let default_order = self.default_order();
300        let pool = self.get_current_pool();
301        match default_order {
302            RowMajor => matmul_row_major_faer(c, lc, a, la, b, lb, alpha, beta, pool),
303            ColMajor => {
304                let la = la.reverse_axes();
305                let lb = lb.reverse_axes();
306                let lc = lc.reverse_axes();
307                matmul_row_major_faer(c, &lc, b, &lb, a, &la, alpha, beta, pool)
308            },
309        }
310    }
311}
312
313#[cfg(test)]
314mod test {
315    use super::*;
316
317    #[test]
318    fn test_matmul() {
319        let device = DeviceFaer::default();
320        let a = linspace((0.0, 14.0, 15, &device)).into_shape_assume_contig([3, 5]);
321        let b = linspace((0.0, 14.0, 15, &device)).into_shape_assume_contig([5, 3]);
322
323        let d = &a % &b;
324        println!("{d}");
325
326        let a = linspace((0.0, 14.0, 15, &device));
327        let b = linspace((0.0, 14.0, 15, &device));
328        println!("{:}", &a % &b);
329
330        #[cfg(not(feature = "col_major"))]
331        {
332            let a = linspace((0.0, 2.0, 3, &device));
333            let b = linspace((0.0, 29.0, 30, &device)).into_shape_assume_contig([2, 3, 5]);
334            println!("{:}", &a % &b);
335
336            let a = linspace((0.0, 29.0, 30, &device)).into_shape_assume_contig([2, 3, 5]);
337            let b = linspace((0.0, 4.0, 5, &device));
338            println!("{:}", &a % &b);
339
340            let a = linspace((0.0, 14.0, 15, &device)).into_shape_assume_contig([5, 3]);
341            let b = linspace((0.0, 29.0, 30, &device)).into_shape_assume_contig([2, 3, 5]);
342            println!("{:}", &a % &b);
343
344            let a = linspace((0.0, 29.0, 30, &device)).into_shape_assume_contig([2, 3, 5]);
345            let b = linspace((0.0, 14.0, 15, &device)).into_shape_assume_contig([5, 3]);
346            println!("{:}", &a % &b);
347        }
348    }
349}