rstsr_common/layout/
matmul.rs

1/*!
2
3Layout manuplication for matmul and other linalg operations
4
5# Rules for matmul
6
7We refer [Python array API](https://data-apis.org/array-api/2023.12/specification/generated/array_api.matmul.html) for more information.
8
9Please note that the following rule only applies to row-major.
10
11| Id | A | B | C |
12|----|---|---|---|
13| 1. | `        N` | `        N` | `         ` |
14| 2. | `     M, K` | `     K, N` | `     M, N` |
15| 3. | `        K` | `..., K, N` | `   ..., N` |
16| 4. | `..., M, K` | `        K` | `   ..., M` |
17| 5. | `     M, K` | `..., K, N` | `..., M, N` |
18| 6. | `..., M, K` | `     K, N` | `..., M, N` |
19| 7. | `..., M, K` | `..., K, N` | `..., M, N` |
20
21For col-major, only rule 1, 2, (part of) 3, (part of) 4 are valid.
22
23*/
24
25use crate::prelude_dev::*;
26
27/// Rules of matmul.
28#[derive(Clone, Copy, Debug, PartialEq, Eq)]
29pub enum MatMulType {
30    InnerDot,
31    GEMM22,
32    GEVM,
33    GEMV,
34    GEMM2X,
35    GEMMX2,
36    GEMMXX,
37}
38
39#[derive(Clone, Debug)]
40pub struct LayoutMatMulConfig<DA, DB>
41where
42    DA: DimAPI,
43    DB: DimAPI,
44    Self: LayoutMatMulAPI<DA, DB>,
45{
46    pub matmul_type: MatMulType,
47    pub lc: Layout<<Self as LayoutMatMulAPI<DA, DB>>::DC>,
48    pub la_rest: Option<Layout<IxD>>,
49    pub lb_rest: Option<Layout<IxD>>,
50    pub lc_rest: Option<Layout<IxD>>,
51    pub la_matmul: Layout<IxD>,
52    pub lb_matmul: Layout<IxD>,
53    pub lc_matmul: Layout<IxD>,
54}
55
56pub trait LayoutMatMulAPI<DA, DB>
57where
58    DA: DimAPI,
59    DB: DimAPI,
60    Self: Sized,
61{
62    type DC: DimAPI;
63    /// Layout configuration for matmul.
64    ///
65    /// For order, currently we only accept deterministic order.
66    fn layout_matmul(la: &Layout<DA>, lb: &Layout<DB>, order: FlagOrder) -> Result<Self>;
67}
68
69// rule 1
70impl LayoutMatMulAPI<Ix1, Ix1> for LayoutMatMulConfig<Ix1, Ix1> {
71    type DC = Ix0;
72    fn layout_matmul(la: &Layout<Ix1>, lb: &Layout<Ix1>, _: FlagOrder) -> Result<Self> {
73        // check shape
74        rstsr_assert_eq!(la.shape(), lb.shape(), InvalidLayout)?;
75        let lc = unsafe { Layout::new_unchecked([], [], 0) };
76        Ok(LayoutMatMulConfig {
77            matmul_type: MatMulType::InnerDot,
78            lc: lc.clone(),
79            la_rest: None,
80            lb_rest: None,
81            lc_rest: None,
82            la_matmul: la.to_dim()?,
83            lb_matmul: lb.to_dim()?,
84            lc_matmul: lc.to_dim()?,
85        })
86    }
87}
88
89// rule 2
90impl LayoutMatMulAPI<Ix2, Ix2> for LayoutMatMulConfig<Ix2, Ix2> {
91    type DC = Ix2;
92    fn layout_matmul(la: &Layout<Ix2>, lb: &Layout<Ix2>, order: FlagOrder) -> Result<Self> {
93        // check and generate shape
94        rstsr_assert_eq!(la.shape()[1], lb.shape()[0], InvalidLayout)?;
95        let sc = [la.shape()[0], lb.shape()[1]];
96        // layout order determination
97        let lc = match order {
98            RowMajor => sc.c(),
99            ColMajor => sc.f(),
100        };
101        // return layout configuration
102        Ok(LayoutMatMulConfig {
103            matmul_type: MatMulType::GEMM22,
104            lc: lc.clone(),
105            la_rest: None,
106            lb_rest: None,
107            lc_rest: None,
108            la_matmul: la.to_dim()?,
109            lb_matmul: lb.to_dim()?,
110            lc_matmul: lc.to_dim()?,
111        })
112    }
113}
114
115fn layout_matmul_dyn_row_major(
116    la: &Layout<IxD>,
117    lb: &Layout<IxD>,
118) -> Result<LayoutMatMulConfig<IxD, IxD>> {
119    let na = la.ndim();
120    let nb = lb.ndim();
121    match (na, nb) {
122        (1, 1) => {
123            // rule 1: vector inner dot
124            rstsr_assert_eq!(la.shape(), lb.shape(), InvalidLayout)?;
125            let lc = unsafe { Layout::new_unchecked(vec![], vec![], 0) };
126            Ok(LayoutMatMulConfig {
127                matmul_type: MatMulType::InnerDot,
128                lc: lc.clone(),
129                la_rest: None,
130                lb_rest: None,
131                lc_rest: None,
132                la_matmul: la.to_dim()?,
133                lb_matmul: lb.to_dim()?,
134                lc_matmul: lc.to_dim()?,
135            })
136        },
137        (2, 2) => {
138            // rule 2: matrix multiplication
139            // check and generate shape
140            rstsr_assert_eq!(la.shape()[1], lb.shape()[0], InvalidLayout)?;
141            let sc = vec![la.shape()[0], lb.shape()[1]];
142            // layout order determination
143            let lc = sc.c();
144            // return layout configuration
145            Ok(LayoutMatMulConfig {
146                matmul_type: MatMulType::GEMM22,
147                lc: lc.clone(),
148                la_rest: None,
149                lb_rest: None,
150                lc_rest: None,
151                la_matmul: la.to_dim()?,
152                lb_matmul: lb.to_dim()?,
153                lc_matmul: lc.to_dim()?,
154            })
155        },
156        (1, 2..) => {
157            // rule 3: | `        K` | `..., K, N` | `   ..., N` |
158            // check and generate shape
159            let (lb_rest, lb_matmul) = lb.dim_split_at(-2)?;
160            rstsr_assert_eq!(la.shape()[0], lb_matmul.shape()[0], InvalidLayout)?;
161            // layout order determination
162            let mut sc = lb_rest.shape().clone();
163            sc.push(lb_matmul.shape()[1]);
164            let lc = sc.c();
165            // return layout configuration
166            let (lc_rest, lc_matmul) = lc.dim_split_at(-1)?;
167            Ok(LayoutMatMulConfig {
168                matmul_type: MatMulType::GEVM,
169                lc: lc.to_dim()?,
170                la_rest: None,
171                lb_rest: Some(lb_rest),
172                lc_rest: Some(lc_rest),
173                la_matmul: la.to_dim()?,
174                lb_matmul: lb_matmul.to_dim()?,
175                lc_matmul: lc_matmul.to_dim()?,
176            })
177        },
178        (2.., 1) => {
179            // rule 4: | `..., M, K` | `        K` | `   ..., M` |
180            // check and generate shape
181            let (la_rest, la_matmul) = la.dim_split_at(-2)?;
182            rstsr_assert_eq!(lb.shape()[0], la_matmul.shape()[1], InvalidLayout)?;
183            // layout order determination
184            let mut sc = la_rest.shape().clone();
185            sc.push(la_matmul.shape()[0]);
186            let lc = sc.c();
187            // return layout configuration
188            let (lc_rest, lc_matmul) = lc.dim_split_at(-1)?;
189            Ok(LayoutMatMulConfig {
190                matmul_type: MatMulType::GEMV,
191                lc: lc.to_dim()?,
192                la_rest: Some(la_rest),
193                lb_rest: None,
194                lc_rest: Some(lc_rest),
195                la_matmul: la_matmul.to_dim()?,
196                lb_matmul: lb.to_dim()?,
197                lc_matmul: lc_matmul.to_dim()?,
198            })
199        },
200        (2, 3..) => {
201            // rule 5: | `     M, K` | `..., K, N` | `..., M, N` |
202            // check and generate shape
203            let (lb_rest, lb_matmul) = lb.dim_split_at(-2)?;
204            rstsr_assert_eq!(la.shape()[1], lb_matmul.shape()[0], InvalidLayout)?;
205            // layout order determination
206            let mut sc = lb_rest.shape().clone();
207            sc.append(&mut vec![la.shape()[0], lb_matmul.shape()[1]]);
208            let lc = sc.c();
209            // return layout configuration
210            let (lc_rest, lc_matmul) = lc.dim_split_at(-2)?;
211            Ok(LayoutMatMulConfig {
212                matmul_type: MatMulType::GEMM2X,
213                lc: lc.to_dim()?,
214                la_rest: None,
215                lb_rest: Some(lb_rest),
216                lc_rest: Some(lc_rest),
217                la_matmul: la.to_dim()?,
218                lb_matmul: lb_matmul.to_dim()?,
219                lc_matmul: lc_matmul.to_dim()?,
220            })
221        },
222        (3.., 2) => {
223            // rule 6: | `..., M, K` | `     K, N` | `..., M, N` |
224            // check and generate shape
225            let (la_rest, la_matmul) = la.dim_split_at(-2)?;
226            rstsr_assert_eq!(la_matmul.shape()[1], lb.shape()[0], InvalidLayout)?;
227            // layout order determination
228            let mut sc = la_rest.shape().clone();
229            sc.append(&mut vec![la_matmul.shape()[0], lb.shape()[1]]);
230            let lc = sc.c();
231            // return layout configuration
232            let (lc_rest, lc_matmul) = lc.dim_split_at(-2)?;
233            Ok(LayoutMatMulConfig {
234                matmul_type: MatMulType::GEMMX2,
235                lc: lc.to_dim()?,
236                la_rest: Some(la_rest),
237                lb_rest: None,
238                lc_rest: Some(lc_rest),
239                la_matmul: la_matmul.to_dim()?,
240                lb_matmul: lb.to_dim()?,
241                lc_matmul: lc_matmul.to_dim()?,
242            })
243        },
244        (3.., 3..) => {
245            // check and generate shape
246            let (la_rest, la_matmul) = la.dim_split_at(-2)?;
247            let (lb_rest, lb_matmul) = lb.dim_split_at(-2)?;
248            rstsr_assert_eq!(la_matmul.shape()[1], lb_matmul.shape()[0], InvalidLayout)?;
249            let (la_rest_b, lb_rest_b) = broadcast_layout(&la_rest, &lb_rest, RowMajor)?;
250            // layout order determination
251            let mut sc = la_rest_b.shape().clone();
252            sc.append(&mut vec![la_matmul.shape()[0], lb_matmul.shape()[1]]);
253            let lc = sc.c();
254            // return layout configuration
255            let (lc_rest, lc_matmul) = lc.dim_split_at(-2)?;
256            Ok(LayoutMatMulConfig {
257                matmul_type: MatMulType::GEMMXX,
258                lc: lc.to_dim()?,
259                la_rest: Some(la_rest_b),
260                lb_rest: Some(lb_rest_b),
261                lc_rest: Some(lc_rest),
262                la_matmul: la.to_dim()?,
263                lb_matmul: lb_matmul.to_dim()?,
264                lc_matmul: lc_matmul.to_dim()?,
265            })
266        },
267        (0, _) | (_, 0) => rstsr_invalid!((na, nb), "In matmul, 0-dim is not allowed."),
268    }
269}
270
271fn layout_matmul_dyn_col_major(
272    la: &Layout<IxD>,
273    lb: &Layout<IxD>,
274) -> Result<LayoutMatMulConfig<IxD, IxD>> {
275    let na = la.ndim();
276    let nb = lb.ndim();
277    match (na, nb) {
278        (1, 1) => {
279            // rule 1: vector inner dot
280            rstsr_assert_eq!(la.shape(), lb.shape(), InvalidLayout)?;
281            let lc = unsafe { Layout::new_unchecked(vec![], vec![], 0) };
282            Ok(LayoutMatMulConfig {
283                matmul_type: MatMulType::InnerDot,
284                lc: lc.clone(),
285                la_rest: None,
286                lb_rest: None,
287                lc_rest: None,
288                la_matmul: la.to_dim()?,
289                lb_matmul: lb.to_dim()?,
290                lc_matmul: lc.to_dim()?,
291            })
292        },
293        (2, 2) => {
294            // rule 2: matrix multiplication
295            // check and generate shape
296            rstsr_assert_eq!(la.shape()[1], lb.shape()[0], InvalidLayout)?;
297            let sc = vec![la.shape()[0], lb.shape()[1]];
298            // layout order determination
299            let lc = sc.f();
300            // return layout configuration
301            Ok(LayoutMatMulConfig {
302                matmul_type: MatMulType::GEMM22,
303                lc: lc.clone(),
304                la_rest: None,
305                lb_rest: None,
306                lc_rest: None,
307                la_matmul: la.to_dim()?,
308                lb_matmul: lb.to_dim()?,
309                lc_matmul: lc.to_dim()?,
310            })
311        },
312        (1, 2) => {
313            // rule 3: | `        K` | `     K, N` | `        N` |
314            // check and generate shape
315            rstsr_assert_eq!(la.shape()[0], lb.shape()[0], InvalidLayout)?;
316            let sc = vec![lb.shape()[1]];
317            let lc = sc.f();
318            Ok(LayoutMatMulConfig {
319                matmul_type: MatMulType::GEVM,
320                lc: lc.to_dim()?,
321                la_rest: None,
322                lb_rest: None,
323                lc_rest: None,
324                la_matmul: la.to_dim()?,
325                lb_matmul: lb.to_dim()?,
326                lc_matmul: lc.to_dim()?,
327            })
328        },
329        (2, 1) => {
330            // rule 4: | `     M, K` | `        K` | `        M` |
331            // check and generate shape
332            rstsr_assert_eq!(la.shape()[1], lb.shape()[0], InvalidLayout)?;
333            let sc = vec![la.shape()[0]];
334            let lc = sc.f();
335            // return layout configuration
336            Ok(LayoutMatMulConfig {
337                matmul_type: MatMulType::GEMV,
338                lc: lc.to_dim()?,
339                la_rest: None,
340                lb_rest: None,
341                lc_rest: None,
342                la_matmul: la.to_dim()?,
343                lb_matmul: lb.to_dim()?,
344                lc_matmul: lc.to_dim()?,
345            })
346        },
347        (1, 3..) | (3.., 1) | (2, 3..) | (3.., 2) | (3.., 3..) => {
348            rstsr_raise!(InvalidLayout, "Broadcasting matmul is not supported in col-major.")
349        },
350        (0, _) | (_, 0) => rstsr_invalid!((na, nb), "In matmul, 0-dim is not allowed."),
351    }
352}
353
354impl LayoutMatMulAPI<IxD, IxD> for LayoutMatMulConfig<IxD, IxD> {
355    type DC = IxD;
356    fn layout_matmul(la: &Layout<IxD>, lb: &Layout<IxD>, order: FlagOrder) -> Result<Self> {
357        match order {
358            RowMajor => layout_matmul_dyn_row_major(la, lb),
359            ColMajor => layout_matmul_dyn_col_major(la, lb),
360        }
361    }
362}
363
364macro_rules! impl_fixed {
365    ($DA:ident, $DB:ident, $DC:ident) => {
366        impl LayoutMatMulAPI<$DA, $DB> for LayoutMatMulConfig<$DA, $DB> {
367            type DC = $DC;
368            fn layout_matmul(la: &Layout<$DA>, lb: &Layout<$DB>, order: FlagOrder) -> Result<Self> {
369                let la = la.to_dim::<IxD>()?;
370                let lb = lb.to_dim::<IxD>()?;
371                let cfg = LayoutMatMulConfig::layout_matmul(&la, &lb, order)?;
372                return Ok(LayoutMatMulConfig {
373                    matmul_type: cfg.matmul_type,
374                    lc: cfg.lc.into_dim()?,
375                    la_rest: cfg.la_rest,
376                    lb_rest: cfg.lb_rest,
377                    lc_rest: cfg.lc_rest,
378                    la_matmul: cfg.la_matmul,
379                    lb_matmul: cfg.lb_matmul,
380                    lc_matmul: cfg.lc_matmul,
381                });
382            }
383        }
384    };
385}
386
387// rule 3
388impl_fixed!(Ix2, Ix1, Ix1);
389impl_fixed!(Ix3, Ix1, Ix2);
390impl_fixed!(Ix4, Ix1, Ix3);
391impl_fixed!(Ix5, Ix1, Ix4);
392impl_fixed!(Ix6, Ix1, Ix5);
393impl_fixed!(Ix7, Ix1, Ix6);
394impl_fixed!(Ix8, Ix1, Ix7);
395impl_fixed!(Ix9, Ix1, Ix8);
396
397// rule 4
398impl_fixed!(Ix1, Ix2, Ix1);
399impl_fixed!(Ix1, Ix3, Ix2);
400impl_fixed!(Ix1, Ix4, Ix3);
401impl_fixed!(Ix1, Ix5, Ix4);
402impl_fixed!(Ix1, Ix6, Ix5);
403impl_fixed!(Ix1, Ix7, Ix6);
404impl_fixed!(Ix1, Ix8, Ix7);
405impl_fixed!(Ix1, Ix9, Ix8);
406
407// rule 5
408impl_fixed!(Ix3, Ix2, Ix3);
409impl_fixed!(Ix4, Ix2, Ix4);
410impl_fixed!(Ix5, Ix2, Ix5);
411impl_fixed!(Ix6, Ix2, Ix6);
412impl_fixed!(Ix7, Ix2, Ix7);
413impl_fixed!(Ix8, Ix2, Ix8);
414impl_fixed!(Ix9, Ix2, Ix9);
415
416// rule 6
417impl_fixed!(Ix2, Ix3, Ix3);
418impl_fixed!(Ix2, Ix4, Ix4);
419impl_fixed!(Ix2, Ix5, Ix5);
420impl_fixed!(Ix2, Ix6, Ix6);
421impl_fixed!(Ix2, Ix7, Ix7);
422impl_fixed!(Ix2, Ix8, Ix8);
423impl_fixed!(Ix2, Ix9, Ix9);
424
425// rule 7
426impl_fixed!(Ix3, Ix3, Ix3);
427impl_fixed!(Ix4, Ix4, Ix4);
428impl_fixed!(Ix5, Ix5, Ix5);
429impl_fixed!(Ix6, Ix6, Ix6);
430impl_fixed!(Ix7, Ix7, Ix7);
431impl_fixed!(Ix8, Ix8, Ix8);
432impl_fixed!(Ix9, Ix9, Ix9);
433
434// partial fixed
435impl_fixed!(Ix1, IxD, IxD);
436impl_fixed!(Ix2, IxD, IxD);
437impl_fixed!(Ix3, IxD, IxD);
438impl_fixed!(Ix4, IxD, IxD);
439impl_fixed!(Ix5, IxD, IxD);
440impl_fixed!(Ix6, IxD, IxD);
441impl_fixed!(Ix7, IxD, IxD);
442impl_fixed!(Ix8, IxD, IxD);
443impl_fixed!(Ix9, IxD, IxD);
444
445impl_fixed!(IxD, Ix1, IxD);
446impl_fixed!(IxD, Ix2, IxD);
447impl_fixed!(IxD, Ix3, IxD);
448impl_fixed!(IxD, Ix4, IxD);
449impl_fixed!(IxD, Ix5, IxD);
450impl_fixed!(IxD, Ix6, IxD);
451impl_fixed!(IxD, Ix7, IxD);
452impl_fixed!(IxD, Ix8, IxD);
453impl_fixed!(IxD, Ix9, IxD);
454
455#[cfg(test)]
456mod test_fixed {
457
458    #[test]
459    fn test_layout_matmul() {
460        use super::*;
461        let la = [4].c();
462        let lb = [4].c();
463        let config = LayoutMatMulConfig::layout_matmul(&la, &lb, RowMajor).unwrap();
464        assert_eq!(config.matmul_type, MatMulType::InnerDot);
465        assert_eq!(config.lc.shape(), &[]);
466        assert_eq!(config.la_matmul.shape(), &[4]);
467        assert_eq!(config.lb_matmul.shape(), &[4]);
468
469        let la = [5].c();
470        let lb = [3, 4, 5, 6].f().swapaxes(0, 1).unwrap();
471        let config = LayoutMatMulConfig::layout_matmul(&la, &lb, RowMajor).unwrap();
472        assert_eq!(config.lc, [4, 3, 6].c());
473
474        let la = [3, 4, 5, 6].f().swapaxes(0, 1).unwrap();
475        let lb = [6].c();
476        let config = LayoutMatMulConfig::layout_matmul(&la, &lb, RowMajor).unwrap();
477        assert_eq!(config.lc, [4, 3, 5].c());
478
479        let la = [7, 6].c();
480        let lb = [2, 3, 4, 5, 6].f().swapaxes(-1, -2).unwrap();
481        let config = LayoutMatMulConfig::layout_matmul(&la, &lb, RowMajor).unwrap();
482        assert_eq!(config.lc, [2, 3, 4, 7, 5].c());
483
484        let la = [2, 3, 4, 5, 6].f().swapaxes(-1, -2).unwrap();
485        let lb = [5, 7].c();
486        let config = LayoutMatMulConfig::layout_matmul(&la, &lb, RowMajor).unwrap();
487        assert_eq!(config.lc, [2, 3, 4, 6, 7].c());
488
489        let la = [4, 1, 2, 5, 6].f().swapaxes(0, 2).unwrap();
490        let lb = [4, 3, 1, 6, 7].f().swapaxes(0, 2).unwrap();
491        let config = LayoutMatMulConfig::layout_matmul(&la, &lb, RowMajor).unwrap();
492        assert_eq!(config.lc, [2, 3, 4, 5, 7].c());
493
494        let la = [4, 3, 2, 5, 6].f().swapaxes(0, 2).unwrap();
495        let lb = [4, 3, 2, 6, 7].f().swapaxes(0, 2).unwrap();
496        let config = LayoutMatMulConfig::layout_matmul(&la, &lb, RowMajor).unwrap();
497        assert_eq!(config.lc, [2, 3, 4, 5, 7].c());
498
499        let la = [4, 3, 2, 5, 6].f().swapaxes(0, 2).unwrap();
500        let lb = [4, 3, 2, 6, 7].f().swapaxes(0, 2).unwrap();
501        let config = LayoutMatMulConfig::layout_matmul(&la, &lb, ColMajor);
502        assert!(config.is_err());
503
504        let la = [5, 6].c();
505        let lb = [6, 7].c();
506        let config = LayoutMatMulConfig::layout_matmul(&la, &lb, ColMajor).unwrap();
507        assert_eq!(config.lc, [5, 7].f());
508    }
509}