1use crate::prelude_dev::*;
26
27#[derive(Clone, Copy, Debug, PartialEq, Eq)]
29pub enum MatMulType {
30 InnerDot,
31 GEMM22,
32 GEVM,
33 GEMV,
34 GEMM2X,
35 GEMMX2,
36 GEMMXX,
37}
38
39#[derive(Clone, Debug)]
40pub struct LayoutMatMulConfig<DA, DB>
41where
42 DA: DimAPI,
43 DB: DimAPI,
44 Self: LayoutMatMulAPI<DA, DB>,
45{
46 pub matmul_type: MatMulType,
47 pub lc: Layout<<Self as LayoutMatMulAPI<DA, DB>>::DC>,
48 pub la_rest: Option<Layout<IxD>>,
49 pub lb_rest: Option<Layout<IxD>>,
50 pub lc_rest: Option<Layout<IxD>>,
51 pub la_matmul: Layout<IxD>,
52 pub lb_matmul: Layout<IxD>,
53 pub lc_matmul: Layout<IxD>,
54}
55
56pub trait LayoutMatMulAPI<DA, DB>
57where
58 DA: DimAPI,
59 DB: DimAPI,
60 Self: Sized,
61{
62 type DC: DimAPI;
63 fn layout_matmul(la: &Layout<DA>, lb: &Layout<DB>, order: FlagOrder) -> Result<Self>;
67}
68
69impl LayoutMatMulAPI<Ix1, Ix1> for LayoutMatMulConfig<Ix1, Ix1> {
71 type DC = Ix0;
72 fn layout_matmul(la: &Layout<Ix1>, lb: &Layout<Ix1>, _: FlagOrder) -> Result<Self> {
73 rstsr_assert_eq!(la.shape(), lb.shape(), InvalidLayout)?;
75 let lc = unsafe { Layout::new_unchecked([], [], 0) };
76 Ok(LayoutMatMulConfig {
77 matmul_type: MatMulType::InnerDot,
78 lc: lc.clone(),
79 la_rest: None,
80 lb_rest: None,
81 lc_rest: None,
82 la_matmul: la.to_dim()?,
83 lb_matmul: lb.to_dim()?,
84 lc_matmul: lc.to_dim()?,
85 })
86 }
87}
88
89impl LayoutMatMulAPI<Ix2, Ix2> for LayoutMatMulConfig<Ix2, Ix2> {
91 type DC = Ix2;
92 fn layout_matmul(la: &Layout<Ix2>, lb: &Layout<Ix2>, order: FlagOrder) -> Result<Self> {
93 rstsr_assert_eq!(la.shape()[1], lb.shape()[0], InvalidLayout)?;
95 let sc = [la.shape()[0], lb.shape()[1]];
96 let lc = match order {
98 RowMajor => sc.c(),
99 ColMajor => sc.f(),
100 };
101 Ok(LayoutMatMulConfig {
103 matmul_type: MatMulType::GEMM22,
104 lc: lc.clone(),
105 la_rest: None,
106 lb_rest: None,
107 lc_rest: None,
108 la_matmul: la.to_dim()?,
109 lb_matmul: lb.to_dim()?,
110 lc_matmul: lc.to_dim()?,
111 })
112 }
113}
114
115fn layout_matmul_dyn_row_major(la: &Layout<IxD>, lb: &Layout<IxD>) -> Result<LayoutMatMulConfig<IxD, IxD>> {
116 let na = la.ndim();
117 let nb = lb.ndim();
118 match (na, nb) {
119 (1, 1) => {
120 rstsr_assert_eq!(la.shape(), lb.shape(), InvalidLayout)?;
122 let lc = unsafe { Layout::new_unchecked(vec![], vec![], 0) };
123 Ok(LayoutMatMulConfig {
124 matmul_type: MatMulType::InnerDot,
125 lc: lc.clone(),
126 la_rest: None,
127 lb_rest: None,
128 lc_rest: None,
129 la_matmul: la.to_dim()?,
130 lb_matmul: lb.to_dim()?,
131 lc_matmul: lc.to_dim()?,
132 })
133 },
134 (2, 2) => {
135 rstsr_assert_eq!(la.shape()[1], lb.shape()[0], InvalidLayout)?;
138 let sc = vec![la.shape()[0], lb.shape()[1]];
139 let lc = sc.c();
141 Ok(LayoutMatMulConfig {
143 matmul_type: MatMulType::GEMM22,
144 lc: lc.clone(),
145 la_rest: None,
146 lb_rest: None,
147 lc_rest: None,
148 la_matmul: la.to_dim()?,
149 lb_matmul: lb.to_dim()?,
150 lc_matmul: lc.to_dim()?,
151 })
152 },
153 (1, 2..) => {
154 let (lb_rest, lb_matmul) = lb.dim_split_at(-2)?;
157 rstsr_assert_eq!(la.shape()[0], lb_matmul.shape()[0], InvalidLayout)?;
158 let mut sc = lb_rest.shape().clone();
160 sc.push(lb_matmul.shape()[1]);
161 let lc = sc.c();
162 let (lc_rest, lc_matmul) = lc.dim_split_at(-1)?;
164 Ok(LayoutMatMulConfig {
165 matmul_type: MatMulType::GEVM,
166 lc: lc.to_dim()?,
167 la_rest: None,
168 lb_rest: Some(lb_rest),
169 lc_rest: Some(lc_rest),
170 la_matmul: la.to_dim()?,
171 lb_matmul: lb_matmul.to_dim()?,
172 lc_matmul: lc_matmul.to_dim()?,
173 })
174 },
175 (2.., 1) => {
176 let (la_rest, la_matmul) = la.dim_split_at(-2)?;
179 rstsr_assert_eq!(lb.shape()[0], la_matmul.shape()[1], InvalidLayout)?;
180 let mut sc = la_rest.shape().clone();
182 sc.push(la_matmul.shape()[0]);
183 let lc = sc.c();
184 let (lc_rest, lc_matmul) = lc.dim_split_at(-1)?;
186 Ok(LayoutMatMulConfig {
187 matmul_type: MatMulType::GEMV,
188 lc: lc.to_dim()?,
189 la_rest: Some(la_rest),
190 lb_rest: None,
191 lc_rest: Some(lc_rest),
192 la_matmul: la_matmul.to_dim()?,
193 lb_matmul: lb.to_dim()?,
194 lc_matmul: lc_matmul.to_dim()?,
195 })
196 },
197 (2, 3..) => {
198 let (lb_rest, lb_matmul) = lb.dim_split_at(-2)?;
201 rstsr_assert_eq!(la.shape()[1], lb_matmul.shape()[0], InvalidLayout)?;
202 let mut sc = lb_rest.shape().clone();
204 sc.append(&mut vec![la.shape()[0], lb_matmul.shape()[1]]);
205 let lc = sc.c();
206 let (lc_rest, lc_matmul) = lc.dim_split_at(-2)?;
208 Ok(LayoutMatMulConfig {
209 matmul_type: MatMulType::GEMM2X,
210 lc: lc.to_dim()?,
211 la_rest: None,
212 lb_rest: Some(lb_rest),
213 lc_rest: Some(lc_rest),
214 la_matmul: la.to_dim()?,
215 lb_matmul: lb_matmul.to_dim()?,
216 lc_matmul: lc_matmul.to_dim()?,
217 })
218 },
219 (3.., 2) => {
220 let (la_rest, la_matmul) = la.dim_split_at(-2)?;
223 rstsr_assert_eq!(la_matmul.shape()[1], lb.shape()[0], InvalidLayout)?;
224 let mut sc = la_rest.shape().clone();
226 sc.append(&mut vec![la_matmul.shape()[0], lb.shape()[1]]);
227 let lc = sc.c();
228 let (lc_rest, lc_matmul) = lc.dim_split_at(-2)?;
230 Ok(LayoutMatMulConfig {
231 matmul_type: MatMulType::GEMMX2,
232 lc: lc.to_dim()?,
233 la_rest: Some(la_rest),
234 lb_rest: None,
235 lc_rest: Some(lc_rest),
236 la_matmul: la_matmul.to_dim()?,
237 lb_matmul: lb.to_dim()?,
238 lc_matmul: lc_matmul.to_dim()?,
239 })
240 },
241 (3.., 3..) => {
242 let (la_rest, la_matmul) = la.dim_split_at(-2)?;
244 let (lb_rest, lb_matmul) = lb.dim_split_at(-2)?;
245 rstsr_assert_eq!(la_matmul.shape()[1], lb_matmul.shape()[0], InvalidLayout)?;
246 let (la_rest_b, lb_rest_b) = broadcast_layout(&la_rest, &lb_rest, RowMajor)?;
247 let mut sc = la_rest_b.shape().clone();
249 sc.append(&mut vec![la_matmul.shape()[0], lb_matmul.shape()[1]]);
250 let lc = sc.c();
251 let (lc_rest, lc_matmul) = lc.dim_split_at(-2)?;
253 Ok(LayoutMatMulConfig {
254 matmul_type: MatMulType::GEMMXX,
255 lc: lc.to_dim()?,
256 la_rest: Some(la_rest_b),
257 lb_rest: Some(lb_rest_b),
258 lc_rest: Some(lc_rest),
259 la_matmul: la.to_dim()?,
260 lb_matmul: lb_matmul.to_dim()?,
261 lc_matmul: lc_matmul.to_dim()?,
262 })
263 },
264 (0, _) | (_, 0) => rstsr_invalid!((na, nb), "In matmul, 0-dim is not allowed."),
265 }
266}
267
268fn layout_matmul_dyn_col_major(la: &Layout<IxD>, lb: &Layout<IxD>) -> Result<LayoutMatMulConfig<IxD, IxD>> {
269 let na = la.ndim();
270 let nb = lb.ndim();
271 match (na, nb) {
272 (1, 1) => {
273 rstsr_assert_eq!(la.shape(), lb.shape(), InvalidLayout)?;
275 let lc = unsafe { Layout::new_unchecked(vec![], vec![], 0) };
276 Ok(LayoutMatMulConfig {
277 matmul_type: MatMulType::InnerDot,
278 lc: lc.clone(),
279 la_rest: None,
280 lb_rest: None,
281 lc_rest: None,
282 la_matmul: la.to_dim()?,
283 lb_matmul: lb.to_dim()?,
284 lc_matmul: lc.to_dim()?,
285 })
286 },
287 (2, 2) => {
288 rstsr_assert_eq!(la.shape()[1], lb.shape()[0], InvalidLayout)?;
291 let sc = vec![la.shape()[0], lb.shape()[1]];
292 let lc = sc.f();
294 Ok(LayoutMatMulConfig {
296 matmul_type: MatMulType::GEMM22,
297 lc: lc.clone(),
298 la_rest: None,
299 lb_rest: None,
300 lc_rest: None,
301 la_matmul: la.to_dim()?,
302 lb_matmul: lb.to_dim()?,
303 lc_matmul: lc.to_dim()?,
304 })
305 },
306 (1, 2) => {
307 rstsr_assert_eq!(la.shape()[0], lb.shape()[0], InvalidLayout)?;
310 let sc = vec![lb.shape()[1]];
311 let lc = sc.f();
312 Ok(LayoutMatMulConfig {
313 matmul_type: MatMulType::GEVM,
314 lc: lc.to_dim()?,
315 la_rest: None,
316 lb_rest: None,
317 lc_rest: None,
318 la_matmul: la.to_dim()?,
319 lb_matmul: lb.to_dim()?,
320 lc_matmul: lc.to_dim()?,
321 })
322 },
323 (2, 1) => {
324 rstsr_assert_eq!(la.shape()[1], lb.shape()[0], InvalidLayout)?;
327 let sc = vec![la.shape()[0]];
328 let lc = sc.f();
329 Ok(LayoutMatMulConfig {
331 matmul_type: MatMulType::GEMV,
332 lc: lc.to_dim()?,
333 la_rest: None,
334 lb_rest: None,
335 lc_rest: None,
336 la_matmul: la.to_dim()?,
337 lb_matmul: lb.to_dim()?,
338 lc_matmul: lc.to_dim()?,
339 })
340 },
341 (1, 3..) | (3.., 1) | (2, 3..) | (3.., 2) | (3.., 3..) => {
342 rstsr_raise!(InvalidLayout, "Broadcasting matmul is not supported in col-major.")
343 },
344 (0, _) | (_, 0) => rstsr_invalid!((na, nb), "In matmul, 0-dim is not allowed."),
345 }
346}
347
348impl LayoutMatMulAPI<IxD, IxD> for LayoutMatMulConfig<IxD, IxD> {
349 type DC = IxD;
350 fn layout_matmul(la: &Layout<IxD>, lb: &Layout<IxD>, order: FlagOrder) -> Result<Self> {
351 match order {
352 RowMajor => layout_matmul_dyn_row_major(la, lb),
353 ColMajor => layout_matmul_dyn_col_major(la, lb),
354 }
355 }
356}
357
358macro_rules! impl_fixed {
359 ($DA:ident, $DB:ident, $DC:ident) => {
360 impl LayoutMatMulAPI<$DA, $DB> for LayoutMatMulConfig<$DA, $DB> {
361 type DC = $DC;
362 fn layout_matmul(la: &Layout<$DA>, lb: &Layout<$DB>, order: FlagOrder) -> Result<Self> {
363 let la = la.to_dim::<IxD>()?;
364 let lb = lb.to_dim::<IxD>()?;
365 let cfg = LayoutMatMulConfig::layout_matmul(&la, &lb, order)?;
366 return Ok(LayoutMatMulConfig {
367 matmul_type: cfg.matmul_type,
368 lc: cfg.lc.into_dim()?,
369 la_rest: cfg.la_rest,
370 lb_rest: cfg.lb_rest,
371 lc_rest: cfg.lc_rest,
372 la_matmul: cfg.la_matmul,
373 lb_matmul: cfg.lb_matmul,
374 lc_matmul: cfg.lc_matmul,
375 });
376 }
377 }
378 };
379}
380
381impl_fixed!(Ix2, Ix1, Ix1);
383impl_fixed!(Ix3, Ix1, Ix2);
384impl_fixed!(Ix4, Ix1, Ix3);
385impl_fixed!(Ix5, Ix1, Ix4);
386impl_fixed!(Ix6, Ix1, Ix5);
387impl_fixed!(Ix7, Ix1, Ix6);
388impl_fixed!(Ix8, Ix1, Ix7);
389impl_fixed!(Ix9, Ix1, Ix8);
390
391impl_fixed!(Ix1, Ix2, Ix1);
393impl_fixed!(Ix1, Ix3, Ix2);
394impl_fixed!(Ix1, Ix4, Ix3);
395impl_fixed!(Ix1, Ix5, Ix4);
396impl_fixed!(Ix1, Ix6, Ix5);
397impl_fixed!(Ix1, Ix7, Ix6);
398impl_fixed!(Ix1, Ix8, Ix7);
399impl_fixed!(Ix1, Ix9, Ix8);
400
401impl_fixed!(Ix3, Ix2, Ix3);
403impl_fixed!(Ix4, Ix2, Ix4);
404impl_fixed!(Ix5, Ix2, Ix5);
405impl_fixed!(Ix6, Ix2, Ix6);
406impl_fixed!(Ix7, Ix2, Ix7);
407impl_fixed!(Ix8, Ix2, Ix8);
408impl_fixed!(Ix9, Ix2, Ix9);
409
410impl_fixed!(Ix2, Ix3, Ix3);
412impl_fixed!(Ix2, Ix4, Ix4);
413impl_fixed!(Ix2, Ix5, Ix5);
414impl_fixed!(Ix2, Ix6, Ix6);
415impl_fixed!(Ix2, Ix7, Ix7);
416impl_fixed!(Ix2, Ix8, Ix8);
417impl_fixed!(Ix2, Ix9, Ix9);
418
419impl_fixed!(Ix3, Ix3, Ix3);
421impl_fixed!(Ix4, Ix4, Ix4);
422impl_fixed!(Ix5, Ix5, Ix5);
423impl_fixed!(Ix6, Ix6, Ix6);
424impl_fixed!(Ix7, Ix7, Ix7);
425impl_fixed!(Ix8, Ix8, Ix8);
426impl_fixed!(Ix9, Ix9, Ix9);
427
428impl_fixed!(Ix1, IxD, IxD);
430impl_fixed!(Ix2, IxD, IxD);
431impl_fixed!(Ix3, IxD, IxD);
432impl_fixed!(Ix4, IxD, IxD);
433impl_fixed!(Ix5, IxD, IxD);
434impl_fixed!(Ix6, IxD, IxD);
435impl_fixed!(Ix7, IxD, IxD);
436impl_fixed!(Ix8, IxD, IxD);
437impl_fixed!(Ix9, IxD, IxD);
438
439impl_fixed!(IxD, Ix1, IxD);
440impl_fixed!(IxD, Ix2, IxD);
441impl_fixed!(IxD, Ix3, IxD);
442impl_fixed!(IxD, Ix4, IxD);
443impl_fixed!(IxD, Ix5, IxD);
444impl_fixed!(IxD, Ix6, IxD);
445impl_fixed!(IxD, Ix7, IxD);
446impl_fixed!(IxD, Ix8, IxD);
447impl_fixed!(IxD, Ix9, IxD);
448
449#[cfg(test)]
450mod test_fixed {
451
452 #[test]
453 fn test_layout_matmul() {
454 use super::*;
455 let la = [4].c();
456 let lb = [4].c();
457 let config = LayoutMatMulConfig::layout_matmul(&la, &lb, RowMajor).unwrap();
458 assert_eq!(config.matmul_type, MatMulType::InnerDot);
459 assert_eq!(config.lc.shape(), &[]);
460 assert_eq!(config.la_matmul.shape(), &[4]);
461 assert_eq!(config.lb_matmul.shape(), &[4]);
462
463 let la = [5].c();
464 let lb = [3, 4, 5, 6].f().swapaxes(0, 1).unwrap();
465 let config = LayoutMatMulConfig::layout_matmul(&la, &lb, RowMajor).unwrap();
466 assert_eq!(config.lc, [4, 3, 6].c());
467
468 let la = [3, 4, 5, 6].f().swapaxes(0, 1).unwrap();
469 let lb = [6].c();
470 let config = LayoutMatMulConfig::layout_matmul(&la, &lb, RowMajor).unwrap();
471 assert_eq!(config.lc, [4, 3, 5].c());
472
473 let la = [7, 6].c();
474 let lb = [2, 3, 4, 5, 6].f().swapaxes(-1, -2).unwrap();
475 let config = LayoutMatMulConfig::layout_matmul(&la, &lb, RowMajor).unwrap();
476 assert_eq!(config.lc, [2, 3, 4, 7, 5].c());
477
478 let la = [2, 3, 4, 5, 6].f().swapaxes(-1, -2).unwrap();
479 let lb = [5, 7].c();
480 let config = LayoutMatMulConfig::layout_matmul(&la, &lb, RowMajor).unwrap();
481 assert_eq!(config.lc, [2, 3, 4, 6, 7].c());
482
483 let la = [4, 1, 2, 5, 6].f().swapaxes(0, 2).unwrap();
484 let lb = [4, 3, 1, 6, 7].f().swapaxes(0, 2).unwrap();
485 let config = LayoutMatMulConfig::layout_matmul(&la, &lb, RowMajor).unwrap();
486 assert_eq!(config.lc, [2, 3, 4, 5, 7].c());
487
488 let la = [4, 3, 2, 5, 6].f().swapaxes(0, 2).unwrap();
489 let lb = [4, 3, 2, 6, 7].f().swapaxes(0, 2).unwrap();
490 let config = LayoutMatMulConfig::layout_matmul(&la, &lb, RowMajor).unwrap();
491 assert_eq!(config.lc, [2, 3, 4, 5, 7].c());
492
493 let la = [4, 3, 2, 5, 6].f().swapaxes(0, 2).unwrap();
494 let lb = [4, 3, 2, 6, 7].f().swapaxes(0, 2).unwrap();
495 let config = LayoutMatMulConfig::layout_matmul(&la, &lb, ColMajor);
496 assert!(config.is_err());
497
498 let la = [5, 6].c();
499 let lb = [6, 7].c();
500 let config = LayoutMatMulConfig::layout_matmul(&la, &lb, ColMajor).unwrap();
501 assert_eq!(config.lc, [5, 7].f());
502 }
503}