cubecl_linalg/matmul/components/
problem.rs1use crate::matmul::kernels::MatmulInvalidProblem;
2
3use super::{MatrixLayout, batch};
4
5#[derive(Clone, Debug)]
6pub struct MatmulProblem {
8 pub m: usize,
9 pub n: usize,
10 pub k: usize,
11 pub batches: (Vec<usize>, Vec<usize>),
12 pub lhs_layout: MatrixLayout,
13 pub rhs_layout: MatrixLayout,
14 pub lhs_line_size: u8,
15 pub rhs_line_size: u8,
16 pub out_line_size: u8,
17}
18
19impl MatmulProblem {
20 pub(crate) fn batch_dims(&self) -> Vec<usize> {
21 self.batches
22 .0
23 .iter()
24 .rev()
25 .zip(self.batches.1.iter().rev())
26 .map(|(&dim_lhs, &dim_rhs)| std::cmp::max(dim_lhs, dim_rhs))
27 .collect()
28 }
29
30 pub(crate) fn num_batches(&self) -> usize {
32 self.batch_dims().iter().product()
33 }
34
35 pub fn check_config<B: batch::BatchConfig>(
42 &self,
43 config: &B,
44 ) -> Result<(), MatmulInvalidProblem> {
45 if self.m > config.max_m() as usize {
46 return Err(MatmulInvalidProblem::ExceededMSize {
47 m: self.m as u32,
48 max_m: config.max_m(),
49 });
50 }
51
52 if self.n > config.max_n() as usize {
53 return Err(MatmulInvalidProblem::ExceededNSize {
54 n: self.n as u32,
55 max_n: config.max_n(),
56 });
57 }
58
59 if self.num_batches() > config.max_batches() as usize {
60 return Err(MatmulInvalidProblem::ExceededBatchSize {
61 b: self.num_batches() as u32,
62 max_b: config.max_batches(),
63 });
64 }
65
66 match self.lhs_layout {
67 MatrixLayout::RowMajor => {
68 if self.k % self.lhs_line_size as usize != 0 {
69 return Err(MatmulInvalidProblem::InvalidLineSizeLhs {
70 size: self.k as u32,
71 line_size: self.lhs_line_size,
72 });
73 }
74 }
75 MatrixLayout::ColMajor => {
76 if self.m % self.lhs_line_size as usize != 0 {
77 return Err(MatmulInvalidProblem::InvalidLineSizeLhs {
78 size: self.m as u32,
79 line_size: self.lhs_line_size,
80 });
81 }
82 }
83 }
84
85 match self.rhs_layout {
86 MatrixLayout::RowMajor => {
87 if self.n % self.rhs_line_size as usize != 0 {
88 return Err(MatmulInvalidProblem::InvalidLineSizeRhs {
89 size: self.n as u32,
90 line_size: self.rhs_line_size,
91 });
92 }
93 }
94 MatrixLayout::ColMajor => {
95 if self.k % self.rhs_line_size as usize != 0 {
96 return Err(MatmulInvalidProblem::InvalidLineSizeRhs {
97 size: self.k as u32,
98 line_size: self.lhs_line_size,
99 });
100 }
101 }
102 }
103
104 if self.n % self.out_line_size as usize != 0 {
105 return Err(MatmulInvalidProblem::InvalidLineSizeOut {
106 size: self.n as u32,
107 line_size: self.out_line_size,
108 });
109 }
110
111 Ok(())
112 }
113}