cubecl_linalg/matmul/components/tile/
accelerated.rs

1use crate::matmul::components::config::MatmulConfig;
2use crate::matmul::components::tile::{TileConfig, TileMatmul, TileMatmulFamily};
3use crate::matmul::components::{
4    Ident, InvalidConfigError, MatmulConfigFactory, MatmulPrecision, MatmulProblem, MatmulSize,
5    MatrixLayout, as_cmma_layout,
6};
7use crate::matmul::kernels::MatmulAvailabilityError;
8use cubecl_core::ir::{Elem, FloatKind};
9use cubecl_core::{self as cubecl, Feature};
10use cubecl_core::{cmma, prelude::*};
11
12use super::{Tile, TileMatmulConfigInput};
13
14pub struct Accelerated;
15
16impl TileMatmulFamily for Accelerated {
17    type Matmul<MP: MatmulPrecision> = Accelerated;
18
19    fn tile_shape(config: &Self::Config) -> MatmulSize {
20        config.size
21    }
22
23    fn requires_tensor_cores() -> bool {
24        true
25    }
26}
27
28#[cube]
29impl<MP: MatmulPrecision> TileMatmul<MP> for Accelerated {
30    type Config = Config;
31    type Lhs = cmma::Matrix<MP::ES>;
32    type Rhs = cmma::Matrix<MP::ES>;
33    type Accumulator = cmma::Matrix<MP::EA>;
34
35    fn execute(
36        lhs: &Self::Lhs,
37        rhs: &Self::Rhs,
38        out: &mut Self::Accumulator,
39        #[comptime] _config: Config,
40    ) {
41        cmma::execute::<MP::ES, MP::ES, MP::EA, MP::EA>(lhs, rhs, out, out);
42    }
43
44    fn allocate_lhs(#[comptime] config: Config) -> Self::Lhs {
45        let size = config.size;
46        let layout = config.matrix_layout(Ident::Lhs);
47        unsafe {
48            cmma::Matrix::<MP::ES>::uninitialized(
49                cmma::MatrixIdent::A, // Check versus Ident
50                size.m,
51                size.n,
52                size.k,
53                as_cmma_layout(layout),
54            )
55        }
56    }
57
58    fn allocate_rhs(#[comptime] config: Config) -> Self::Rhs {
59        let size = config.size;
60        let layout = config.matrix_layout(Ident::Rhs);
61        unsafe {
62            cmma::Matrix::<MP::ES>::uninitialized(
63                cmma::MatrixIdent::B,
64                size.m,
65                size.n,
66                size.k,
67                as_cmma_layout(layout),
68            )
69        }
70    }
71
72    fn fill_lhs(tile: &Tile<MP::ES>, lhs: &mut Self::Lhs, #[comptime] config: Config) {
73        let (slice, stride) = tile.as_unlined::<Config>(Ident::Lhs, config);
74        cmma::load(lhs, &slice, stride);
75    }
76
77    fn fill_rhs(tile: &Tile<MP::ES>, rhs: &mut Self::Rhs, #[comptime] config: Config) {
78        let (slice, stride) = tile.as_unlined::<Config>(Ident::Rhs, config);
79        cmma::load(rhs, &slice, stride);
80    }
81
82    fn fill_accumulator(
83        tile: &Tile<MP::EA>,
84        acc: &mut Self::Accumulator,
85        #[comptime] config: Config,
86    ) {
87        let layout = comptime!(as_cmma_layout(config.matrix_layout(Ident::Out)));
88        let (slice, stride) = tile.as_unlined::<Config>(Ident::Out, config);
89        cmma::load_with_layout(acc, &slice, stride, layout);
90    }
91
92    fn read_accumulator<C: Numeric>(
93        out: &Self::Accumulator,
94        slice: &mut SliceMut<Line<C>>,
95        #[comptime] config: Config,
96    ) {
97        let acc = cmma::cast::<MP::EA, C>(out);
98        cmma::store(slice, &acc, config.size.n, cmma::MatrixLayout::RowMajor);
99    }
100
101    fn allocate_accumulator(#[comptime] config: Self::Config) -> Self::Accumulator {
102        let size = config.size;
103        unsafe {
104            cmma::Matrix::<MP::EA>::uninitialized(
105                cmma::MatrixIdent::Accumulator,
106                size.m,
107                size.n,
108                size.k,
109                cmma::MatrixLayout::Undefined,
110            )
111        }
112    }
113
114    fn zero_accumulator(acc: &mut Self::Accumulator, #[comptime] _config: Self::Config) {
115        cmma::fill(acc, MP::EA::from_int(0));
116    }
117}
118
119impl MatmulConfigFactory for Accelerated {
120    type Input = TileMatmulConfigInput;
121    type Config = Config;
122
123    fn check_config(config: &Self::Config) -> Result<(), InvalidConfigError> {
124        if config.plane_dim != 32 {
125            return Err(Box::new(
126                "Error: Expected plane dimension to be 32, but found {}. Please ensure that cube dimension x is set correctly.",
127            ));
128        }
129        Ok(())
130    }
131
132    fn check_availability<R: Runtime, MP: MatmulPrecision>(
133        client: &ComputeClient<R::Server, R::Channel>,
134        config: &Self::Config,
135    ) -> Result<(), MatmulAvailabilityError> {
136        if config.stage_dynamic_line_size
137            && !client
138                .properties()
139                .feature_enabled(Feature::DynamicLineSize)
140        {
141            return Err(MatmulAvailabilityError::DynamicLineSizeUnavailable);
142        }
143
144        let es = MP::ES::as_elem_native().expect("to be a native type");
145        let ea = MP::EA::as_elem_native().expect("to be a native type");
146
147        let es = match es {
148            Elem::Float(FloatKind::Flex32) => Elem::Float(FloatKind::F32),
149            _ => es,
150        };
151
152        let ea = match ea {
153            Elem::Float(FloatKind::Flex32) => Elem::Float(FloatKind::F32),
154            _ => ea,
155        };
156
157        let size = config.size;
158        if !client.properties().feature_enabled(Feature::Cmma {
159            a: es,
160            b: es,
161            c: ea,
162            m: size.m as u8,
163            k: size.k as u8,
164            n: size.n as u8,
165        }) {
166            return Err(MatmulAvailabilityError::CmmaInstructionUnavailable {
167                input: es,
168                output: ea,
169                shape: Some(MatmulSize {
170                    m: size.m,
171                    n: size.n,
172                    k: size.k,
173                }),
174            });
175        }
176
177        if !(MP::ES::is_supported(client) && MP::EA::is_supported(client)) {
178            return Err(MatmulAvailabilityError::TypesUnavailable {
179                input: es,
180                output: ea,
181            });
182        }
183
184        Ok(())
185    }
186
187    fn make_config(
188        input: Self::Input,
189        problem: &MatmulProblem,
190        cube_dim: &CubeDim,
191        _cube_count: &CubeCount,
192        _quantized: bool,
193    ) -> Self::Config {
194        let (lhs_line_size, rhs_line_size, stage_line_size_update) =
195            if input.vectorization.stage_line_size == 0 {
196                (
197                    problem.lhs_line_size as u32,
198                    problem.rhs_line_size as u32,
199                    false,
200                )
201            } else {
202                (
203                    input.vectorization.stage_line_size as u32,
204                    input.vectorization.stage_line_size as u32,
205                    true,
206                )
207            };
208        Config::new(
209            input.size,
210            cube_dim.x,
211            problem.lhs_layout,
212            problem.rhs_layout,
213            stage_line_size_update,
214            lhs_line_size,
215            rhs_line_size,
216            problem.out_line_size as u32,
217        )
218    }
219}
220
221#[derive(CubeType, Copy, Clone, Debug, Hash, PartialEq, Eq)]
222/// Configuration for Accelerated instruction
223pub struct Config {
224    size: MatmulSize,
225    plane_dim: u32,
226    lhs_layout: MatrixLayout,
227    rhs_layout: MatrixLayout,
228    stage_dynamic_line_size: bool,
229    lhs_line_size: u32,
230    rhs_line_size: u32,
231    out_line_size: u32,
232}
233
234impl TileConfig for Config {
235    fn plane_dim(&self) -> u32 {
236        self.plane_dim
237    }
238
239    fn matrix_layout(&self, ident: Ident) -> MatrixLayout {
240        match ident {
241            Ident::Lhs => self.lhs_layout,
242            Ident::Rhs => self.rhs_layout,
243            Ident::Out => MatrixLayout::RowMajor,
244        }
245    }
246
247    fn stage_line_size(&self, ident: Ident) -> u32 {
248        match ident {
249            Ident::Lhs => self.lhs_line_size,
250            Ident::Rhs => self.rhs_line_size,
251            Ident::Out => self.out_line_size,
252        }
253    }
254
255    fn tile_shape(&self) -> &MatmulSize {
256        &self.size
257    }
258}
259
260impl MatmulConfig for Config {}
261
262impl Config {
263    #[allow(clippy::too_many_arguments)]
264    pub fn new(
265        size: MatmulSize,
266        plane_dim: u32,
267        lhs_layout: MatrixLayout,
268        rhs_layout: MatrixLayout,
269        stage_dynamic_line_size: bool,
270        lhs_line_size: u32,
271        rhs_line_size: u32,
272        out_line_size: u32,
273    ) -> Self {
274        Self {
275            size,
276            plane_dim,
277            lhs_layout,
278            rhs_layout,
279            stage_dynamic_line_size,
280            lhs_line_size,
281            rhs_line_size,
282            out_line_size,
283        }
284    }
285}