cubecl_linalg/matmul/components/
problem.rs

1use crate::matmul::kernels::MatmulInvalidProblem;
2
3use super::{MatrixLayout, batch};
4
5#[derive(Clone, Debug)]
6/// Description of a matmul problem to solve, regardless of actual data
7pub struct MatmulProblem {
8    pub m: usize,
9    pub n: usize,
10    pub k: usize,
11    pub batches: (Vec<usize>, Vec<usize>),
12    pub lhs_layout: MatrixLayout,
13    pub rhs_layout: MatrixLayout,
14    pub lhs_line_size: u8,
15    pub rhs_line_size: u8,
16    pub out_line_size: u8,
17}
18
19impl MatmulProblem {
20    pub(crate) fn batch_dims(&self) -> Vec<usize> {
21        self.batches
22            .0
23            .iter()
24            .rev()
25            .zip(self.batches.1.iter().rev())
26            .map(|(&dim_lhs, &dim_rhs)| std::cmp::max(dim_lhs, dim_rhs))
27            .collect()
28    }
29
30    /// Returns the total number of batches
31    pub(crate) fn num_batches(&self) -> usize {
32        self.batch_dims().iter().product()
33    }
34
35    /// Asserts that the problem can be solved with the given batch matmul configs
36    ///
37    /// # Panics:
38    ///
39    ///  - If dimensions of the problem are larger than allowed by the config
40    ///  - If line sizes do not divide well the dimension in which they are aligned
41    pub fn check_config<B: batch::BatchConfig>(
42        &self,
43        config: &B,
44    ) -> Result<(), MatmulInvalidProblem> {
45        if self.m > config.max_m() as usize {
46            return Err(MatmulInvalidProblem::ExceededMSize {
47                m: self.m as u32,
48                max_m: config.max_m(),
49            });
50        }
51
52        if self.n > config.max_n() as usize {
53            return Err(MatmulInvalidProblem::ExceededNSize {
54                n: self.n as u32,
55                max_n: config.max_n(),
56            });
57        }
58
59        if self.num_batches() > config.max_batches() as usize {
60            return Err(MatmulInvalidProblem::ExceededBatchSize {
61                b: self.num_batches() as u32,
62                max_b: config.max_batches(),
63            });
64        }
65
66        match self.lhs_layout {
67            MatrixLayout::RowMajor => {
68                if self.k % self.lhs_line_size as usize != 0 {
69                    return Err(MatmulInvalidProblem::InvalidLineSizeLhs {
70                        size: self.k as u32,
71                        line_size: self.lhs_line_size,
72                    });
73                }
74            }
75            MatrixLayout::ColMajor => {
76                if self.m % self.lhs_line_size as usize != 0 {
77                    return Err(MatmulInvalidProblem::InvalidLineSizeLhs {
78                        size: self.m as u32,
79                        line_size: self.lhs_line_size,
80                    });
81                }
82            }
83        }
84
85        match self.rhs_layout {
86            MatrixLayout::RowMajor => {
87                if self.n % self.rhs_line_size as usize != 0 {
88                    return Err(MatmulInvalidProblem::InvalidLineSizeRhs {
89                        size: self.n as u32,
90                        line_size: self.rhs_line_size,
91                    });
92                }
93            }
94            MatrixLayout::ColMajor => {
95                if self.k % self.rhs_line_size as usize != 0 {
96                    return Err(MatmulInvalidProblem::InvalidLineSizeRhs {
97                        size: self.k as u32,
98                        line_size: self.lhs_line_size,
99                    });
100                }
101            }
102        }
103
104        if self.n % self.out_line_size as usize != 0 {
105            return Err(MatmulInvalidProblem::InvalidLineSizeOut {
106                size: self.n as u32,
107                line_size: self.out_line_size,
108            });
109        }
110
111        Ok(())
112    }
113}