cubecl_matmul/components/
line_size.rs1use cubecl_core::{LineSizeError, Runtime, tensor_line_size_parallel};
2
3use crate::components::{MatrixLayout, error::MatmulSetupError};
4use std::fmt::Debug;
5
6#[derive(Debug, PartialEq, Eq, Clone, Copy)]
7pub struct MatmulLineSizes {
10 pub lhs: u8,
11 pub rhs: u8,
12 pub out: u8,
13}
14
15#[derive(Clone, Debug)]
16pub struct AvailableLineSizes {
21 pub lhs: Vec<u8>,
22 pub rhs: Vec<u8>,
23 pub out: Vec<u8>,
24}
25
26impl AvailableLineSizes {
27 pub fn from_type_sizes<R: Runtime>(elem_lhs: usize, elem_rhs: usize, elem_out: usize) -> Self {
28 AvailableLineSizes {
29 lhs: R::io_optimized_line_sizes_unchecked(elem_lhs).collect(),
30 rhs: R::io_optimized_line_sizes_unchecked(elem_rhs).collect(),
31 out: R::io_optimized_line_sizes_unchecked(elem_out).collect(),
32 }
33 }
34
35 pub fn filter_lhs_with_tensor(
37 self,
38 strides: &[usize],
39 shape: &[usize],
40 layout: MatrixLayout,
41 ) -> Self {
42 let lhs_vec: Vec<u8> = self.lhs.to_vec();
43 let rank = strides.len();
44
45 let target = tensor_line_size_parallel(
46 lhs_vec.iter().copied(),
47 shape,
48 strides,
49 match layout {
50 MatrixLayout::RowMajor => rank - 1,
51 MatrixLayout::ColMajor => rank - 2,
52 },
53 );
54
55 self.filter_lhs(move |x| *x == target)
56 }
57
58 pub fn filter_rhs_with_tensor(
60 self,
61 strides: &[usize],
62 shape: &[usize],
63 layout: MatrixLayout,
64 ) -> Self {
65 let rhs_vec: Vec<u8> = self.rhs.to_vec();
66 let rank = strides.len();
67
68 let target = tensor_line_size_parallel(
69 rhs_vec.iter().copied(),
70 shape,
71 strides,
72 match layout {
73 MatrixLayout::RowMajor => rank - 1,
74 MatrixLayout::ColMajor => rank - 2,
75 },
76 );
77
78 self.filter_rhs(move |x| *x == target)
79 }
80
81 pub fn filter_out_with_tensor(self, strides: &[usize], shape: &[usize]) -> Self {
83 let out_vec: Vec<u8> = self.out.to_vec();
84 let rank = strides.len();
85
86 let target = tensor_line_size_parallel(out_vec.iter().copied(), shape, strides, rank - 1);
87
88 self.filter_out(move |x| *x == target)
89 }
90
91 pub fn filter_lhs<F>(self, pred: F) -> Self
93 where
94 F: FnMut(&u8) -> bool,
95 {
96 Self {
97 lhs: self.lhs.iter().copied().filter(pred).collect(),
98 rhs: self.rhs,
99 out: self.out,
100 }
101 }
102
103 pub fn filter_rhs<F>(self, pred: F) -> Self
105 where
106 F: FnMut(&u8) -> bool,
107 {
108 Self {
109 lhs: self.lhs,
110 rhs: self.rhs.iter().copied().filter(pred).collect(),
111 out: self.out,
112 }
113 }
114
115 pub fn filter_out<F>(self, pred: F) -> Self
117 where
118 F: FnMut(&u8) -> bool,
119 {
120 Self {
121 lhs: self.lhs,
122 rhs: self.rhs,
123 out: self.out.iter().copied().filter(pred).collect(),
124 }
125 }
126
127 pub fn pick_max(self) -> Result<MatmulLineSizes, MatmulSetupError> {
129 let pick = |v: Vec<u8>| {
130 v.into_iter()
131 .max()
132 .ok_or(MatmulSetupError::LineSize(LineSizeError::NoValidLineSize))
133 };
134
135 Ok(MatmulLineSizes {
136 lhs: pick(self.lhs)?,
137 rhs: pick(self.rhs)?,
138 out: pick(self.out)?,
139 })
140 }
141}