rstsr_core/device_cpu_serial/
matmul.rs

1//! Matrix multiplication for CPU backend.
2//!
3//! **This implementation is not optimized!**
4
5use core::ops::{Add, Mul};
6
7use crate::prelude_dev::*;
8
9impl<TA, TB, TC, DA, DB, DC> DeviceMatMulAPI<TA, TB, TC, DA, DB, DC> for DeviceCpuSerial
10where
11    TA: Clone,
12    TB: Clone,
13    TC: Clone,
14    DA: DimAPI,
15    DB: DimAPI,
16    DC: DimAPI,
17    TA: Mul<TB, Output = TC>,
18    TC: Mul<TC, Output = TC> + Add<TC, Output = TC>,
19    Self: DeviceGEMMAPI<TA, TB, TC>
20        + DeviceGEMVAPI<TA, TB, TC>
21        + DeviceInnerDotAPI<TA, TB, TC>
22        + DeviceAPI<TA, Raw = Vec<TA>>
23        + DeviceAPI<TB, Raw = Vec<TB>>
24        + DeviceAPI<TC, Raw = Vec<TC>>,
25{
26    fn matmul(
27        &self,
28        c: &mut Vec<TC>,
29        lc: &Layout<DC>,
30        a: &Vec<TA>,
31        la: &Layout<DA>,
32        b: &Vec<TB>,
33        lb: &Layout<DB>,
34        alpha: TC,
35        beta: TC,
36    ) -> Result<()> {
37        let default_order = self.default_order();
38        match (la.ndim(), lb.ndim(), lc.ndim()) {
39            (1, 1, 0) => {
40                // rule 1: vector inner dot
41                let la = &la.clone().into_dim::<Ix1>().unwrap();
42                let lb = &lb.clone().into_dim::<Ix1>().unwrap();
43                let lc = &lc.clone().into_dim::<Ix0>().unwrap();
44                self.inner_dot(c, lc, a, la, b, lb, alpha, beta)?;
45            },
46            (2, 2, 2) => {
47                // rule 2: matrix multiplication
48                let la = &la.clone().into_dim::<Ix2>().unwrap();
49                let lb = &lb.clone().into_dim::<Ix2>().unwrap();
50                let lc = &lc.clone().into_dim::<Ix2>().unwrap();
51                self.gemm(c, lc, a, la, b, lb, alpha, beta)?;
52            },
53            (2, 1, 1) => {
54                // rule 4 special: 2 x 1
55                let la = &la.clone().into_dim::<Ix2>().unwrap();
56                let lb = &lb.clone().into_dim::<Ix1>().unwrap();
57                let lc = &lc.clone().into_dim::<Ix1>().unwrap();
58                self.gemv(c, lc, a, la, b, lb, alpha, beta)?;
59            },
60            (1, 2, 1) => {
61                // rule 3 special: 1 x 2
62                let la = &la.clone().into_dim::<Ix1>().unwrap();
63                let lb = &lb.clone().into_dim::<Ix2>().unwrap();
64                let lc = &lc.clone().into_dim::<Ix1>().unwrap();
65                self.gevm(c, lc, a, la, b, lb, alpha, beta)?;
66            },
67            (1, 2.., _) => {
68                // rule 3: | `        K` | `..., K, N` | `   ..., N` |
69                rstsr_assert_eq!(lb.ndim(), lc.ndim() + 1, InvalidLayout)?;
70                if default_order == ColMajor && lb.ndim() > 2 {
71                    rstsr_raise!(InvalidLayout, "Broadcasting matmul is not supported in col-major.")?;
72                }
73                let la = &la.clone().into_dim::<Ix1>().unwrap();
74                let (lb_rest, lb_matmul) = lb.dim_split_at(-2)?;
75                let (lc_rest, lc_matmul) = lc.dim_split_at(-1)?;
76                let lb_matmul = &mut lb_matmul.into_dim::<Ix2>()?;
77                let lc_matmul = &mut lc_matmul.into_dim::<Ix1>()?;
78                let l_rest = translate_to_col_major(&[&lc_rest, &lb_rest], TensorIterOrder::K)?;
79                let (lc_rest, lb_rest) = (&l_rest[0], &l_rest[1]);
80                let itb_rest = IterLayoutColMajor::new(lb_rest)?;
81                let itc_rest = IterLayoutColMajor::new(lc_rest)?;
82                for (ib_rest, ic_rest) in izip!(itb_rest, itc_rest) {
83                    unsafe { lb_matmul.set_offset(ib_rest) };
84                    unsafe { lc_matmul.set_offset(ic_rest) };
85                    self.gevm(c, lc_matmul, a, la, b, lb_matmul, alpha.clone(), beta.clone())?;
86                }
87            },
88            (2.., 1, _) => {
89                // rule 4: | `..., M, K` | `        K` | `   ..., M` |
90                rstsr_assert_eq!(la.ndim(), lc.ndim() + 1, InvalidLayout)?;
91                if default_order == ColMajor && la.ndim() > 2 {
92                    rstsr_raise!(InvalidLayout, "Broadcasting matmul is not supported in col-major.")?;
93                }
94                let lb = &lb.clone().into_dim::<Ix1>().unwrap();
95                let (la_rest, la_matmul) = la.dim_split_at(-2)?;
96                let (lc_rest, lc_matmul) = lc.dim_split_at(-1)?;
97                let la_matmul = &mut la_matmul.into_dim::<Ix2>()?;
98                let lc_matmul = &mut lc_matmul.into_dim::<Ix1>()?;
99                let l_rest = translate_to_col_major(&[&lc_rest, &la_rest], TensorIterOrder::K)?;
100                let (lc_rest, la_rest) = (&l_rest[0], &l_rest[1]);
101                let ita_rest = IterLayoutColMajor::new(la_rest)?;
102                let itc_rest = IterLayoutColMajor::new(lc_rest)?;
103                for (ib_rest, ic_rest) in izip!(ita_rest, itc_rest) {
104                    unsafe { la_matmul.set_offset(ib_rest) };
105                    unsafe { lc_matmul.set_offset(ic_rest) };
106                    self.gemv(c, lc_matmul, a, la_matmul, b, lb, alpha.clone(), beta.clone())?;
107                }
108            },
109            (2, 3.., _) => {
110                // rule 5: | `     M, K` | `..., K, N` | `..., M, N` |
111                rstsr_assert_eq!(lb.ndim(), lc.ndim(), InvalidLayout)?;
112                if default_order == ColMajor {
113                    rstsr_raise!(InvalidLayout, "Broadcasting matmul is not supported in col-major.")?;
114                }
115                let la = &la.clone().into_dim::<Ix2>().unwrap();
116                let (lb_rest, lb_matmul) = lb.dim_split_at(-2)?;
117                let (lc_rest, lc_matmul) = lc.dim_split_at(-2)?;
118                let lb_matmul = &mut lb_matmul.into_dim::<Ix2>()?;
119                let lc_matmul = &mut lc_matmul.into_dim::<Ix2>()?;
120                let l_rest = translate_to_col_major(&[&lc_rest, &lb_rest], TensorIterOrder::K)?;
121                let (lc_rest, lb_rest) = (&l_rest[0], &l_rest[1]);
122                let itb_rest = IterLayoutColMajor::new(lb_rest)?;
123                let itc_rest = IterLayoutColMajor::new(lc_rest)?;
124                for (ib_rest, ic_rest) in izip!(itb_rest, itc_rest) {
125                    unsafe { lb_matmul.set_offset(ib_rest) };
126                    unsafe { lc_matmul.set_offset(ic_rest) };
127                    self.gemm(c, lc_matmul, a, la, b, lb_matmul, alpha.clone(), beta.clone())?;
128                }
129            },
130            (3.., 2, _) => {
131                // rule 6: | `..., M, K` | `     K, N` | `..., M, N` |
132                rstsr_assert_eq!(la.ndim(), lc.ndim(), InvalidLayout)?;
133                if default_order == ColMajor {
134                    rstsr_raise!(InvalidLayout, "Broadcasting matmul is not supported in col-major.")?;
135                }
136                let lb = &lb.clone().into_dim::<Ix2>().unwrap();
137                let (la_rest, la_matmul) = la.dim_split_at(-2)?;
138                let (lc_rest, lc_matmul) = lc.dim_split_at(-2)?;
139                let la_matmul = &mut la_matmul.into_dim::<Ix2>()?;
140                let lc_matmul = &mut lc_matmul.into_dim::<Ix2>()?;
141                let l_rest = translate_to_col_major(&[&lc_rest, &la_rest], TensorIterOrder::K)?;
142                let (lc_rest, la_rest) = (&l_rest[0], &l_rest[1]);
143                let ita_rest = IterLayoutColMajor::new(la_rest)?;
144                let itc_rest = IterLayoutColMajor::new(lc_rest)?;
145                for (ib_rest, ic_rest) in izip!(ita_rest, itc_rest) {
146                    unsafe { la_matmul.set_offset(ib_rest) };
147                    unsafe { lc_matmul.set_offset(ic_rest) };
148                    self.gemm(c, lc_matmul, a, la_matmul, b, lb, alpha.clone(), beta.clone())?;
149                }
150            },
151            (3.., 3.., _) => {
152                // rule 7: | `..., M, K` | `..., K, N` | `..., M, N` |
153                rstsr_assert_eq!(la.ndim(), lc.ndim(), InvalidLayout)?;
154                rstsr_assert_eq!(lb.ndim(), lc.ndim(), InvalidLayout)?;
155                if default_order == ColMajor {
156                    rstsr_raise!(InvalidLayout, "Broadcasting matmul is not supported in col-major.")?;
157                }
158                let (la_rest, la_matmul) = la.dim_split_at(-2)?;
159                let (lb_rest, lb_matmul) = lb.dim_split_at(-2)?;
160                let (lc_rest, lc_matmul) = lc.dim_split_at(-2)?;
161                let la_matmul = &mut la_matmul.into_dim::<Ix2>()?;
162                let lb_matmul = &mut lb_matmul.into_dim::<Ix2>()?;
163                let lc_matmul = &mut lc_matmul.into_dim::<Ix2>()?;
164                let l_rest =
165                    translate_to_col_major(&[&lc_rest, &la_rest, &lb_rest], TensorIterOrder::K)?;
166                let (lc_rest, la_rest, lb_rest) = (&l_rest[0], &l_rest[1], &l_rest[2]);
167                let ita_rest = IterLayoutColMajor::new(la_rest)?;
168                let itb_rest = IterLayoutColMajor::new(lb_rest)?;
169                let itc_rest = IterLayoutColMajor::new(lc_rest)?;
170                for (ia_rest, ib_rest, ic_rest) in izip!(ita_rest, itb_rest, itc_rest) {
171                    unsafe { la_matmul.set_offset(ia_rest) };
172                    unsafe { lb_matmul.set_offset(ib_rest) };
173                    unsafe { lc_matmul.set_offset(ic_rest) };
174                    self.gemm(
175                        c,
176                        lc_matmul,
177                        a,
178                        la_matmul,
179                        b,
180                        lb_matmul,
181                        alpha.clone(),
182                        beta.clone(),
183                    )?;
184                }
185            },
186            // handle other cases
187            (0, _, _) | (_, 0, _) // zero-dimension input
188            | (1, 1, 1..) // rule 1 invalid
189            | (2, 2, 3..) | (2, 2, 0..2) // rule 2 invalid
190            => {
191                rstsr_raise!(
192                    InvalidLayout,
193                    "Invalid ndim for matmul: {}, {}, {}",
194                    la.ndim(),
195                    lb.ndim(),
196                    lc.ndim()
197                )?;
198            },
199        }
200        return Ok(());
201    }
202}
203
204impl<TA, TB, TC> DeviceGEMMAPI<TA, TB, TC> for DeviceCpuSerial
205where
206    TA: Clone,
207    TB: Clone,
208    TC: Clone,
209    TA: Mul<TB, Output = TC>,
210    TC: Mul<TC, Output = TC> + Add<TC, Output = TC>,
211{
212    fn gemm(
213        &self,
214        c: &mut Vec<TC>,
215        lc: &Layout<Ix2>,
216        a: &Vec<TA>,
217        la: &Layout<Ix2>,
218        b: &Vec<TB>,
219        lb: &Layout<Ix2>,
220        alpha: TC,
221        beta: TC,
222    ) -> Result<()> {
223        // shape check
224        let sc = lc.shape();
225        let sa = la.shape();
226        let sb = lb.shape();
227        rstsr_assert_eq!(sc[0], sa[0], InvalidLayout)?;
228        rstsr_assert_eq!(sa[1], sb[0], InvalidLayout)?;
229        rstsr_assert_eq!(sc[1], sb[1], InvalidLayout)?;
230        let (m, n, k) = (sc[0], sc[1], sa[1]);
231
232        // naive iteration: assuming c-prefer
233        unsafe {
234            for i_m in 0..m {
235                for i_n in 0..n {
236                    let idx_c = lc.index_uncheck(&[i_m, i_n]) as usize;
237                    c[idx_c] = beta.clone() * c[idx_c].clone();
238                }
239                for i_k in 0..k {
240                    let idx_a = la.index_uncheck(&[i_m, i_k]) as usize;
241                    for i_n in 0..n {
242                        let idx_c = lc.index_uncheck(&[i_m, i_n]) as usize;
243                        let idx_b = lb.index_uncheck(&[i_k, i_n]) as usize;
244                        c[idx_c] = alpha.clone() * (a[idx_a].clone() * b[idx_b].clone())
245                            + c[idx_c].clone();
246                    }
247                }
248            }
249        }
250        return Ok(());
251    }
252}
253
254impl<TA, TB, TC> DeviceGEMVAPI<TA, TB, TC> for DeviceCpuSerial
255where
256    TA: Clone,
257    TB: Clone,
258    TC: Clone,
259    TA: Mul<TB, Output = TC>,
260    TC: Mul<TC, Output = TC> + Add<TC, Output = TC>,
261{
262    fn gemv(
263        &self,
264        c: &mut Vec<TC>,
265        lc: &Layout<Ix1>,
266        a: &Vec<TA>,
267        la: &Layout<Ix2>,
268        b: &Vec<TB>,
269        lb: &Layout<Ix1>,
270        alpha: TC,
271        beta: TC,
272    ) -> Result<()> {
273        // shape check
274        let sc = lc.shape();
275        let sa = la.shape();
276        let sb = lb.shape();
277        rstsr_assert_eq!(sc[0], sa[0], InvalidLayout)?;
278        rstsr_assert_eq!(sa[1], sb[0], InvalidLayout)?;
279        let (n, k) = (sa[0], sa[1]);
280
281        // naive iteration: assuming c-prefer
282        unsafe {
283            for i_n in 0..n {
284                let idx_c = lc.index_uncheck(&[i_n]) as usize;
285                c[idx_c] = beta.clone() * c[idx_c].clone();
286                for i_k in 0..k {
287                    let idx_a = la.index_uncheck(&[i_n, i_k]) as usize;
288                    let idx_b = lb.index_uncheck(&[i_k]) as usize;
289                    c[idx_c] =
290                        alpha.clone() * (a[idx_a].clone() * b[idx_b].clone()) + c[idx_c].clone();
291                }
292            }
293        }
294        return Ok(());
295    }
296
297    fn gevm(
298        &self,
299        c: &mut Vec<TC>,
300        lc: &Layout<Ix1>,
301        a: &Vec<TA>,
302        la: &Layout<Ix1>,
303        b: &Vec<TB>,
304        lb: &Layout<Ix2>,
305        alpha: TC,
306        beta: TC,
307    ) -> Result<()> {
308        // shape check
309        let sc = lc.shape();
310        let sa = la.shape();
311        let sb = lb.shape();
312        rstsr_assert_eq!(sc[0], sb[1], InvalidLayout)?;
313        rstsr_assert_eq!(sa[0], sb[0], InvalidLayout)?;
314        let (n, k) = (sb[1], sb[0]);
315
316        // naive iteration: assuming c-prefer
317        unsafe {
318            for i_n in 0..n {
319                let idx_c = lc.index_uncheck(&[i_n]) as usize;
320                c[idx_c] = beta.clone() * c[idx_c].clone();
321                for i_k in 0..k {
322                    let idx_a = la.index_uncheck(&[i_k]) as usize;
323                    let idx_b = lb.index_uncheck(&[i_k, i_n]) as usize;
324                    c[idx_c] =
325                        alpha.clone() * (a[idx_a].clone() * b[idx_b].clone()) + c[idx_c].clone();
326                }
327            }
328        }
329        return Ok(());
330    }
331}
332
333impl<TA, TB, TC> DeviceInnerDotAPI<TA, TB, TC> for DeviceCpuSerial
334where
335    TA: Clone,
336    TB: Clone,
337    TC: Clone,
338    TA: Mul<TB, Output = TC>,
339    TC: Mul<TC, Output = TC> + Add<TC, Output = TC>,
340{
341    fn inner_dot(
342        &self,
343        c: &mut Vec<TC>,
344        lc: &Layout<Ix0>,
345        a: &Vec<TA>,
346        la: &Layout<Ix1>,
347        b: &Vec<TB>,
348        lb: &Layout<Ix1>,
349        alpha: TC,
350        beta: TC,
351    ) -> Result<()> {
352        // shape check
353        let sa = la.shape();
354        let sb = lb.shape();
355        rstsr_assert_eq!(sa[0], sb[0], InvalidLayout)?;
356        let n = sa[0];
357
358        // naive iteration
359        unsafe {
360            let idx_c = lc.index_uncheck(&[]) as usize;
361            let mut sum = beta * c[idx_c].clone();
362            for i in 0..n {
363                let idx_a = la.index_uncheck(&[i]) as usize;
364                let idx_b = lb.index_uncheck(&[i]) as usize;
365                sum = sum + alpha.clone() * (a[idx_a].clone() * b[idx_b].clone());
366            }
367            c[0] = sum;
368        }
369        return Ok(());
370    }
371}