1use crate::prelude_dev::*;
26
27#[derive(Clone, Copy, Debug, PartialEq, Eq)]
29pub enum MatMulType {
30 InnerDot,
31 GEMM22,
32 GEVM,
33 GEMV,
34 GEMM2X,
35 GEMMX2,
36 GEMMXX,
37}
38
39#[derive(Clone, Debug)]
40pub struct LayoutMatMulConfig<DA, DB>
41where
42 DA: DimAPI,
43 DB: DimAPI,
44 Self: LayoutMatMulAPI<DA, DB>,
45{
46 pub matmul_type: MatMulType,
47 pub lc: Layout<<Self as LayoutMatMulAPI<DA, DB>>::DC>,
48 pub la_rest: Option<Layout<IxD>>,
49 pub lb_rest: Option<Layout<IxD>>,
50 pub lc_rest: Option<Layout<IxD>>,
51 pub la_matmul: Layout<IxD>,
52 pub lb_matmul: Layout<IxD>,
53 pub lc_matmul: Layout<IxD>,
54}
55
56pub trait LayoutMatMulAPI<DA, DB>
57where
58 DA: DimAPI,
59 DB: DimAPI,
60 Self: Sized,
61{
62 type DC: DimAPI;
63 fn layout_matmul(la: &Layout<DA>, lb: &Layout<DB>, order: FlagOrder) -> Result<Self>;
67}
68
69impl LayoutMatMulAPI<Ix1, Ix1> for LayoutMatMulConfig<Ix1, Ix1> {
71 type DC = Ix0;
72 fn layout_matmul(la: &Layout<Ix1>, lb: &Layout<Ix1>, _: FlagOrder) -> Result<Self> {
73 rstsr_assert_eq!(la.shape(), lb.shape(), InvalidLayout)?;
75 let lc = unsafe { Layout::new_unchecked([], [], 0) };
76 Ok(LayoutMatMulConfig {
77 matmul_type: MatMulType::InnerDot,
78 lc: lc.clone(),
79 la_rest: None,
80 lb_rest: None,
81 lc_rest: None,
82 la_matmul: la.to_dim()?,
83 lb_matmul: lb.to_dim()?,
84 lc_matmul: lc.to_dim()?,
85 })
86 }
87}
88
89impl LayoutMatMulAPI<Ix2, Ix2> for LayoutMatMulConfig<Ix2, Ix2> {
91 type DC = Ix2;
92 fn layout_matmul(la: &Layout<Ix2>, lb: &Layout<Ix2>, order: FlagOrder) -> Result<Self> {
93 rstsr_assert_eq!(la.shape()[1], lb.shape()[0], InvalidLayout)?;
95 let sc = [la.shape()[0], lb.shape()[1]];
96 let lc = match order {
98 RowMajor => sc.c(),
99 ColMajor => sc.f(),
100 };
101 Ok(LayoutMatMulConfig {
103 matmul_type: MatMulType::GEMM22,
104 lc: lc.clone(),
105 la_rest: None,
106 lb_rest: None,
107 lc_rest: None,
108 la_matmul: la.to_dim()?,
109 lb_matmul: lb.to_dim()?,
110 lc_matmul: lc.to_dim()?,
111 })
112 }
113}
114
115fn layout_matmul_dyn_row_major(
116 la: &Layout<IxD>,
117 lb: &Layout<IxD>,
118) -> Result<LayoutMatMulConfig<IxD, IxD>> {
119 let na = la.ndim();
120 let nb = lb.ndim();
121 match (na, nb) {
122 (1, 1) => {
123 rstsr_assert_eq!(la.shape(), lb.shape(), InvalidLayout)?;
125 let lc = unsafe { Layout::new_unchecked(vec![], vec![], 0) };
126 Ok(LayoutMatMulConfig {
127 matmul_type: MatMulType::InnerDot,
128 lc: lc.clone(),
129 la_rest: None,
130 lb_rest: None,
131 lc_rest: None,
132 la_matmul: la.to_dim()?,
133 lb_matmul: lb.to_dim()?,
134 lc_matmul: lc.to_dim()?,
135 })
136 },
137 (2, 2) => {
138 rstsr_assert_eq!(la.shape()[1], lb.shape()[0], InvalidLayout)?;
141 let sc = vec![la.shape()[0], lb.shape()[1]];
142 let lc = sc.c();
144 Ok(LayoutMatMulConfig {
146 matmul_type: MatMulType::GEMM22,
147 lc: lc.clone(),
148 la_rest: None,
149 lb_rest: None,
150 lc_rest: None,
151 la_matmul: la.to_dim()?,
152 lb_matmul: lb.to_dim()?,
153 lc_matmul: lc.to_dim()?,
154 })
155 },
156 (1, 2..) => {
157 let (lb_rest, lb_matmul) = lb.dim_split_at(-2)?;
160 rstsr_assert_eq!(la.shape()[0], lb_matmul.shape()[0], InvalidLayout)?;
161 let mut sc = lb_rest.shape().clone();
163 sc.push(lb_matmul.shape()[1]);
164 let lc = sc.c();
165 let (lc_rest, lc_matmul) = lc.dim_split_at(-1)?;
167 Ok(LayoutMatMulConfig {
168 matmul_type: MatMulType::GEVM,
169 lc: lc.to_dim()?,
170 la_rest: None,
171 lb_rest: Some(lb_rest),
172 lc_rest: Some(lc_rest),
173 la_matmul: la.to_dim()?,
174 lb_matmul: lb_matmul.to_dim()?,
175 lc_matmul: lc_matmul.to_dim()?,
176 })
177 },
178 (2.., 1) => {
179 let (la_rest, la_matmul) = la.dim_split_at(-2)?;
182 rstsr_assert_eq!(lb.shape()[0], la_matmul.shape()[1], InvalidLayout)?;
183 let mut sc = la_rest.shape().clone();
185 sc.push(la_matmul.shape()[0]);
186 let lc = sc.c();
187 let (lc_rest, lc_matmul) = lc.dim_split_at(-1)?;
189 Ok(LayoutMatMulConfig {
190 matmul_type: MatMulType::GEMV,
191 lc: lc.to_dim()?,
192 la_rest: Some(la_rest),
193 lb_rest: None,
194 lc_rest: Some(lc_rest),
195 la_matmul: la_matmul.to_dim()?,
196 lb_matmul: lb.to_dim()?,
197 lc_matmul: lc_matmul.to_dim()?,
198 })
199 },
200 (2, 3..) => {
201 let (lb_rest, lb_matmul) = lb.dim_split_at(-2)?;
204 rstsr_assert_eq!(la.shape()[1], lb_matmul.shape()[0], InvalidLayout)?;
205 let mut sc = lb_rest.shape().clone();
207 sc.append(&mut vec![la.shape()[0], lb_matmul.shape()[1]]);
208 let lc = sc.c();
209 let (lc_rest, lc_matmul) = lc.dim_split_at(-2)?;
211 Ok(LayoutMatMulConfig {
212 matmul_type: MatMulType::GEMM2X,
213 lc: lc.to_dim()?,
214 la_rest: None,
215 lb_rest: Some(lb_rest),
216 lc_rest: Some(lc_rest),
217 la_matmul: la.to_dim()?,
218 lb_matmul: lb_matmul.to_dim()?,
219 lc_matmul: lc_matmul.to_dim()?,
220 })
221 },
222 (3.., 2) => {
223 let (la_rest, la_matmul) = la.dim_split_at(-2)?;
226 rstsr_assert_eq!(la_matmul.shape()[1], lb.shape()[0], InvalidLayout)?;
227 let mut sc = la_rest.shape().clone();
229 sc.append(&mut vec![la_matmul.shape()[0], lb.shape()[1]]);
230 let lc = sc.c();
231 let (lc_rest, lc_matmul) = lc.dim_split_at(-2)?;
233 Ok(LayoutMatMulConfig {
234 matmul_type: MatMulType::GEMMX2,
235 lc: lc.to_dim()?,
236 la_rest: Some(la_rest),
237 lb_rest: None,
238 lc_rest: Some(lc_rest),
239 la_matmul: la_matmul.to_dim()?,
240 lb_matmul: lb.to_dim()?,
241 lc_matmul: lc_matmul.to_dim()?,
242 })
243 },
244 (3.., 3..) => {
245 let (la_rest, la_matmul) = la.dim_split_at(-2)?;
247 let (lb_rest, lb_matmul) = lb.dim_split_at(-2)?;
248 rstsr_assert_eq!(la_matmul.shape()[1], lb_matmul.shape()[0], InvalidLayout)?;
249 let (la_rest_b, lb_rest_b) = broadcast_layout(&la_rest, &lb_rest, RowMajor)?;
250 let mut sc = la_rest_b.shape().clone();
252 sc.append(&mut vec![la_matmul.shape()[0], lb_matmul.shape()[1]]);
253 let lc = sc.c();
254 let (lc_rest, lc_matmul) = lc.dim_split_at(-2)?;
256 Ok(LayoutMatMulConfig {
257 matmul_type: MatMulType::GEMMXX,
258 lc: lc.to_dim()?,
259 la_rest: Some(la_rest_b),
260 lb_rest: Some(lb_rest_b),
261 lc_rest: Some(lc_rest),
262 la_matmul: la.to_dim()?,
263 lb_matmul: lb_matmul.to_dim()?,
264 lc_matmul: lc_matmul.to_dim()?,
265 })
266 },
267 (0, _) | (_, 0) => rstsr_invalid!((na, nb), "In matmul, 0-dim is not allowed."),
268 }
269}
270
271fn layout_matmul_dyn_col_major(
272 la: &Layout<IxD>,
273 lb: &Layout<IxD>,
274) -> Result<LayoutMatMulConfig<IxD, IxD>> {
275 let na = la.ndim();
276 let nb = lb.ndim();
277 match (na, nb) {
278 (1, 1) => {
279 rstsr_assert_eq!(la.shape(), lb.shape(), InvalidLayout)?;
281 let lc = unsafe { Layout::new_unchecked(vec![], vec![], 0) };
282 Ok(LayoutMatMulConfig {
283 matmul_type: MatMulType::InnerDot,
284 lc: lc.clone(),
285 la_rest: None,
286 lb_rest: None,
287 lc_rest: None,
288 la_matmul: la.to_dim()?,
289 lb_matmul: lb.to_dim()?,
290 lc_matmul: lc.to_dim()?,
291 })
292 },
293 (2, 2) => {
294 rstsr_assert_eq!(la.shape()[1], lb.shape()[0], InvalidLayout)?;
297 let sc = vec![la.shape()[0], lb.shape()[1]];
298 let lc = sc.f();
300 Ok(LayoutMatMulConfig {
302 matmul_type: MatMulType::GEMM22,
303 lc: lc.clone(),
304 la_rest: None,
305 lb_rest: None,
306 lc_rest: None,
307 la_matmul: la.to_dim()?,
308 lb_matmul: lb.to_dim()?,
309 lc_matmul: lc.to_dim()?,
310 })
311 },
312 (1, 2) => {
313 rstsr_assert_eq!(la.shape()[0], lb.shape()[0], InvalidLayout)?;
316 let sc = vec![lb.shape()[1]];
317 let lc = sc.f();
318 Ok(LayoutMatMulConfig {
319 matmul_type: MatMulType::GEVM,
320 lc: lc.to_dim()?,
321 la_rest: None,
322 lb_rest: None,
323 lc_rest: None,
324 la_matmul: la.to_dim()?,
325 lb_matmul: lb.to_dim()?,
326 lc_matmul: lc.to_dim()?,
327 })
328 },
329 (2, 1) => {
330 rstsr_assert_eq!(la.shape()[1], lb.shape()[0], InvalidLayout)?;
333 let sc = vec![la.shape()[0]];
334 let lc = sc.f();
335 Ok(LayoutMatMulConfig {
337 matmul_type: MatMulType::GEMV,
338 lc: lc.to_dim()?,
339 la_rest: None,
340 lb_rest: None,
341 lc_rest: None,
342 la_matmul: la.to_dim()?,
343 lb_matmul: lb.to_dim()?,
344 lc_matmul: lc.to_dim()?,
345 })
346 },
347 (1, 3..) | (3.., 1) | (2, 3..) | (3.., 2) | (3.., 3..) => {
348 rstsr_raise!(InvalidLayout, "Broadcasting matmul is not supported in col-major.")
349 },
350 (0, _) | (_, 0) => rstsr_invalid!((na, nb), "In matmul, 0-dim is not allowed."),
351 }
352}
353
354impl LayoutMatMulAPI<IxD, IxD> for LayoutMatMulConfig<IxD, IxD> {
355 type DC = IxD;
356 fn layout_matmul(la: &Layout<IxD>, lb: &Layout<IxD>, order: FlagOrder) -> Result<Self> {
357 match order {
358 RowMajor => layout_matmul_dyn_row_major(la, lb),
359 ColMajor => layout_matmul_dyn_col_major(la, lb),
360 }
361 }
362}
363
364macro_rules! impl_fixed {
365 ($DA:ident, $DB:ident, $DC:ident) => {
366 impl LayoutMatMulAPI<$DA, $DB> for LayoutMatMulConfig<$DA, $DB> {
367 type DC = $DC;
368 fn layout_matmul(la: &Layout<$DA>, lb: &Layout<$DB>, order: FlagOrder) -> Result<Self> {
369 let la = la.to_dim::<IxD>()?;
370 let lb = lb.to_dim::<IxD>()?;
371 let cfg = LayoutMatMulConfig::layout_matmul(&la, &lb, order)?;
372 return Ok(LayoutMatMulConfig {
373 matmul_type: cfg.matmul_type,
374 lc: cfg.lc.into_dim()?,
375 la_rest: cfg.la_rest,
376 lb_rest: cfg.lb_rest,
377 lc_rest: cfg.lc_rest,
378 la_matmul: cfg.la_matmul,
379 lb_matmul: cfg.lb_matmul,
380 lc_matmul: cfg.lc_matmul,
381 });
382 }
383 }
384 };
385}
386
387impl_fixed!(Ix2, Ix1, Ix1);
389impl_fixed!(Ix3, Ix1, Ix2);
390impl_fixed!(Ix4, Ix1, Ix3);
391impl_fixed!(Ix5, Ix1, Ix4);
392impl_fixed!(Ix6, Ix1, Ix5);
393impl_fixed!(Ix7, Ix1, Ix6);
394impl_fixed!(Ix8, Ix1, Ix7);
395impl_fixed!(Ix9, Ix1, Ix8);
396
397impl_fixed!(Ix1, Ix2, Ix1);
399impl_fixed!(Ix1, Ix3, Ix2);
400impl_fixed!(Ix1, Ix4, Ix3);
401impl_fixed!(Ix1, Ix5, Ix4);
402impl_fixed!(Ix1, Ix6, Ix5);
403impl_fixed!(Ix1, Ix7, Ix6);
404impl_fixed!(Ix1, Ix8, Ix7);
405impl_fixed!(Ix1, Ix9, Ix8);
406
407impl_fixed!(Ix3, Ix2, Ix3);
409impl_fixed!(Ix4, Ix2, Ix4);
410impl_fixed!(Ix5, Ix2, Ix5);
411impl_fixed!(Ix6, Ix2, Ix6);
412impl_fixed!(Ix7, Ix2, Ix7);
413impl_fixed!(Ix8, Ix2, Ix8);
414impl_fixed!(Ix9, Ix2, Ix9);
415
416impl_fixed!(Ix2, Ix3, Ix3);
418impl_fixed!(Ix2, Ix4, Ix4);
419impl_fixed!(Ix2, Ix5, Ix5);
420impl_fixed!(Ix2, Ix6, Ix6);
421impl_fixed!(Ix2, Ix7, Ix7);
422impl_fixed!(Ix2, Ix8, Ix8);
423impl_fixed!(Ix2, Ix9, Ix9);
424
425impl_fixed!(Ix3, Ix3, Ix3);
427impl_fixed!(Ix4, Ix4, Ix4);
428impl_fixed!(Ix5, Ix5, Ix5);
429impl_fixed!(Ix6, Ix6, Ix6);
430impl_fixed!(Ix7, Ix7, Ix7);
431impl_fixed!(Ix8, Ix8, Ix8);
432impl_fixed!(Ix9, Ix9, Ix9);
433
434impl_fixed!(Ix1, IxD, IxD);
436impl_fixed!(Ix2, IxD, IxD);
437impl_fixed!(Ix3, IxD, IxD);
438impl_fixed!(Ix4, IxD, IxD);
439impl_fixed!(Ix5, IxD, IxD);
440impl_fixed!(Ix6, IxD, IxD);
441impl_fixed!(Ix7, IxD, IxD);
442impl_fixed!(Ix8, IxD, IxD);
443impl_fixed!(Ix9, IxD, IxD);
444
445impl_fixed!(IxD, Ix1, IxD);
446impl_fixed!(IxD, Ix2, IxD);
447impl_fixed!(IxD, Ix3, IxD);
448impl_fixed!(IxD, Ix4, IxD);
449impl_fixed!(IxD, Ix5, IxD);
450impl_fixed!(IxD, Ix6, IxD);
451impl_fixed!(IxD, Ix7, IxD);
452impl_fixed!(IxD, Ix8, IxD);
453impl_fixed!(IxD, Ix9, IxD);
454
455#[cfg(test)]
456mod test_fixed {
457
458 #[test]
459 fn test_layout_matmul() {
460 use super::*;
461 let la = [4].c();
462 let lb = [4].c();
463 let config = LayoutMatMulConfig::layout_matmul(&la, &lb, RowMajor).unwrap();
464 assert_eq!(config.matmul_type, MatMulType::InnerDot);
465 assert_eq!(config.lc.shape(), &[]);
466 assert_eq!(config.la_matmul.shape(), &[4]);
467 assert_eq!(config.lb_matmul.shape(), &[4]);
468
469 let la = [5].c();
470 let lb = [3, 4, 5, 6].f().swapaxes(0, 1).unwrap();
471 let config = LayoutMatMulConfig::layout_matmul(&la, &lb, RowMajor).unwrap();
472 assert_eq!(config.lc, [4, 3, 6].c());
473
474 let la = [3, 4, 5, 6].f().swapaxes(0, 1).unwrap();
475 let lb = [6].c();
476 let config = LayoutMatMulConfig::layout_matmul(&la, &lb, RowMajor).unwrap();
477 assert_eq!(config.lc, [4, 3, 5].c());
478
479 let la = [7, 6].c();
480 let lb = [2, 3, 4, 5, 6].f().swapaxes(-1, -2).unwrap();
481 let config = LayoutMatMulConfig::layout_matmul(&la, &lb, RowMajor).unwrap();
482 assert_eq!(config.lc, [2, 3, 4, 7, 5].c());
483
484 let la = [2, 3, 4, 5, 6].f().swapaxes(-1, -2).unwrap();
485 let lb = [5, 7].c();
486 let config = LayoutMatMulConfig::layout_matmul(&la, &lb, RowMajor).unwrap();
487 assert_eq!(config.lc, [2, 3, 4, 6, 7].c());
488
489 let la = [4, 1, 2, 5, 6].f().swapaxes(0, 2).unwrap();
490 let lb = [4, 3, 1, 6, 7].f().swapaxes(0, 2).unwrap();
491 let config = LayoutMatMulConfig::layout_matmul(&la, &lb, RowMajor).unwrap();
492 assert_eq!(config.lc, [2, 3, 4, 5, 7].c());
493
494 let la = [4, 3, 2, 5, 6].f().swapaxes(0, 2).unwrap();
495 let lb = [4, 3, 2, 6, 7].f().swapaxes(0, 2).unwrap();
496 let config = LayoutMatMulConfig::layout_matmul(&la, &lb, RowMajor).unwrap();
497 assert_eq!(config.lc, [2, 3, 4, 5, 7].c());
498
499 let la = [4, 3, 2, 5, 6].f().swapaxes(0, 2).unwrap();
500 let lb = [4, 3, 2, 6, 7].f().swapaxes(0, 2).unwrap();
501 let config = LayoutMatMulConfig::layout_matmul(&la, &lb, ColMajor);
502 assert!(config.is_err());
503
504 let la = [5, 6].c();
505 let lb = [6, 7].c();
506 let config = LayoutMatMulConfig::layout_matmul(&la, &lb, ColMajor).unwrap();
507 assert_eq!(config.lc, [5, 7].f());
508 }
509}