rstsr_common/layout/
matmul.rs

1/*!
2
3Layout manuplication for matmul and other linalg operations
4
5# Rules for matmul
6
7We refer [Python array API](https://data-apis.org/array-api/2023.12/specification/generated/array_api.matmul.html) for more information.
8
9Please note that the following rule only applies to row-major.
10
11| Id | A | B | C |
12|----|---|---|---|
13| 1. | `        N` | `        N` | `         ` |
14| 2. | `     M, K` | `     K, N` | `     M, N` |
15| 3. | `        K` | `..., K, N` | `   ..., N` |
16| 4. | `..., M, K` | `        K` | `   ..., M` |
17| 5. | `     M, K` | `..., K, N` | `..., M, N` |
18| 6. | `..., M, K` | `     K, N` | `..., M, N` |
19| 7. | `..., M, K` | `..., K, N` | `..., M, N` |
20
21For col-major, only rule 1, 2, (part of) 3, (part of) 4 are valid.
22
23*/
24
25use crate::prelude_dev::*;
26
27/// Rules of matmul.
28#[derive(Clone, Copy, Debug, PartialEq, Eq)]
29pub enum MatMulType {
30    InnerDot,
31    GEMM22,
32    GEVM,
33    GEMV,
34    GEMM2X,
35    GEMMX2,
36    GEMMXX,
37}
38
39#[derive(Clone, Debug)]
40pub struct LayoutMatMulConfig<DA, DB>
41where
42    DA: DimAPI,
43    DB: DimAPI,
44    Self: LayoutMatMulAPI<DA, DB>,
45{
46    pub matmul_type: MatMulType,
47    pub lc: Layout<<Self as LayoutMatMulAPI<DA, DB>>::DC>,
48    pub la_rest: Option<Layout<IxD>>,
49    pub lb_rest: Option<Layout<IxD>>,
50    pub lc_rest: Option<Layout<IxD>>,
51    pub la_matmul: Layout<IxD>,
52    pub lb_matmul: Layout<IxD>,
53    pub lc_matmul: Layout<IxD>,
54}
55
56pub trait LayoutMatMulAPI<DA, DB>
57where
58    DA: DimAPI,
59    DB: DimAPI,
60    Self: Sized,
61{
62    type DC: DimAPI;
63    /// Layout configuration for matmul.
64    ///
65    /// For order, currently we only accept deterministic order.
66    fn layout_matmul(la: &Layout<DA>, lb: &Layout<DB>, order: FlagOrder) -> Result<Self>;
67}
68
69// rule 1
70impl LayoutMatMulAPI<Ix1, Ix1> for LayoutMatMulConfig<Ix1, Ix1> {
71    type DC = Ix0;
72    fn layout_matmul(la: &Layout<Ix1>, lb: &Layout<Ix1>, _: FlagOrder) -> Result<Self> {
73        // check shape
74        rstsr_assert_eq!(la.shape(), lb.shape(), InvalidLayout)?;
75        let lc = unsafe { Layout::new_unchecked([], [], 0) };
76        Ok(LayoutMatMulConfig {
77            matmul_type: MatMulType::InnerDot,
78            lc: lc.clone(),
79            la_rest: None,
80            lb_rest: None,
81            lc_rest: None,
82            la_matmul: la.to_dim()?,
83            lb_matmul: lb.to_dim()?,
84            lc_matmul: lc.to_dim()?,
85        })
86    }
87}
88
89// rule 2
90impl LayoutMatMulAPI<Ix2, Ix2> for LayoutMatMulConfig<Ix2, Ix2> {
91    type DC = Ix2;
92    fn layout_matmul(la: &Layout<Ix2>, lb: &Layout<Ix2>, order: FlagOrder) -> Result<Self> {
93        // check and generate shape
94        rstsr_assert_eq!(la.shape()[1], lb.shape()[0], InvalidLayout)?;
95        let sc = [la.shape()[0], lb.shape()[1]];
96        // layout order determination
97        let lc = match order {
98            RowMajor => sc.c(),
99            ColMajor => sc.f(),
100        };
101        // return layout configuration
102        Ok(LayoutMatMulConfig {
103            matmul_type: MatMulType::GEMM22,
104            lc: lc.clone(),
105            la_rest: None,
106            lb_rest: None,
107            lc_rest: None,
108            la_matmul: la.to_dim()?,
109            lb_matmul: lb.to_dim()?,
110            lc_matmul: lc.to_dim()?,
111        })
112    }
113}
114
115fn layout_matmul_dyn_row_major(la: &Layout<IxD>, lb: &Layout<IxD>) -> Result<LayoutMatMulConfig<IxD, IxD>> {
116    let na = la.ndim();
117    let nb = lb.ndim();
118    match (na, nb) {
119        (1, 1) => {
120            // rule 1: vector inner dot
121            rstsr_assert_eq!(la.shape(), lb.shape(), InvalidLayout)?;
122            let lc = unsafe { Layout::new_unchecked(vec![], vec![], 0) };
123            Ok(LayoutMatMulConfig {
124                matmul_type: MatMulType::InnerDot,
125                lc: lc.clone(),
126                la_rest: None,
127                lb_rest: None,
128                lc_rest: None,
129                la_matmul: la.to_dim()?,
130                lb_matmul: lb.to_dim()?,
131                lc_matmul: lc.to_dim()?,
132            })
133        },
134        (2, 2) => {
135            // rule 2: matrix multiplication
136            // check and generate shape
137            rstsr_assert_eq!(la.shape()[1], lb.shape()[0], InvalidLayout)?;
138            let sc = vec![la.shape()[0], lb.shape()[1]];
139            // layout order determination
140            let lc = sc.c();
141            // return layout configuration
142            Ok(LayoutMatMulConfig {
143                matmul_type: MatMulType::GEMM22,
144                lc: lc.clone(),
145                la_rest: None,
146                lb_rest: None,
147                lc_rest: None,
148                la_matmul: la.to_dim()?,
149                lb_matmul: lb.to_dim()?,
150                lc_matmul: lc.to_dim()?,
151            })
152        },
153        (1, 2..) => {
154            // rule 3: | `        K` | `..., K, N` | `   ..., N` |
155            // check and generate shape
156            let (lb_rest, lb_matmul) = lb.dim_split_at(-2)?;
157            rstsr_assert_eq!(la.shape()[0], lb_matmul.shape()[0], InvalidLayout)?;
158            // layout order determination
159            let mut sc = lb_rest.shape().clone();
160            sc.push(lb_matmul.shape()[1]);
161            let lc = sc.c();
162            // return layout configuration
163            let (lc_rest, lc_matmul) = lc.dim_split_at(-1)?;
164            Ok(LayoutMatMulConfig {
165                matmul_type: MatMulType::GEVM,
166                lc: lc.to_dim()?,
167                la_rest: None,
168                lb_rest: Some(lb_rest),
169                lc_rest: Some(lc_rest),
170                la_matmul: la.to_dim()?,
171                lb_matmul: lb_matmul.to_dim()?,
172                lc_matmul: lc_matmul.to_dim()?,
173            })
174        },
175        (2.., 1) => {
176            // rule 4: | `..., M, K` | `        K` | `   ..., M` |
177            // check and generate shape
178            let (la_rest, la_matmul) = la.dim_split_at(-2)?;
179            rstsr_assert_eq!(lb.shape()[0], la_matmul.shape()[1], InvalidLayout)?;
180            // layout order determination
181            let mut sc = la_rest.shape().clone();
182            sc.push(la_matmul.shape()[0]);
183            let lc = sc.c();
184            // return layout configuration
185            let (lc_rest, lc_matmul) = lc.dim_split_at(-1)?;
186            Ok(LayoutMatMulConfig {
187                matmul_type: MatMulType::GEMV,
188                lc: lc.to_dim()?,
189                la_rest: Some(la_rest),
190                lb_rest: None,
191                lc_rest: Some(lc_rest),
192                la_matmul: la_matmul.to_dim()?,
193                lb_matmul: lb.to_dim()?,
194                lc_matmul: lc_matmul.to_dim()?,
195            })
196        },
197        (2, 3..) => {
198            // rule 5: | `     M, K` | `..., K, N` | `..., M, N` |
199            // check and generate shape
200            let (lb_rest, lb_matmul) = lb.dim_split_at(-2)?;
201            rstsr_assert_eq!(la.shape()[1], lb_matmul.shape()[0], InvalidLayout)?;
202            // layout order determination
203            let mut sc = lb_rest.shape().clone();
204            sc.append(&mut vec![la.shape()[0], lb_matmul.shape()[1]]);
205            let lc = sc.c();
206            // return layout configuration
207            let (lc_rest, lc_matmul) = lc.dim_split_at(-2)?;
208            Ok(LayoutMatMulConfig {
209                matmul_type: MatMulType::GEMM2X,
210                lc: lc.to_dim()?,
211                la_rest: None,
212                lb_rest: Some(lb_rest),
213                lc_rest: Some(lc_rest),
214                la_matmul: la.to_dim()?,
215                lb_matmul: lb_matmul.to_dim()?,
216                lc_matmul: lc_matmul.to_dim()?,
217            })
218        },
219        (3.., 2) => {
220            // rule 6: | `..., M, K` | `     K, N` | `..., M, N` |
221            // check and generate shape
222            let (la_rest, la_matmul) = la.dim_split_at(-2)?;
223            rstsr_assert_eq!(la_matmul.shape()[1], lb.shape()[0], InvalidLayout)?;
224            // layout order determination
225            let mut sc = la_rest.shape().clone();
226            sc.append(&mut vec![la_matmul.shape()[0], lb.shape()[1]]);
227            let lc = sc.c();
228            // return layout configuration
229            let (lc_rest, lc_matmul) = lc.dim_split_at(-2)?;
230            Ok(LayoutMatMulConfig {
231                matmul_type: MatMulType::GEMMX2,
232                lc: lc.to_dim()?,
233                la_rest: Some(la_rest),
234                lb_rest: None,
235                lc_rest: Some(lc_rest),
236                la_matmul: la_matmul.to_dim()?,
237                lb_matmul: lb.to_dim()?,
238                lc_matmul: lc_matmul.to_dim()?,
239            })
240        },
241        (3.., 3..) => {
242            // check and generate shape
243            let (la_rest, la_matmul) = la.dim_split_at(-2)?;
244            let (lb_rest, lb_matmul) = lb.dim_split_at(-2)?;
245            rstsr_assert_eq!(la_matmul.shape()[1], lb_matmul.shape()[0], InvalidLayout)?;
246            let (la_rest_b, lb_rest_b) = broadcast_layout(&la_rest, &lb_rest, RowMajor)?;
247            // layout order determination
248            let mut sc = la_rest_b.shape().clone();
249            sc.append(&mut vec![la_matmul.shape()[0], lb_matmul.shape()[1]]);
250            let lc = sc.c();
251            // return layout configuration
252            let (lc_rest, lc_matmul) = lc.dim_split_at(-2)?;
253            Ok(LayoutMatMulConfig {
254                matmul_type: MatMulType::GEMMXX,
255                lc: lc.to_dim()?,
256                la_rest: Some(la_rest_b),
257                lb_rest: Some(lb_rest_b),
258                lc_rest: Some(lc_rest),
259                la_matmul: la.to_dim()?,
260                lb_matmul: lb_matmul.to_dim()?,
261                lc_matmul: lc_matmul.to_dim()?,
262            })
263        },
264        (0, _) | (_, 0) => rstsr_invalid!((na, nb), "In matmul, 0-dim is not allowed."),
265    }
266}
267
268fn layout_matmul_dyn_col_major(la: &Layout<IxD>, lb: &Layout<IxD>) -> Result<LayoutMatMulConfig<IxD, IxD>> {
269    let na = la.ndim();
270    let nb = lb.ndim();
271    match (na, nb) {
272        (1, 1) => {
273            // rule 1: vector inner dot
274            rstsr_assert_eq!(la.shape(), lb.shape(), InvalidLayout)?;
275            let lc = unsafe { Layout::new_unchecked(vec![], vec![], 0) };
276            Ok(LayoutMatMulConfig {
277                matmul_type: MatMulType::InnerDot,
278                lc: lc.clone(),
279                la_rest: None,
280                lb_rest: None,
281                lc_rest: None,
282                la_matmul: la.to_dim()?,
283                lb_matmul: lb.to_dim()?,
284                lc_matmul: lc.to_dim()?,
285            })
286        },
287        (2, 2) => {
288            // rule 2: matrix multiplication
289            // check and generate shape
290            rstsr_assert_eq!(la.shape()[1], lb.shape()[0], InvalidLayout)?;
291            let sc = vec![la.shape()[0], lb.shape()[1]];
292            // layout order determination
293            let lc = sc.f();
294            // return layout configuration
295            Ok(LayoutMatMulConfig {
296                matmul_type: MatMulType::GEMM22,
297                lc: lc.clone(),
298                la_rest: None,
299                lb_rest: None,
300                lc_rest: None,
301                la_matmul: la.to_dim()?,
302                lb_matmul: lb.to_dim()?,
303                lc_matmul: lc.to_dim()?,
304            })
305        },
306        (1, 2) => {
307            // rule 3: | `        K` | `     K, N` | `        N` |
308            // check and generate shape
309            rstsr_assert_eq!(la.shape()[0], lb.shape()[0], InvalidLayout)?;
310            let sc = vec![lb.shape()[1]];
311            let lc = sc.f();
312            Ok(LayoutMatMulConfig {
313                matmul_type: MatMulType::GEVM,
314                lc: lc.to_dim()?,
315                la_rest: None,
316                lb_rest: None,
317                lc_rest: None,
318                la_matmul: la.to_dim()?,
319                lb_matmul: lb.to_dim()?,
320                lc_matmul: lc.to_dim()?,
321            })
322        },
323        (2, 1) => {
324            // rule 4: | `     M, K` | `        K` | `        M` |
325            // check and generate shape
326            rstsr_assert_eq!(la.shape()[1], lb.shape()[0], InvalidLayout)?;
327            let sc = vec![la.shape()[0]];
328            let lc = sc.f();
329            // return layout configuration
330            Ok(LayoutMatMulConfig {
331                matmul_type: MatMulType::GEMV,
332                lc: lc.to_dim()?,
333                la_rest: None,
334                lb_rest: None,
335                lc_rest: None,
336                la_matmul: la.to_dim()?,
337                lb_matmul: lb.to_dim()?,
338                lc_matmul: lc.to_dim()?,
339            })
340        },
341        (1, 3..) | (3.., 1) | (2, 3..) | (3.., 2) | (3.., 3..) => {
342            rstsr_raise!(InvalidLayout, "Broadcasting matmul is not supported in col-major.")
343        },
344        (0, _) | (_, 0) => rstsr_invalid!((na, nb), "In matmul, 0-dim is not allowed."),
345    }
346}
347
348impl LayoutMatMulAPI<IxD, IxD> for LayoutMatMulConfig<IxD, IxD> {
349    type DC = IxD;
350    fn layout_matmul(la: &Layout<IxD>, lb: &Layout<IxD>, order: FlagOrder) -> Result<Self> {
351        match order {
352            RowMajor => layout_matmul_dyn_row_major(la, lb),
353            ColMajor => layout_matmul_dyn_col_major(la, lb),
354        }
355    }
356}
357
358macro_rules! impl_fixed {
359    ($DA:ident, $DB:ident, $DC:ident) => {
360        impl LayoutMatMulAPI<$DA, $DB> for LayoutMatMulConfig<$DA, $DB> {
361            type DC = $DC;
362            fn layout_matmul(la: &Layout<$DA>, lb: &Layout<$DB>, order: FlagOrder) -> Result<Self> {
363                let la = la.to_dim::<IxD>()?;
364                let lb = lb.to_dim::<IxD>()?;
365                let cfg = LayoutMatMulConfig::layout_matmul(&la, &lb, order)?;
366                return Ok(LayoutMatMulConfig {
367                    matmul_type: cfg.matmul_type,
368                    lc: cfg.lc.into_dim()?,
369                    la_rest: cfg.la_rest,
370                    lb_rest: cfg.lb_rest,
371                    lc_rest: cfg.lc_rest,
372                    la_matmul: cfg.la_matmul,
373                    lb_matmul: cfg.lb_matmul,
374                    lc_matmul: cfg.lc_matmul,
375                });
376            }
377        }
378    };
379}
380
381// rule 3
382impl_fixed!(Ix2, Ix1, Ix1);
383impl_fixed!(Ix3, Ix1, Ix2);
384impl_fixed!(Ix4, Ix1, Ix3);
385impl_fixed!(Ix5, Ix1, Ix4);
386impl_fixed!(Ix6, Ix1, Ix5);
387impl_fixed!(Ix7, Ix1, Ix6);
388impl_fixed!(Ix8, Ix1, Ix7);
389impl_fixed!(Ix9, Ix1, Ix8);
390
391// rule 4
392impl_fixed!(Ix1, Ix2, Ix1);
393impl_fixed!(Ix1, Ix3, Ix2);
394impl_fixed!(Ix1, Ix4, Ix3);
395impl_fixed!(Ix1, Ix5, Ix4);
396impl_fixed!(Ix1, Ix6, Ix5);
397impl_fixed!(Ix1, Ix7, Ix6);
398impl_fixed!(Ix1, Ix8, Ix7);
399impl_fixed!(Ix1, Ix9, Ix8);
400
401// rule 5
402impl_fixed!(Ix3, Ix2, Ix3);
403impl_fixed!(Ix4, Ix2, Ix4);
404impl_fixed!(Ix5, Ix2, Ix5);
405impl_fixed!(Ix6, Ix2, Ix6);
406impl_fixed!(Ix7, Ix2, Ix7);
407impl_fixed!(Ix8, Ix2, Ix8);
408impl_fixed!(Ix9, Ix2, Ix9);
409
410// rule 6
411impl_fixed!(Ix2, Ix3, Ix3);
412impl_fixed!(Ix2, Ix4, Ix4);
413impl_fixed!(Ix2, Ix5, Ix5);
414impl_fixed!(Ix2, Ix6, Ix6);
415impl_fixed!(Ix2, Ix7, Ix7);
416impl_fixed!(Ix2, Ix8, Ix8);
417impl_fixed!(Ix2, Ix9, Ix9);
418
419// rule 7
420impl_fixed!(Ix3, Ix3, Ix3);
421impl_fixed!(Ix4, Ix4, Ix4);
422impl_fixed!(Ix5, Ix5, Ix5);
423impl_fixed!(Ix6, Ix6, Ix6);
424impl_fixed!(Ix7, Ix7, Ix7);
425impl_fixed!(Ix8, Ix8, Ix8);
426impl_fixed!(Ix9, Ix9, Ix9);
427
428// partial fixed
429impl_fixed!(Ix1, IxD, IxD);
430impl_fixed!(Ix2, IxD, IxD);
431impl_fixed!(Ix3, IxD, IxD);
432impl_fixed!(Ix4, IxD, IxD);
433impl_fixed!(Ix5, IxD, IxD);
434impl_fixed!(Ix6, IxD, IxD);
435impl_fixed!(Ix7, IxD, IxD);
436impl_fixed!(Ix8, IxD, IxD);
437impl_fixed!(Ix9, IxD, IxD);
438
439impl_fixed!(IxD, Ix1, IxD);
440impl_fixed!(IxD, Ix2, IxD);
441impl_fixed!(IxD, Ix3, IxD);
442impl_fixed!(IxD, Ix4, IxD);
443impl_fixed!(IxD, Ix5, IxD);
444impl_fixed!(IxD, Ix6, IxD);
445impl_fixed!(IxD, Ix7, IxD);
446impl_fixed!(IxD, Ix8, IxD);
447impl_fixed!(IxD, Ix9, IxD);
448
449#[cfg(test)]
450mod test_fixed {
451
452    #[test]
453    fn test_layout_matmul() {
454        use super::*;
455        let la = [4].c();
456        let lb = [4].c();
457        let config = LayoutMatMulConfig::layout_matmul(&la, &lb, RowMajor).unwrap();
458        assert_eq!(config.matmul_type, MatMulType::InnerDot);
459        assert_eq!(config.lc.shape(), &[]);
460        assert_eq!(config.la_matmul.shape(), &[4]);
461        assert_eq!(config.lb_matmul.shape(), &[4]);
462
463        let la = [5].c();
464        let lb = [3, 4, 5, 6].f().swapaxes(0, 1).unwrap();
465        let config = LayoutMatMulConfig::layout_matmul(&la, &lb, RowMajor).unwrap();
466        assert_eq!(config.lc, [4, 3, 6].c());
467
468        let la = [3, 4, 5, 6].f().swapaxes(0, 1).unwrap();
469        let lb = [6].c();
470        let config = LayoutMatMulConfig::layout_matmul(&la, &lb, RowMajor).unwrap();
471        assert_eq!(config.lc, [4, 3, 5].c());
472
473        let la = [7, 6].c();
474        let lb = [2, 3, 4, 5, 6].f().swapaxes(-1, -2).unwrap();
475        let config = LayoutMatMulConfig::layout_matmul(&la, &lb, RowMajor).unwrap();
476        assert_eq!(config.lc, [2, 3, 4, 7, 5].c());
477
478        let la = [2, 3, 4, 5, 6].f().swapaxes(-1, -2).unwrap();
479        let lb = [5, 7].c();
480        let config = LayoutMatMulConfig::layout_matmul(&la, &lb, RowMajor).unwrap();
481        assert_eq!(config.lc, [2, 3, 4, 6, 7].c());
482
483        let la = [4, 1, 2, 5, 6].f().swapaxes(0, 2).unwrap();
484        let lb = [4, 3, 1, 6, 7].f().swapaxes(0, 2).unwrap();
485        let config = LayoutMatMulConfig::layout_matmul(&la, &lb, RowMajor).unwrap();
486        assert_eq!(config.lc, [2, 3, 4, 5, 7].c());
487
488        let la = [4, 3, 2, 5, 6].f().swapaxes(0, 2).unwrap();
489        let lb = [4, 3, 2, 6, 7].f().swapaxes(0, 2).unwrap();
490        let config = LayoutMatMulConfig::layout_matmul(&la, &lb, RowMajor).unwrap();
491        assert_eq!(config.lc, [2, 3, 4, 5, 7].c());
492
493        let la = [4, 3, 2, 5, 6].f().swapaxes(0, 2).unwrap();
494        let lb = [4, 3, 2, 6, 7].f().swapaxes(0, 2).unwrap();
495        let config = LayoutMatMulConfig::layout_matmul(&la, &lb, ColMajor);
496        assert!(config.is_err());
497
498        let la = [5, 6].c();
499        let lb = [6, 7].c();
500        let config = LayoutMatMulConfig::layout_matmul(&la, &lb, ColMajor).unwrap();
501        assert_eq!(config.lc, [5, 7].f());
502    }
503}