Skip to main content

rstsr_common/layout/
matmul.rs

1/*!
2
3Layout manuplication for matmul and other linalg operations
4
5# Rules for matmul
6
7We Refer to [Python array API](https://data-apis.org/array-api/2024.12/specification/generated/array_api.matmul.html) for more information.
8
9The rules below are written for row-major; the last two axes of each operand
10are the matmul dimensions and any leading axes broadcast.
11
12| Id | A | B | C |
13|----|---|---|---|
14| 1. | `        N` | `        N` | `         ` |
15| 2. | `     M, K` | `     K, N` | `     M, N` |
16| 3. | `        K` | `..., K, N` | `   ..., N` |
17| 4. | `..., M, K` | `        K` | `   ..., M` |
18| 5. | `     M, K` | `..., K, N` | `..., M, N` |
19| 6. | `..., M, K` | `     K, N` | `..., M, N` |
20| 7. | `..., M, K` | `..., K, N` | `..., M, N` |
21
22For col-major, the same rules apply *with all axes reversed*: the matmul
23dimensions are the first two of each operand and any trailing axes broadcast.
24This is implemented by delegating to the row-major routine on reversed-and-
25swapped inputs, using the identity
26`C[t, m, n] = A[t, m, k] @ B[t, k, n]` (row-major) `==`
27`C[n, m, t] = B[n, k, t] @ A[k, m, t]` (col-major).
28
29*/
30
31use crate::prelude_dev::*;
32
33/// Rules of matmul.
34#[derive(Clone, Copy, Debug, PartialEq, Eq)]
35pub enum MatMulType {
36    InnerDot,
37    GEMM22,
38    GEVM,
39    GEMV,
40    GEMM2X,
41    GEMMX2,
42    GEMMXX,
43}
44
45#[derive(Clone, Debug)]
46pub struct LayoutMatMulConfig<DA, DB>
47where
48    DA: DimAPI,
49    DB: DimAPI,
50    Self: LayoutMatMulAPI<DA, DB>,
51{
52    pub matmul_type: MatMulType,
53    pub lc: Layout<<Self as LayoutMatMulAPI<DA, DB>>::DC>,
54    pub la_rest: Option<Layout<IxD>>,
55    pub lb_rest: Option<Layout<IxD>>,
56    pub lc_rest: Option<Layout<IxD>>,
57    pub la_matmul: Layout<IxD>,
58    pub lb_matmul: Layout<IxD>,
59    pub lc_matmul: Layout<IxD>,
60}
61
62pub trait LayoutMatMulAPI<DA, DB>
63where
64    DA: DimAPI,
65    DB: DimAPI,
66    Self: Sized,
67{
68    type DC: DimAPI;
69    /// Layout configuration for matmul.
70    ///
71    /// For order, currently we only accept deterministic order.
72    fn layout_matmul(la: &Layout<DA>, lb: &Layout<DB>, order: FlagOrder) -> Result<Self>;
73}
74
75// rule 1
76impl LayoutMatMulAPI<Ix1, Ix1> for LayoutMatMulConfig<Ix1, Ix1> {
77    type DC = Ix0;
78    fn layout_matmul(la: &Layout<Ix1>, lb: &Layout<Ix1>, _: FlagOrder) -> Result<Self> {
79        // check shape
80        rstsr_assert_eq!(la.shape(), lb.shape(), InvalidLayout)?;
81        let lc = unsafe { Layout::new_unchecked([], [], 0) };
82        Ok(LayoutMatMulConfig {
83            matmul_type: MatMulType::InnerDot,
84            lc: lc.clone(),
85            la_rest: None,
86            lb_rest: None,
87            lc_rest: None,
88            la_matmul: la.to_dim()?,
89            lb_matmul: lb.to_dim()?,
90            lc_matmul: lc.to_dim()?,
91        })
92    }
93}
94
95// rule 2
96impl LayoutMatMulAPI<Ix2, Ix2> for LayoutMatMulConfig<Ix2, Ix2> {
97    type DC = Ix2;
98    fn layout_matmul(la: &Layout<Ix2>, lb: &Layout<Ix2>, order: FlagOrder) -> Result<Self> {
99        // check and generate shape
100        rstsr_assert_eq!(la.shape()[1], lb.shape()[0], InvalidLayout)?;
101        let sc = [la.shape()[0], lb.shape()[1]];
102        // layout order determination
103        let lc = match order {
104            RowMajor => sc.c(),
105            ColMajor => sc.f(),
106        };
107        // return layout configuration
108        Ok(LayoutMatMulConfig {
109            matmul_type: MatMulType::GEMM22,
110            lc: lc.clone(),
111            la_rest: None,
112            lb_rest: None,
113            lc_rest: None,
114            la_matmul: la.to_dim()?,
115            lb_matmul: lb.to_dim()?,
116            lc_matmul: lc.to_dim()?,
117        })
118    }
119}
120
121fn layout_matmul_dyn_row_major(la: &Layout<IxD>, lb: &Layout<IxD>) -> Result<LayoutMatMulConfig<IxD, IxD>> {
122    let na = la.ndim();
123    let nb = lb.ndim();
124    match (na, nb) {
125        (1, 1) => {
126            // rule 1: vector inner dot
127            rstsr_assert_eq!(la.shape(), lb.shape(), InvalidLayout)?;
128            let lc = unsafe { Layout::new_unchecked(vec![], vec![], 0) };
129            Ok(LayoutMatMulConfig {
130                matmul_type: MatMulType::InnerDot,
131                lc: lc.clone(),
132                la_rest: None,
133                lb_rest: None,
134                lc_rest: None,
135                la_matmul: la.to_dim()?,
136                lb_matmul: lb.to_dim()?,
137                lc_matmul: lc.to_dim()?,
138            })
139        },
140        (2, 2) => {
141            // rule 2: matrix multiplication
142            // check and generate shape
143            rstsr_assert_eq!(la.shape()[1], lb.shape()[0], InvalidLayout)?;
144            let sc = vec![la.shape()[0], lb.shape()[1]];
145            // layout order determination
146            let lc = sc.c();
147            // return layout configuration
148            Ok(LayoutMatMulConfig {
149                matmul_type: MatMulType::GEMM22,
150                lc: lc.clone(),
151                la_rest: None,
152                lb_rest: None,
153                lc_rest: None,
154                la_matmul: la.to_dim()?,
155                lb_matmul: lb.to_dim()?,
156                lc_matmul: lc.to_dim()?,
157            })
158        },
159        (1, 2..) => {
160            // rule 3: | `        K` | `..., K, N` | `   ..., N` |
161            // check and generate shape
162            let (lb_rest, lb_matmul) = lb.dim_split_at(-2)?;
163            rstsr_assert_eq!(la.shape()[0], lb_matmul.shape()[0], InvalidLayout)?;
164            // layout order determination
165            let mut sc = lb_rest.shape().clone();
166            sc.push(lb_matmul.shape()[1]);
167            let lc = sc.c();
168            // return layout configuration
169            let (lc_rest, lc_matmul) = lc.dim_split_at(-1)?;
170            Ok(LayoutMatMulConfig {
171                matmul_type: MatMulType::GEVM,
172                lc: lc.to_dim()?,
173                la_rest: None,
174                lb_rest: Some(lb_rest),
175                lc_rest: Some(lc_rest),
176                la_matmul: la.to_dim()?,
177                lb_matmul: lb_matmul.to_dim()?,
178                lc_matmul: lc_matmul.to_dim()?,
179            })
180        },
181        (2.., 1) => {
182            // rule 4: | `..., M, K` | `        K` | `   ..., M` |
183            // check and generate shape
184            let (la_rest, la_matmul) = la.dim_split_at(-2)?;
185            rstsr_assert_eq!(lb.shape()[0], la_matmul.shape()[1], InvalidLayout)?;
186            // layout order determination
187            let mut sc = la_rest.shape().clone();
188            sc.push(la_matmul.shape()[0]);
189            let lc = sc.c();
190            // return layout configuration
191            let (lc_rest, lc_matmul) = lc.dim_split_at(-1)?;
192            Ok(LayoutMatMulConfig {
193                matmul_type: MatMulType::GEMV,
194                lc: lc.to_dim()?,
195                la_rest: Some(la_rest),
196                lb_rest: None,
197                lc_rest: Some(lc_rest),
198                la_matmul: la_matmul.to_dim()?,
199                lb_matmul: lb.to_dim()?,
200                lc_matmul: lc_matmul.to_dim()?,
201            })
202        },
203        (2, 3..) => {
204            // rule 5: | `     M, K` | `..., K, N` | `..., M, N` |
205            // check and generate shape
206            let (lb_rest, lb_matmul) = lb.dim_split_at(-2)?;
207            rstsr_assert_eq!(la.shape()[1], lb_matmul.shape()[0], InvalidLayout)?;
208            // layout order determination
209            let mut sc = lb_rest.shape().clone();
210            sc.append(&mut vec![la.shape()[0], lb_matmul.shape()[1]]);
211            let lc = sc.c();
212            // return layout configuration
213            let (lc_rest, lc_matmul) = lc.dim_split_at(-2)?;
214            Ok(LayoutMatMulConfig {
215                matmul_type: MatMulType::GEMM2X,
216                lc: lc.to_dim()?,
217                la_rest: None,
218                lb_rest: Some(lb_rest),
219                lc_rest: Some(lc_rest),
220                la_matmul: la.to_dim()?,
221                lb_matmul: lb_matmul.to_dim()?,
222                lc_matmul: lc_matmul.to_dim()?,
223            })
224        },
225        (3.., 2) => {
226            // rule 6: | `..., M, K` | `     K, N` | `..., M, N` |
227            // check and generate shape
228            let (la_rest, la_matmul) = la.dim_split_at(-2)?;
229            rstsr_assert_eq!(la_matmul.shape()[1], lb.shape()[0], InvalidLayout)?;
230            // layout order determination
231            let mut sc = la_rest.shape().clone();
232            sc.append(&mut vec![la_matmul.shape()[0], lb.shape()[1]]);
233            let lc = sc.c();
234            // return layout configuration
235            let (lc_rest, lc_matmul) = lc.dim_split_at(-2)?;
236            Ok(LayoutMatMulConfig {
237                matmul_type: MatMulType::GEMMX2,
238                lc: lc.to_dim()?,
239                la_rest: Some(la_rest),
240                lb_rest: None,
241                lc_rest: Some(lc_rest),
242                la_matmul: la_matmul.to_dim()?,
243                lb_matmul: lb.to_dim()?,
244                lc_matmul: lc_matmul.to_dim()?,
245            })
246        },
247        (3.., 3..) => {
248            // check and generate shape
249            let (la_rest, la_matmul) = la.dim_split_at(-2)?;
250            let (lb_rest, lb_matmul) = lb.dim_split_at(-2)?;
251            rstsr_assert_eq!(la_matmul.shape()[1], lb_matmul.shape()[0], InvalidLayout)?;
252            let (la_rest_b, lb_rest_b) = broadcast_layout(&la_rest, &lb_rest, RowMajor)?;
253            // layout order determination
254            let mut sc = la_rest_b.shape().clone();
255            sc.append(&mut vec![la_matmul.shape()[0], lb_matmul.shape()[1]]);
256            let lc = sc.c();
257            // return layout configuration
258            let (lc_rest, lc_matmul) = lc.dim_split_at(-2)?;
259            Ok(LayoutMatMulConfig {
260                matmul_type: MatMulType::GEMMXX,
261                lc: lc.to_dim()?,
262                la_rest: Some(la_rest_b),
263                lb_rest: Some(lb_rest_b),
264                lc_rest: Some(lc_rest),
265                la_matmul: la.to_dim()?,
266                lb_matmul: lb_matmul.to_dim()?,
267                lc_matmul: lc_matmul.to_dim()?,
268            })
269        },
270        (0, _) | (_, 0) => rstsr_invalid!((na, nb), "In matmul, 0-dim is not allowed."),
271    }
272}
273
274/// Resolve matmul layouts against a caller-provided `lc`.
275///
276/// Unlike [`layout_matmul_dyn_row_major`], which *constructs* `lc` from `la`
277/// and `lb`, this function takes the real `lc` as the source of truth for the
278/// batch shape (it comes from the caller and may be strided / non-zero-offset)
279/// and derives `la_rest` / `lb_rest` by broadcasting them **to** `lc_rest`.
280/// All returned layouts are split from the real `la` / `lb` / `lc`, so their
281/// strides and offsets are valid for indexing into the actual operand buffers.
282///
283/// This is the canonical entry point used by the `DeviceMatMulAPI`
284/// implementations (faer, BLAS backends, naive CPU); it centralizes the rule
285/// table and the batch-broadcasting so the device drivers only have to dispatch
286/// on [`MatMulType`] and iterate the rest layouts.
287pub fn layout_matmul_dyn_row_major_with_lc(
288    la: &Layout<IxD>,
289    lb: &Layout<IxD>,
290    lc: &Layout<IxD>,
291) -> Result<LayoutMatMulConfig<IxD, IxD>> {
292    let na = la.ndim();
293    let nb = lb.ndim();
294    let nc = lc.ndim();
295    match (na, nb) {
296        (1, 1) => {
297            // rule 1: vector inner dot
298            rstsr_assert_eq!(la.shape(), lb.shape(), InvalidLayout)?;
299            rstsr_assert_eq!(nc, 0, InvalidLayout)?;
300            Ok(LayoutMatMulConfig {
301                matmul_type: MatMulType::InnerDot,
302                lc: lc.clone(),
303                la_rest: None,
304                lb_rest: None,
305                lc_rest: None,
306                la_matmul: la.clone(),
307                lb_matmul: lb.clone(),
308                lc_matmul: lc.clone(),
309            })
310        },
311        (2, 2) => {
312            // rule 2: matrix multiplication
313            rstsr_assert_eq!(nc, 2, InvalidLayout)?;
314            rstsr_assert_eq!(la.shape()[1], lb.shape()[0], InvalidLayout)?;
315            Ok(LayoutMatMulConfig {
316                matmul_type: MatMulType::GEMM22,
317                lc: lc.clone(),
318                la_rest: None,
319                lb_rest: None,
320                lc_rest: None,
321                la_matmul: la.clone(),
322                lb_matmul: lb.clone(),
323                lc_matmul: lc.clone(),
324            })
325        },
326        (1, 2..) => {
327            // rule 3: | `        K` | `..., K, N` | `   ..., N` |
328            rstsr_assert_eq!(nb, nc + 1, InvalidLayout)?;
329            let (la_r, la_m) = la.dim_split_at(-1)?;
330            let (lb_r, lb_m) = lb.dim_split_at(-2)?;
331            let (lc_r, lc_m) = lc.dim_split_at(-1)?;
332            let la_rest = broadcast_layout_to_first(&lc_r, &la_r, RowMajor)?.1;
333            Ok(LayoutMatMulConfig {
334                matmul_type: MatMulType::GEVM,
335                lc: lc.clone(),
336                la_rest: Some(la_rest),
337                lb_rest: Some(lb_r),
338                lc_rest: Some(lc_r),
339                la_matmul: la_m.dim_insert(0)?,
340                lb_matmul: lb_m,
341                lc_matmul: lc_m.dim_insert(0)?,
342            })
343        },
344        (2.., 1) => {
345            // rule 4: | `..., M, K` | `        K` | `   ..., M` |
346            rstsr_assert_eq!(na, nc + 1, InvalidLayout)?;
347            let (la_r, la_m) = la.dim_split_at(-2)?;
348            let (lb_r, lb_m) = lb.dim_split_at(-1)?;
349            let (lc_r, lc_m) = lc.dim_split_at(-1)?;
350            let lb_rest = broadcast_layout_to_first(&lc_r, &lb_r, RowMajor)?.1;
351            Ok(LayoutMatMulConfig {
352                matmul_type: MatMulType::GEMV,
353                lc: lc.clone(),
354                la_rest: Some(la_r),
355                lb_rest: Some(lb_rest),
356                lc_rest: Some(lc_r),
357                la_matmul: la_m,
358                lb_matmul: lb_m.dim_insert(1)?,
359                lc_matmul: lc_m.dim_insert(1)?,
360            })
361        },
362        (2, 3..) => {
363            // rule 5: | `     M, K` | `..., K, N` | `..., M, N` |
364            rstsr_assert_eq!(nb, nc, InvalidLayout)?;
365            let (la_r, la_m) = la.dim_split_at(-2)?;
366            let (lb_r, lb_m) = lb.dim_split_at(-2)?;
367            let (lc_r, lc_m) = lc.dim_split_at(-2)?;
368            let la_rest = broadcast_layout_to_first(&lc_r, &la_r, RowMajor)?.1;
369            Ok(LayoutMatMulConfig {
370                matmul_type: MatMulType::GEMM2X,
371                lc: lc.clone(),
372                la_rest: Some(la_rest),
373                lb_rest: Some(lb_r),
374                lc_rest: Some(lc_r),
375                la_matmul: la_m,
376                lb_matmul: lb_m,
377                lc_matmul: lc_m,
378            })
379        },
380        (3.., 2) => {
381            // rule 6: | `..., M, K` | `     K, N` | `..., M, N` |
382            rstsr_assert_eq!(na, nc, InvalidLayout)?;
383            let (la_r, la_m) = la.dim_split_at(-2)?;
384            let (lb_r, lb_m) = lb.dim_split_at(-2)?;
385            let (lc_r, lc_m) = lc.dim_split_at(-2)?;
386            let lb_rest = broadcast_layout_to_first(&lc_r, &lb_r, RowMajor)?.1;
387            Ok(LayoutMatMulConfig {
388                matmul_type: MatMulType::GEMMX2,
389                lc: lc.clone(),
390                la_rest: Some(la_r),
391                lb_rest: Some(lb_rest),
392                lc_rest: Some(lc_r),
393                la_matmul: la_m,
394                lb_matmul: lb_m,
395                lc_matmul: lc_m,
396            })
397        },
398        (3.., 3..) => {
399            // rule 7: | `..., M, K` | `..., K, N` | `..., M, N` |
400            rstsr_assert_eq!(na, nc, InvalidLayout)?;
401            rstsr_assert_eq!(nb, nc, InvalidLayout)?;
402            let (la_r, la_m) = la.dim_split_at(-2)?;
403            let (lb_r, lb_m) = lb.dim_split_at(-2)?;
404            let (lc_r, lc_m) = lc.dim_split_at(-2)?;
405            // both A and B batch dims broadcast against C's batch dims
406            let la_rest = broadcast_layout_to_first(&lc_r, &la_r, RowMajor)?.1;
407            let lb_rest = broadcast_layout_to_first(&lc_r, &lb_r, RowMajor)?.1;
408            Ok(LayoutMatMulConfig {
409                matmul_type: MatMulType::GEMMXX,
410                lc: lc.clone(),
411                la_rest: Some(la_rest),
412                lb_rest: Some(lb_rest),
413                lc_rest: Some(lc_r),
414                la_matmul: la_m,
415                lb_matmul: lb_m,
416                lc_matmul: lc_m,
417            })
418        },
419        (0, _) | (_, 0) => rstsr_invalid!((na, nb), "In matmul, 0-dim is not allowed."),
420    }
421}
422
423fn layout_matmul_dyn_col_major(la: &Layout<IxD>, lb: &Layout<IxD>) -> Result<LayoutMatMulConfig<IxD, IxD>> {
424    // For col-major, we re-use the row-major implementation via the identity
425    //     C[t, m, n] = A[t, m, k] @ B[t, k, n]   (row-major)
426    // <=> C[n, m, t] = B[n, k, t] @ A[k, m, t]   (col-major)
427    // i.e. reverse all axes and swap A/B. So we delegate to
428    // `layout_matmul_dyn_row_major(lb_rev, la_rev)` and then reverse-axes (and
429    // swap A/B back) on every layout field that the row-major impl returns.
430    //
431    // Note that rules 5/6/7 (broadcasting matmul) are only supported by the
432    // row-major rules in the array API spec; for col-major we accept them
433    // here, but the corresponding `DeviceMatMulAPI` impl must follow the same
434    // reverse-axes-and-swap convention (see e.g. `device_faer/matmul.rs`).
435    let na = la.ndim();
436    let nb = lb.ndim();
437    if na == 0 || nb == 0 {
438        return rstsr_invalid!((na, nb), "In matmul, 0-dim is not allowed.");
439    }
440    let la_rev = la.reverse_axes();
441    let lb_rev = lb.reverse_axes();
442    let cfg = layout_matmul_dyn_row_major(&lb_rev, &la_rev)?;
443    Ok(LayoutMatMulConfig {
444        matmul_type: cfg.matmul_type,
445        lc: cfg.lc.reverse_axes(),
446        // row-major's `la_*` corresponds to (reversed) B, so it maps back to
447        // col-major's `lb_*` (after reversing axes again).
448        la_rest: cfg.lb_rest.map(|l| l.reverse_axes()),
449        lb_rest: cfg.la_rest.map(|l| l.reverse_axes()),
450        lc_rest: cfg.lc_rest.map(|l| l.reverse_axes()),
451        la_matmul: cfg.lb_matmul.reverse_axes(),
452        lb_matmul: cfg.la_matmul.reverse_axes(),
453        lc_matmul: cfg.lc_matmul.reverse_axes(),
454    })
455}
456
457impl LayoutMatMulAPI<IxD, IxD> for LayoutMatMulConfig<IxD, IxD> {
458    type DC = IxD;
459    fn layout_matmul(la: &Layout<IxD>, lb: &Layout<IxD>, order: FlagOrder) -> Result<Self> {
460        match order {
461            RowMajor => layout_matmul_dyn_row_major(la, lb),
462            ColMajor => layout_matmul_dyn_col_major(la, lb),
463        }
464    }
465}
466
467macro_rules! impl_fixed {
468    ($DA:ident, $DB:ident, $DC:ident) => {
469        impl LayoutMatMulAPI<$DA, $DB> for LayoutMatMulConfig<$DA, $DB> {
470            type DC = $DC;
471            fn layout_matmul(la: &Layout<$DA>, lb: &Layout<$DB>, order: FlagOrder) -> Result<Self> {
472                let la = la.to_dim::<IxD>()?;
473                let lb = lb.to_dim::<IxD>()?;
474                let cfg = LayoutMatMulConfig::layout_matmul(&la, &lb, order)?;
475                return Ok(LayoutMatMulConfig {
476                    matmul_type: cfg.matmul_type,
477                    lc: cfg.lc.into_dim()?,
478                    la_rest: cfg.la_rest,
479                    lb_rest: cfg.lb_rest,
480                    lc_rest: cfg.lc_rest,
481                    la_matmul: cfg.la_matmul,
482                    lb_matmul: cfg.lb_matmul,
483                    lc_matmul: cfg.lc_matmul,
484                });
485            }
486        }
487    };
488}
489
490// rule 3
491impl_fixed!(Ix2, Ix1, Ix1);
492impl_fixed!(Ix3, Ix1, Ix2);
493impl_fixed!(Ix4, Ix1, Ix3);
494impl_fixed!(Ix5, Ix1, Ix4);
495impl_fixed!(Ix6, Ix1, Ix5);
496impl_fixed!(Ix7, Ix1, Ix6);
497impl_fixed!(Ix8, Ix1, Ix7);
498impl_fixed!(Ix9, Ix1, Ix8);
499
500// rule 4
501impl_fixed!(Ix1, Ix2, Ix1);
502impl_fixed!(Ix1, Ix3, Ix2);
503impl_fixed!(Ix1, Ix4, Ix3);
504impl_fixed!(Ix1, Ix5, Ix4);
505impl_fixed!(Ix1, Ix6, Ix5);
506impl_fixed!(Ix1, Ix7, Ix6);
507impl_fixed!(Ix1, Ix8, Ix7);
508impl_fixed!(Ix1, Ix9, Ix8);
509
510// rule 5
511impl_fixed!(Ix3, Ix2, Ix3);
512impl_fixed!(Ix4, Ix2, Ix4);
513impl_fixed!(Ix5, Ix2, Ix5);
514impl_fixed!(Ix6, Ix2, Ix6);
515impl_fixed!(Ix7, Ix2, Ix7);
516impl_fixed!(Ix8, Ix2, Ix8);
517impl_fixed!(Ix9, Ix2, Ix9);
518
519// rule 6
520impl_fixed!(Ix2, Ix3, Ix3);
521impl_fixed!(Ix2, Ix4, Ix4);
522impl_fixed!(Ix2, Ix5, Ix5);
523impl_fixed!(Ix2, Ix6, Ix6);
524impl_fixed!(Ix2, Ix7, Ix7);
525impl_fixed!(Ix2, Ix8, Ix8);
526impl_fixed!(Ix2, Ix9, Ix9);
527
528// rule 7
529impl_fixed!(Ix3, Ix3, Ix3);
530impl_fixed!(Ix4, Ix4, Ix4);
531impl_fixed!(Ix5, Ix5, Ix5);
532impl_fixed!(Ix6, Ix6, Ix6);
533impl_fixed!(Ix7, Ix7, Ix7);
534impl_fixed!(Ix8, Ix8, Ix8);
535impl_fixed!(Ix9, Ix9, Ix9);
536
537// partial fixed
538impl_fixed!(Ix1, IxD, IxD);
539impl_fixed!(Ix2, IxD, IxD);
540impl_fixed!(Ix3, IxD, IxD);
541impl_fixed!(Ix4, IxD, IxD);
542impl_fixed!(Ix5, IxD, IxD);
543impl_fixed!(Ix6, IxD, IxD);
544impl_fixed!(Ix7, IxD, IxD);
545impl_fixed!(Ix8, IxD, IxD);
546impl_fixed!(Ix9, IxD, IxD);
547
548impl_fixed!(IxD, Ix1, IxD);
549impl_fixed!(IxD, Ix2, IxD);
550impl_fixed!(IxD, Ix3, IxD);
551impl_fixed!(IxD, Ix4, IxD);
552impl_fixed!(IxD, Ix5, IxD);
553impl_fixed!(IxD, Ix6, IxD);
554impl_fixed!(IxD, Ix7, IxD);
555impl_fixed!(IxD, Ix8, IxD);
556impl_fixed!(IxD, Ix9, IxD);
557
558#[cfg(test)]
559mod test_fixed {
560
561    #[test]
562    fn test_layout_matmul() {
563        use super::*;
564        let la = [4].c();
565        let lb = [4].c();
566        let config = LayoutMatMulConfig::layout_matmul(&la, &lb, RowMajor).unwrap();
567        assert_eq!(config.matmul_type, MatMulType::InnerDot);
568        assert_eq!(config.lc.shape(), &[]);
569        assert_eq!(config.la_matmul.shape(), &[4]);
570        assert_eq!(config.lb_matmul.shape(), &[4]);
571
572        let la = [5].c();
573        let lb = [3, 4, 5, 6].f().swapaxes(0, 1).unwrap();
574        let config = LayoutMatMulConfig::layout_matmul(&la, &lb, RowMajor).unwrap();
575        assert_eq!(config.lc, [4, 3, 6].c());
576
577        let la = [3, 4, 5, 6].f().swapaxes(0, 1).unwrap();
578        let lb = [6].c();
579        let config = LayoutMatMulConfig::layout_matmul(&la, &lb, RowMajor).unwrap();
580        assert_eq!(config.lc, [4, 3, 5].c());
581
582        let la = [7, 6].c();
583        let lb = [2, 3, 4, 5, 6].f().swapaxes(-1, -2).unwrap();
584        let config = LayoutMatMulConfig::layout_matmul(&la, &lb, RowMajor).unwrap();
585        assert_eq!(config.lc, [2, 3, 4, 7, 5].c());
586
587        let la = [2, 3, 4, 5, 6].f().swapaxes(-1, -2).unwrap();
588        let lb = [5, 7].c();
589        let config = LayoutMatMulConfig::layout_matmul(&la, &lb, RowMajor).unwrap();
590        assert_eq!(config.lc, [2, 3, 4, 6, 7].c());
591
592        let la = [4, 1, 2, 5, 6].f().swapaxes(0, 2).unwrap();
593        let lb = [4, 3, 1, 6, 7].f().swapaxes(0, 2).unwrap();
594        let config = LayoutMatMulConfig::layout_matmul(&la, &lb, RowMajor).unwrap();
595        assert_eq!(config.lc, [2, 3, 4, 5, 7].c());
596
597        let la = [4, 3, 2, 5, 6].f().swapaxes(0, 2).unwrap();
598        let lb = [4, 3, 2, 6, 7].f().swapaxes(0, 2).unwrap();
599        let config = LayoutMatMulConfig::layout_matmul(&la, &lb, RowMajor).unwrap();
600        assert_eq!(config.lc, [2, 3, 4, 5, 7].c());
601
602        // col-major broadcasting (mirror of the row-major cases above; the
603        // matmul dims are the first two, the trailing dims broadcast).
604        let la = [5, 6].f();
605        let lb = [6, 7].f();
606        let config = LayoutMatMulConfig::layout_matmul(&la, &lb, ColMajor).unwrap();
607        assert_eq!(config.lc, [5, 7].f());
608
609        // rule 3 mirrored: K @ (K, N, ...) -> (N, ...)
610        let la = [5].f();
611        let lb = [5, 6, 3, 4].f();
612        let config = LayoutMatMulConfig::layout_matmul(&la, &lb, ColMajor).unwrap();
613        assert_eq!(config.lc, [6, 3, 4].f());
614
615        // rule 4 mirrored: (M, K, ...) @ K -> (M, ...)
616        let la = [5, 6, 3, 4].f();
617        let lb = [6].f();
618        let config = LayoutMatMulConfig::layout_matmul(&la, &lb, ColMajor).unwrap();
619        assert_eq!(config.lc, [5, 3, 4].f());
620
621        // rule 7 mirrored: full 5D x 5D batched matmul, including broadcast
622        // on the trailing dims (`1` broadcasts against `2`/`3`).
623        let la = [5, 6, 2, 1, 4].f();
624        let lb = [6, 7, 1, 3, 4].f();
625        let config = LayoutMatMulConfig::layout_matmul(&la, &lb, ColMajor).unwrap();
626        assert_eq!(config.lc, [5, 7, 2, 3, 4].f());
627    }
628}