cubek_convolution/components/global/layout/
tma_out_grad.rs1use cubecl::{
2 prelude::*,
3 std::tensor::layout::{Coords2d, Layout, LayoutExpand},
4};
5use cubek_matmul::launch::BatchedCoords;
6
7use crate::components::ConvolutionProblem;
8
9#[derive(CubeType, CubeLaunch)]
13pub struct TmaOutGradLayout {
14 rows: u32,
15 cols: u32,
16}
17
18#[cube]
19impl Layout for TmaOutGradLayout {
20 type Coordinates = BatchedCoords;
21 type SourceCoordinates = Coords2d;
22
23 fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
24 let (_, row, col) = pos;
25 (row, col)
26 }
27
28 fn is_in_bounds(&self, _pos: Self::Coordinates) -> bool {
29 true.runtime()
30 }
31
32 fn shape(&self) -> Self::Coordinates {
33 (1, self.rows, self.cols)
34 }
35
36 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
37 (self.to_source_pos(pos), self.is_in_bounds(pos))
38 }
39}
40
41impl<R: Runtime> TmaOutGradLayoutLaunch<R> {
42 pub fn from_problem(problem: &ConvolutionProblem) -> Self {
43 TmaOutGradLayoutLaunch::new(problem.k as u32, problem.m as u32)
44 }
45}