1use crate::prelude_dev::*;
32
33#[derive(Clone, Copy, Debug, PartialEq, Eq)]
35pub enum MatMulType {
36 InnerDot,
37 GEMM22,
38 GEVM,
39 GEMV,
40 GEMM2X,
41 GEMMX2,
42 GEMMXX,
43}
44
45#[derive(Clone, Debug)]
46pub struct LayoutMatMulConfig<DA, DB>
47where
48 DA: DimAPI,
49 DB: DimAPI,
50 Self: LayoutMatMulAPI<DA, DB>,
51{
52 pub matmul_type: MatMulType,
53 pub lc: Layout<<Self as LayoutMatMulAPI<DA, DB>>::DC>,
54 pub la_rest: Option<Layout<IxD>>,
55 pub lb_rest: Option<Layout<IxD>>,
56 pub lc_rest: Option<Layout<IxD>>,
57 pub la_matmul: Layout<IxD>,
58 pub lb_matmul: Layout<IxD>,
59 pub lc_matmul: Layout<IxD>,
60}
61
62pub trait LayoutMatMulAPI<DA, DB>
63where
64 DA: DimAPI,
65 DB: DimAPI,
66 Self: Sized,
67{
68 type DC: DimAPI;
69 fn layout_matmul(la: &Layout<DA>, lb: &Layout<DB>, order: FlagOrder) -> Result<Self>;
73}
74
75impl LayoutMatMulAPI<Ix1, Ix1> for LayoutMatMulConfig<Ix1, Ix1> {
77 type DC = Ix0;
78 fn layout_matmul(la: &Layout<Ix1>, lb: &Layout<Ix1>, _: FlagOrder) -> Result<Self> {
79 rstsr_assert_eq!(la.shape(), lb.shape(), InvalidLayout)?;
81 let lc = unsafe { Layout::new_unchecked([], [], 0) };
82 Ok(LayoutMatMulConfig {
83 matmul_type: MatMulType::InnerDot,
84 lc: lc.clone(),
85 la_rest: None,
86 lb_rest: None,
87 lc_rest: None,
88 la_matmul: la.to_dim()?,
89 lb_matmul: lb.to_dim()?,
90 lc_matmul: lc.to_dim()?,
91 })
92 }
93}
94
95impl LayoutMatMulAPI<Ix2, Ix2> for LayoutMatMulConfig<Ix2, Ix2> {
97 type DC = Ix2;
98 fn layout_matmul(la: &Layout<Ix2>, lb: &Layout<Ix2>, order: FlagOrder) -> Result<Self> {
99 rstsr_assert_eq!(la.shape()[1], lb.shape()[0], InvalidLayout)?;
101 let sc = [la.shape()[0], lb.shape()[1]];
102 let lc = match order {
104 RowMajor => sc.c(),
105 ColMajor => sc.f(),
106 };
107 Ok(LayoutMatMulConfig {
109 matmul_type: MatMulType::GEMM22,
110 lc: lc.clone(),
111 la_rest: None,
112 lb_rest: None,
113 lc_rest: None,
114 la_matmul: la.to_dim()?,
115 lb_matmul: lb.to_dim()?,
116 lc_matmul: lc.to_dim()?,
117 })
118 }
119}
120
121fn layout_matmul_dyn_row_major(la: &Layout<IxD>, lb: &Layout<IxD>) -> Result<LayoutMatMulConfig<IxD, IxD>> {
122 let na = la.ndim();
123 let nb = lb.ndim();
124 match (na, nb) {
125 (1, 1) => {
126 rstsr_assert_eq!(la.shape(), lb.shape(), InvalidLayout)?;
128 let lc = unsafe { Layout::new_unchecked(vec![], vec![], 0) };
129 Ok(LayoutMatMulConfig {
130 matmul_type: MatMulType::InnerDot,
131 lc: lc.clone(),
132 la_rest: None,
133 lb_rest: None,
134 lc_rest: None,
135 la_matmul: la.to_dim()?,
136 lb_matmul: lb.to_dim()?,
137 lc_matmul: lc.to_dim()?,
138 })
139 },
140 (2, 2) => {
141 rstsr_assert_eq!(la.shape()[1], lb.shape()[0], InvalidLayout)?;
144 let sc = vec![la.shape()[0], lb.shape()[1]];
145 let lc = sc.c();
147 Ok(LayoutMatMulConfig {
149 matmul_type: MatMulType::GEMM22,
150 lc: lc.clone(),
151 la_rest: None,
152 lb_rest: None,
153 lc_rest: None,
154 la_matmul: la.to_dim()?,
155 lb_matmul: lb.to_dim()?,
156 lc_matmul: lc.to_dim()?,
157 })
158 },
159 (1, 2..) => {
160 let (lb_rest, lb_matmul) = lb.dim_split_at(-2)?;
163 rstsr_assert_eq!(la.shape()[0], lb_matmul.shape()[0], InvalidLayout)?;
164 let mut sc = lb_rest.shape().clone();
166 sc.push(lb_matmul.shape()[1]);
167 let lc = sc.c();
168 let (lc_rest, lc_matmul) = lc.dim_split_at(-1)?;
170 Ok(LayoutMatMulConfig {
171 matmul_type: MatMulType::GEVM,
172 lc: lc.to_dim()?,
173 la_rest: None,
174 lb_rest: Some(lb_rest),
175 lc_rest: Some(lc_rest),
176 la_matmul: la.to_dim()?,
177 lb_matmul: lb_matmul.to_dim()?,
178 lc_matmul: lc_matmul.to_dim()?,
179 })
180 },
181 (2.., 1) => {
182 let (la_rest, la_matmul) = la.dim_split_at(-2)?;
185 rstsr_assert_eq!(lb.shape()[0], la_matmul.shape()[1], InvalidLayout)?;
186 let mut sc = la_rest.shape().clone();
188 sc.push(la_matmul.shape()[0]);
189 let lc = sc.c();
190 let (lc_rest, lc_matmul) = lc.dim_split_at(-1)?;
192 Ok(LayoutMatMulConfig {
193 matmul_type: MatMulType::GEMV,
194 lc: lc.to_dim()?,
195 la_rest: Some(la_rest),
196 lb_rest: None,
197 lc_rest: Some(lc_rest),
198 la_matmul: la_matmul.to_dim()?,
199 lb_matmul: lb.to_dim()?,
200 lc_matmul: lc_matmul.to_dim()?,
201 })
202 },
203 (2, 3..) => {
204 let (lb_rest, lb_matmul) = lb.dim_split_at(-2)?;
207 rstsr_assert_eq!(la.shape()[1], lb_matmul.shape()[0], InvalidLayout)?;
208 let mut sc = lb_rest.shape().clone();
210 sc.append(&mut vec![la.shape()[0], lb_matmul.shape()[1]]);
211 let lc = sc.c();
212 let (lc_rest, lc_matmul) = lc.dim_split_at(-2)?;
214 Ok(LayoutMatMulConfig {
215 matmul_type: MatMulType::GEMM2X,
216 lc: lc.to_dim()?,
217 la_rest: None,
218 lb_rest: Some(lb_rest),
219 lc_rest: Some(lc_rest),
220 la_matmul: la.to_dim()?,
221 lb_matmul: lb_matmul.to_dim()?,
222 lc_matmul: lc_matmul.to_dim()?,
223 })
224 },
225 (3.., 2) => {
226 let (la_rest, la_matmul) = la.dim_split_at(-2)?;
229 rstsr_assert_eq!(la_matmul.shape()[1], lb.shape()[0], InvalidLayout)?;
230 let mut sc = la_rest.shape().clone();
232 sc.append(&mut vec![la_matmul.shape()[0], lb.shape()[1]]);
233 let lc = sc.c();
234 let (lc_rest, lc_matmul) = lc.dim_split_at(-2)?;
236 Ok(LayoutMatMulConfig {
237 matmul_type: MatMulType::GEMMX2,
238 lc: lc.to_dim()?,
239 la_rest: Some(la_rest),
240 lb_rest: None,
241 lc_rest: Some(lc_rest),
242 la_matmul: la_matmul.to_dim()?,
243 lb_matmul: lb.to_dim()?,
244 lc_matmul: lc_matmul.to_dim()?,
245 })
246 },
247 (3.., 3..) => {
248 let (la_rest, la_matmul) = la.dim_split_at(-2)?;
250 let (lb_rest, lb_matmul) = lb.dim_split_at(-2)?;
251 rstsr_assert_eq!(la_matmul.shape()[1], lb_matmul.shape()[0], InvalidLayout)?;
252 let (la_rest_b, lb_rest_b) = broadcast_layout(&la_rest, &lb_rest, RowMajor)?;
253 let mut sc = la_rest_b.shape().clone();
255 sc.append(&mut vec![la_matmul.shape()[0], lb_matmul.shape()[1]]);
256 let lc = sc.c();
257 let (lc_rest, lc_matmul) = lc.dim_split_at(-2)?;
259 Ok(LayoutMatMulConfig {
260 matmul_type: MatMulType::GEMMXX,
261 lc: lc.to_dim()?,
262 la_rest: Some(la_rest_b),
263 lb_rest: Some(lb_rest_b),
264 lc_rest: Some(lc_rest),
265 la_matmul: la.to_dim()?,
266 lb_matmul: lb_matmul.to_dim()?,
267 lc_matmul: lc_matmul.to_dim()?,
268 })
269 },
270 (0, _) | (_, 0) => rstsr_invalid!((na, nb), "In matmul, 0-dim is not allowed."),
271 }
272}
273
274pub fn layout_matmul_dyn_row_major_with_lc(
288 la: &Layout<IxD>,
289 lb: &Layout<IxD>,
290 lc: &Layout<IxD>,
291) -> Result<LayoutMatMulConfig<IxD, IxD>> {
292 let na = la.ndim();
293 let nb = lb.ndim();
294 let nc = lc.ndim();
295 match (na, nb) {
296 (1, 1) => {
297 rstsr_assert_eq!(la.shape(), lb.shape(), InvalidLayout)?;
299 rstsr_assert_eq!(nc, 0, InvalidLayout)?;
300 Ok(LayoutMatMulConfig {
301 matmul_type: MatMulType::InnerDot,
302 lc: lc.clone(),
303 la_rest: None,
304 lb_rest: None,
305 lc_rest: None,
306 la_matmul: la.clone(),
307 lb_matmul: lb.clone(),
308 lc_matmul: lc.clone(),
309 })
310 },
311 (2, 2) => {
312 rstsr_assert_eq!(nc, 2, InvalidLayout)?;
314 rstsr_assert_eq!(la.shape()[1], lb.shape()[0], InvalidLayout)?;
315 Ok(LayoutMatMulConfig {
316 matmul_type: MatMulType::GEMM22,
317 lc: lc.clone(),
318 la_rest: None,
319 lb_rest: None,
320 lc_rest: None,
321 la_matmul: la.clone(),
322 lb_matmul: lb.clone(),
323 lc_matmul: lc.clone(),
324 })
325 },
326 (1, 2..) => {
327 rstsr_assert_eq!(nb, nc + 1, InvalidLayout)?;
329 let (la_r, la_m) = la.dim_split_at(-1)?;
330 let (lb_r, lb_m) = lb.dim_split_at(-2)?;
331 let (lc_r, lc_m) = lc.dim_split_at(-1)?;
332 let la_rest = broadcast_layout_to_first(&lc_r, &la_r, RowMajor)?.1;
333 Ok(LayoutMatMulConfig {
334 matmul_type: MatMulType::GEVM,
335 lc: lc.clone(),
336 la_rest: Some(la_rest),
337 lb_rest: Some(lb_r),
338 lc_rest: Some(lc_r),
339 la_matmul: la_m.dim_insert(0)?,
340 lb_matmul: lb_m,
341 lc_matmul: lc_m.dim_insert(0)?,
342 })
343 },
344 (2.., 1) => {
345 rstsr_assert_eq!(na, nc + 1, InvalidLayout)?;
347 let (la_r, la_m) = la.dim_split_at(-2)?;
348 let (lb_r, lb_m) = lb.dim_split_at(-1)?;
349 let (lc_r, lc_m) = lc.dim_split_at(-1)?;
350 let lb_rest = broadcast_layout_to_first(&lc_r, &lb_r, RowMajor)?.1;
351 Ok(LayoutMatMulConfig {
352 matmul_type: MatMulType::GEMV,
353 lc: lc.clone(),
354 la_rest: Some(la_r),
355 lb_rest: Some(lb_rest),
356 lc_rest: Some(lc_r),
357 la_matmul: la_m,
358 lb_matmul: lb_m.dim_insert(1)?,
359 lc_matmul: lc_m.dim_insert(1)?,
360 })
361 },
362 (2, 3..) => {
363 rstsr_assert_eq!(nb, nc, InvalidLayout)?;
365 let (la_r, la_m) = la.dim_split_at(-2)?;
366 let (lb_r, lb_m) = lb.dim_split_at(-2)?;
367 let (lc_r, lc_m) = lc.dim_split_at(-2)?;
368 let la_rest = broadcast_layout_to_first(&lc_r, &la_r, RowMajor)?.1;
369 Ok(LayoutMatMulConfig {
370 matmul_type: MatMulType::GEMM2X,
371 lc: lc.clone(),
372 la_rest: Some(la_rest),
373 lb_rest: Some(lb_r),
374 lc_rest: Some(lc_r),
375 la_matmul: la_m,
376 lb_matmul: lb_m,
377 lc_matmul: lc_m,
378 })
379 },
380 (3.., 2) => {
381 rstsr_assert_eq!(na, nc, InvalidLayout)?;
383 let (la_r, la_m) = la.dim_split_at(-2)?;
384 let (lb_r, lb_m) = lb.dim_split_at(-2)?;
385 let (lc_r, lc_m) = lc.dim_split_at(-2)?;
386 let lb_rest = broadcast_layout_to_first(&lc_r, &lb_r, RowMajor)?.1;
387 Ok(LayoutMatMulConfig {
388 matmul_type: MatMulType::GEMMX2,
389 lc: lc.clone(),
390 la_rest: Some(la_r),
391 lb_rest: Some(lb_rest),
392 lc_rest: Some(lc_r),
393 la_matmul: la_m,
394 lb_matmul: lb_m,
395 lc_matmul: lc_m,
396 })
397 },
398 (3.., 3..) => {
399 rstsr_assert_eq!(na, nc, InvalidLayout)?;
401 rstsr_assert_eq!(nb, nc, InvalidLayout)?;
402 let (la_r, la_m) = la.dim_split_at(-2)?;
403 let (lb_r, lb_m) = lb.dim_split_at(-2)?;
404 let (lc_r, lc_m) = lc.dim_split_at(-2)?;
405 let la_rest = broadcast_layout_to_first(&lc_r, &la_r, RowMajor)?.1;
407 let lb_rest = broadcast_layout_to_first(&lc_r, &lb_r, RowMajor)?.1;
408 Ok(LayoutMatMulConfig {
409 matmul_type: MatMulType::GEMMXX,
410 lc: lc.clone(),
411 la_rest: Some(la_rest),
412 lb_rest: Some(lb_rest),
413 lc_rest: Some(lc_r),
414 la_matmul: la_m,
415 lb_matmul: lb_m,
416 lc_matmul: lc_m,
417 })
418 },
419 (0, _) | (_, 0) => rstsr_invalid!((na, nb), "In matmul, 0-dim is not allowed."),
420 }
421}
422
423fn layout_matmul_dyn_col_major(la: &Layout<IxD>, lb: &Layout<IxD>) -> Result<LayoutMatMulConfig<IxD, IxD>> {
424 let na = la.ndim();
436 let nb = lb.ndim();
437 if na == 0 || nb == 0 {
438 return rstsr_invalid!((na, nb), "In matmul, 0-dim is not allowed.");
439 }
440 let la_rev = la.reverse_axes();
441 let lb_rev = lb.reverse_axes();
442 let cfg = layout_matmul_dyn_row_major(&lb_rev, &la_rev)?;
443 Ok(LayoutMatMulConfig {
444 matmul_type: cfg.matmul_type,
445 lc: cfg.lc.reverse_axes(),
446 la_rest: cfg.lb_rest.map(|l| l.reverse_axes()),
449 lb_rest: cfg.la_rest.map(|l| l.reverse_axes()),
450 lc_rest: cfg.lc_rest.map(|l| l.reverse_axes()),
451 la_matmul: cfg.lb_matmul.reverse_axes(),
452 lb_matmul: cfg.la_matmul.reverse_axes(),
453 lc_matmul: cfg.lc_matmul.reverse_axes(),
454 })
455}
456
457impl LayoutMatMulAPI<IxD, IxD> for LayoutMatMulConfig<IxD, IxD> {
458 type DC = IxD;
459 fn layout_matmul(la: &Layout<IxD>, lb: &Layout<IxD>, order: FlagOrder) -> Result<Self> {
460 match order {
461 RowMajor => layout_matmul_dyn_row_major(la, lb),
462 ColMajor => layout_matmul_dyn_col_major(la, lb),
463 }
464 }
465}
466
467macro_rules! impl_fixed {
468 ($DA:ident, $DB:ident, $DC:ident) => {
469 impl LayoutMatMulAPI<$DA, $DB> for LayoutMatMulConfig<$DA, $DB> {
470 type DC = $DC;
471 fn layout_matmul(la: &Layout<$DA>, lb: &Layout<$DB>, order: FlagOrder) -> Result<Self> {
472 let la = la.to_dim::<IxD>()?;
473 let lb = lb.to_dim::<IxD>()?;
474 let cfg = LayoutMatMulConfig::layout_matmul(&la, &lb, order)?;
475 return Ok(LayoutMatMulConfig {
476 matmul_type: cfg.matmul_type,
477 lc: cfg.lc.into_dim()?,
478 la_rest: cfg.la_rest,
479 lb_rest: cfg.lb_rest,
480 lc_rest: cfg.lc_rest,
481 la_matmul: cfg.la_matmul,
482 lb_matmul: cfg.lb_matmul,
483 lc_matmul: cfg.lc_matmul,
484 });
485 }
486 }
487 };
488}
489
490impl_fixed!(Ix2, Ix1, Ix1);
492impl_fixed!(Ix3, Ix1, Ix2);
493impl_fixed!(Ix4, Ix1, Ix3);
494impl_fixed!(Ix5, Ix1, Ix4);
495impl_fixed!(Ix6, Ix1, Ix5);
496impl_fixed!(Ix7, Ix1, Ix6);
497impl_fixed!(Ix8, Ix1, Ix7);
498impl_fixed!(Ix9, Ix1, Ix8);
499
500impl_fixed!(Ix1, Ix2, Ix1);
502impl_fixed!(Ix1, Ix3, Ix2);
503impl_fixed!(Ix1, Ix4, Ix3);
504impl_fixed!(Ix1, Ix5, Ix4);
505impl_fixed!(Ix1, Ix6, Ix5);
506impl_fixed!(Ix1, Ix7, Ix6);
507impl_fixed!(Ix1, Ix8, Ix7);
508impl_fixed!(Ix1, Ix9, Ix8);
509
510impl_fixed!(Ix3, Ix2, Ix3);
512impl_fixed!(Ix4, Ix2, Ix4);
513impl_fixed!(Ix5, Ix2, Ix5);
514impl_fixed!(Ix6, Ix2, Ix6);
515impl_fixed!(Ix7, Ix2, Ix7);
516impl_fixed!(Ix8, Ix2, Ix8);
517impl_fixed!(Ix9, Ix2, Ix9);
518
519impl_fixed!(Ix2, Ix3, Ix3);
521impl_fixed!(Ix2, Ix4, Ix4);
522impl_fixed!(Ix2, Ix5, Ix5);
523impl_fixed!(Ix2, Ix6, Ix6);
524impl_fixed!(Ix2, Ix7, Ix7);
525impl_fixed!(Ix2, Ix8, Ix8);
526impl_fixed!(Ix2, Ix9, Ix9);
527
528impl_fixed!(Ix3, Ix3, Ix3);
530impl_fixed!(Ix4, Ix4, Ix4);
531impl_fixed!(Ix5, Ix5, Ix5);
532impl_fixed!(Ix6, Ix6, Ix6);
533impl_fixed!(Ix7, Ix7, Ix7);
534impl_fixed!(Ix8, Ix8, Ix8);
535impl_fixed!(Ix9, Ix9, Ix9);
536
537impl_fixed!(Ix1, IxD, IxD);
539impl_fixed!(Ix2, IxD, IxD);
540impl_fixed!(Ix3, IxD, IxD);
541impl_fixed!(Ix4, IxD, IxD);
542impl_fixed!(Ix5, IxD, IxD);
543impl_fixed!(Ix6, IxD, IxD);
544impl_fixed!(Ix7, IxD, IxD);
545impl_fixed!(Ix8, IxD, IxD);
546impl_fixed!(Ix9, IxD, IxD);
547
548impl_fixed!(IxD, Ix1, IxD);
549impl_fixed!(IxD, Ix2, IxD);
550impl_fixed!(IxD, Ix3, IxD);
551impl_fixed!(IxD, Ix4, IxD);
552impl_fixed!(IxD, Ix5, IxD);
553impl_fixed!(IxD, Ix6, IxD);
554impl_fixed!(IxD, Ix7, IxD);
555impl_fixed!(IxD, Ix8, IxD);
556impl_fixed!(IxD, Ix9, IxD);
557
558#[cfg(test)]
559mod test_fixed {
560
561 #[test]
562 fn test_layout_matmul() {
563 use super::*;
564 let la = [4].c();
565 let lb = [4].c();
566 let config = LayoutMatMulConfig::layout_matmul(&la, &lb, RowMajor).unwrap();
567 assert_eq!(config.matmul_type, MatMulType::InnerDot);
568 assert_eq!(config.lc.shape(), &[]);
569 assert_eq!(config.la_matmul.shape(), &[4]);
570 assert_eq!(config.lb_matmul.shape(), &[4]);
571
572 let la = [5].c();
573 let lb = [3, 4, 5, 6].f().swapaxes(0, 1).unwrap();
574 let config = LayoutMatMulConfig::layout_matmul(&la, &lb, RowMajor).unwrap();
575 assert_eq!(config.lc, [4, 3, 6].c());
576
577 let la = [3, 4, 5, 6].f().swapaxes(0, 1).unwrap();
578 let lb = [6].c();
579 let config = LayoutMatMulConfig::layout_matmul(&la, &lb, RowMajor).unwrap();
580 assert_eq!(config.lc, [4, 3, 5].c());
581
582 let la = [7, 6].c();
583 let lb = [2, 3, 4, 5, 6].f().swapaxes(-1, -2).unwrap();
584 let config = LayoutMatMulConfig::layout_matmul(&la, &lb, RowMajor).unwrap();
585 assert_eq!(config.lc, [2, 3, 4, 7, 5].c());
586
587 let la = [2, 3, 4, 5, 6].f().swapaxes(-1, -2).unwrap();
588 let lb = [5, 7].c();
589 let config = LayoutMatMulConfig::layout_matmul(&la, &lb, RowMajor).unwrap();
590 assert_eq!(config.lc, [2, 3, 4, 6, 7].c());
591
592 let la = [4, 1, 2, 5, 6].f().swapaxes(0, 2).unwrap();
593 let lb = [4, 3, 1, 6, 7].f().swapaxes(0, 2).unwrap();
594 let config = LayoutMatMulConfig::layout_matmul(&la, &lb, RowMajor).unwrap();
595 assert_eq!(config.lc, [2, 3, 4, 5, 7].c());
596
597 let la = [4, 3, 2, 5, 6].f().swapaxes(0, 2).unwrap();
598 let lb = [4, 3, 2, 6, 7].f().swapaxes(0, 2).unwrap();
599 let config = LayoutMatMulConfig::layout_matmul(&la, &lb, RowMajor).unwrap();
600 assert_eq!(config.lc, [2, 3, 4, 5, 7].c());
601
602 let la = [5, 6].f();
605 let lb = [6, 7].f();
606 let config = LayoutMatMulConfig::layout_matmul(&la, &lb, ColMajor).unwrap();
607 assert_eq!(config.lc, [5, 7].f());
608
609 let la = [5].f();
611 let lb = [5, 6, 3, 4].f();
612 let config = LayoutMatMulConfig::layout_matmul(&la, &lb, ColMajor).unwrap();
613 assert_eq!(config.lc, [6, 3, 4].f());
614
615 let la = [5, 6, 3, 4].f();
617 let lb = [6].f();
618 let config = LayoutMatMulConfig::layout_matmul(&la, &lb, ColMajor).unwrap();
619 assert_eq!(config.lc, [5, 3, 4].f());
620
621 let la = [5, 6, 2, 1, 4].f();
624 let lb = [6, 7, 1, 3, 4].f();
625 let config = LayoutMatMulConfig::layout_matmul(&la, &lb, ColMajor).unwrap();
626 assert_eq!(config.lc, [5, 7, 2, 3, 4].f());
627 }
628}