cubecl_matmul/components/
line_size.rs

1use cubecl_core::{LineSizeError, Runtime, tensor_line_size_parallel};
2
3use crate::components::{MatrixLayout, error::MatmulSetupError};
4use std::fmt::Debug;
5
6#[derive(Debug, PartialEq, Eq, Clone, Copy)]
7/// Line size used for each tensor in global memory accesses.
8/// Represents the number of elements processed per SIMD load/store.
9pub struct MatmulLineSizes {
10    pub lhs: u8,
11    pub rhs: u8,
12    pub out: u8,
13}
14
15#[derive(Clone, Debug)]
16/// Candidate line sizes supported for each tensor.
17///
18/// These lists begin with compiler-supported sizes and are progressively
19/// filtered based on problem shape divisibility and hardware constraints.
20pub struct AvailableLineSizes {
21    pub lhs: Vec<u8>,
22    pub rhs: Vec<u8>,
23    pub out: Vec<u8>,
24}
25
26impl AvailableLineSizes {
27    pub fn from_type_sizes<R: Runtime>(elem_lhs: usize, elem_rhs: usize, elem_out: usize) -> Self {
28        AvailableLineSizes {
29            lhs: R::io_optimized_line_sizes_unchecked(elem_lhs).collect(),
30            rhs: R::io_optimized_line_sizes_unchecked(elem_rhs).collect(),
31            out: R::io_optimized_line_sizes_unchecked(elem_out).collect(),
32        }
33    }
34
35    /// Filter available line sizes considering tensor shapes and strides for Lhs
36    pub fn filter_lhs_with_tensor(
37        self,
38        strides: &[usize],
39        shape: &[usize],
40        layout: MatrixLayout,
41    ) -> Self {
42        let lhs_vec: Vec<u8> = self.lhs.to_vec();
43        let rank = strides.len();
44
45        let target = tensor_line_size_parallel(
46            lhs_vec.iter().copied(),
47            shape,
48            strides,
49            match layout {
50                MatrixLayout::RowMajor => rank - 1,
51                MatrixLayout::ColMajor => rank - 2,
52            },
53        );
54
55        self.filter_lhs(move |x| *x == target)
56    }
57
58    /// Filter available line sizes considering tensor shapes and strides for Rhs
59    pub fn filter_rhs_with_tensor(
60        self,
61        strides: &[usize],
62        shape: &[usize],
63        layout: MatrixLayout,
64    ) -> Self {
65        let rhs_vec: Vec<u8> = self.rhs.to_vec();
66        let rank = strides.len();
67
68        let target = tensor_line_size_parallel(
69            rhs_vec.iter().copied(),
70            shape,
71            strides,
72            match layout {
73                MatrixLayout::RowMajor => rank - 1,
74                MatrixLayout::ColMajor => rank - 2,
75            },
76        );
77
78        self.filter_rhs(move |x| *x == target)
79    }
80
81    /// Filter available line sizes considering tensor shapes and strides for output
82    pub fn filter_out_with_tensor(self, strides: &[usize], shape: &[usize]) -> Self {
83        let out_vec: Vec<u8> = self.out.to_vec();
84        let rank = strides.len();
85
86        let target = tensor_line_size_parallel(out_vec.iter().copied(), shape, strides, rank - 1);
87
88        self.filter_out(move |x| *x == target)
89    }
90
91    /// Filter available line sizes for Lhs
92    pub fn filter_lhs<F>(self, pred: F) -> Self
93    where
94        F: FnMut(&u8) -> bool,
95    {
96        Self {
97            lhs: self.lhs.iter().copied().filter(pred).collect(),
98            rhs: self.rhs,
99            out: self.out,
100        }
101    }
102
103    /// Filter available line sizes for Rhs
104    pub fn filter_rhs<F>(self, pred: F) -> Self
105    where
106        F: FnMut(&u8) -> bool,
107    {
108        Self {
109            lhs: self.lhs,
110            rhs: self.rhs.iter().copied().filter(pred).collect(),
111            out: self.out,
112        }
113    }
114
115    /// Filter available line sizes for output
116    pub fn filter_out<F>(self, pred: F) -> Self
117    where
118        F: FnMut(&u8) -> bool,
119    {
120        Self {
121            lhs: self.lhs,
122            rhs: self.rhs,
123            out: self.out.iter().copied().filter(pred).collect(),
124        }
125    }
126
127    /// Pick the largest remaining line size for each tensor
128    pub fn pick_max(self) -> Result<MatmulLineSizes, MatmulSetupError> {
129        let pick = |v: Vec<u8>| {
130            v.into_iter()
131                .max()
132                .ok_or(MatmulSetupError::LineSize(LineSizeError::NoValidLineSize))
133        };
134
135        Ok(MatmulLineSizes {
136            lhs: pick(self.lhs)?,
137            rhs: pick(self.rhs)?,
138            out: pick(self.out)?,
139        })
140    }
141}