Skip to main content

cubek_convolution/components/global/layout/
out.rs

1use cubecl::prelude::*;
2use cubecl::std::{
3    FastDivmod,
4    tensor::layout::{Layout, LayoutExpand},
5};
6use cubek_matmul::{components::global::memory::GlobalLayoutConfig, launch::BatchedCoords};
7
8use crate::components::{
9    ConvolutionOperation, ConvolutionProblem,
10    global::layout::{NhwcCoords, cast_seq, div_mod_seq},
11};
12
13/// Maps a 4D NHWC out tensor of shape `((n, h, w), c)` to a col-major 2D matmul tile with
14/// shape `(m, n)`
15#[derive(CubeType, CubeLaunch, Clone)]
16pub struct OutLayout {
17    /// Shape of DHW
18    pub shape_out: Sequence<FastDivmod<u32>>,
19
20    /// Shape of the conceptual `m` size
21    pub rows: u32,
22    /// Shape of the conceptual `n`size, or channels
23    pub cols: u32,
24
25    /// Global memory config for the backing tensor
26    #[cube(comptime)]
27    pub config: GlobalLayoutConfig,
28}
29
30#[cube]
31impl OutLayout {
32    pub fn new(
33        rows: u32,
34        cols: u32,
35        shape_out: Sequence<FastDivmod<u32>>,
36        #[comptime] config: GlobalLayoutConfig,
37    ) -> OutLayout {
38        OutLayout {
39            shape_out,
40            rows,
41            cols,
42            config,
43        }
44    }
45}
46
47#[cube]
48impl Layout for OutLayout {
49    type Coordinates = BatchedCoords;
50    type SourceCoordinates = NhwcCoords;
51
52    fn to_source_pos(&self, coords: Self::Coordinates) -> NhwcCoords {
53        let (_, view_m, view_n) = coords;
54        let (batch, spatial) = div_mod_seq(view_m, &self.shape_out);
55
56        NhwcCoords {
57            batch,
58            spatial: cast_seq(spatial),
59            channel: view_n,
60        }
61    }
62
63    fn to_source_pos_checked(&self, coords: Self::Coordinates) -> (NhwcCoords, bool) {
64        (self.to_source_pos(coords), self.is_in_bounds(coords))
65    }
66
67    fn shape(&self) -> Self::Coordinates {
68        (1, self.rows, self.cols)
69    }
70
71    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
72        let (_, row, col) = pos;
73        (!self.config.check_row_bounds || row < self.rows)
74            && (!self.config.check_col_bounds || col < self.cols)
75    }
76}
77
78impl<R: Runtime> OutLayoutLaunch<R> {
79    pub fn from_args(problem: &ConvolutionProblem, config: GlobalLayoutConfig) -> Self {
80        match problem.operation {
81            ConvolutionOperation::Forward => Self::from_args_fprop(problem, config),
82            ConvolutionOperation::ForwardTransposed | ConvolutionOperation::BackwardData => {
83                Self::from_args_dgrad(problem, config)
84            }
85            ConvolutionOperation::BackwardWeight => Self::from_args_wgrad(problem, config),
86        }
87    }
88
89    fn from_args_fprop(problem: &ConvolutionProblem, config: GlobalLayoutConfig) -> Self {
90        let shape_out = problem.out_shape.iter().map(|s| *s as u32).collect();
91        let shape_m = problem.m as u32;
92        let shape_n = problem.n as u32;
93
94        Self::new(shape_out, shape_m, shape_n, config)
95    }
96
97    fn from_args_dgrad(problem: &ConvolutionProblem, config: GlobalLayoutConfig) -> Self {
98        let shape = problem.in_shape.iter().map(|s| *s as u32).collect();
99        let shape_m = problem.m as u32;
100        let shape_n = problem.n as u32;
101
102        Self::new(shape, shape_m, shape_n, config)
103    }
104
105    fn from_args_wgrad(problem: &ConvolutionProblem, config: GlobalLayoutConfig) -> Self {
106        let shape_out = problem.out_shape.iter().map(|s| *s as u32).collect();
107        let shape_m = problem.m as u32;
108        let shape_k = problem.k as u32;
109
110        Self::new(shape_out, shape_k, shape_m, config)
111    }
112}