cubek_convolution/components/global/layout/
out.rs

1use cubecl::prelude::*;
2use cubecl::std::{
3    FastDivmod, FastDivmodArgs,
4    tensor::layout::{Layout, LayoutExpand},
5};
6use cubek_matmul::{components::global::memory::GlobalMemoryConfig, launch::BatchedCoords};
7
8use crate::components::{
9    ConvolutionOperation, ConvolutionProblem,
10    global::layout::{NhwcCoords, cast_seq, div_mod_seq},
11};
12
13/// Maps a 4D NHWC out tensor of shape `((n, h, w), c)` to a col-major 2D matmul tile with
14/// shape `(m, n)`
15#[derive(CubeType, CubeLaunch, Clone)]
16pub struct OutLayout {
17    /// Shape of DHW
18    pub shape_out: Sequence<FastDivmod<u32>>,
19
20    /// Shape of the conceptual `m` size
21    pub rows: u32,
22    /// Shape of the conceptual `n`size, or channels
23    pub cols: u32,
24
25    /// Global memory config for the backing tensor
26    #[cube(comptime)]
27    pub config: GlobalMemoryConfig,
28}
29
30#[cube]
31impl OutLayout {
32    pub fn new(
33        rows: u32,
34        cols: u32,
35        shape_out: Sequence<FastDivmod<u32>>,
36        #[comptime] config: GlobalMemoryConfig,
37    ) -> OutLayout {
38        OutLayout {
39            shape_out,
40            rows,
41            cols,
42            config,
43        }
44    }
45}
46
47#[cube]
48impl Layout for OutLayout {
49    type Coordinates = BatchedCoords;
50    type SourceCoordinates = NhwcCoords;
51
52    fn to_source_pos(&self, coords: Self::Coordinates) -> NhwcCoords {
53        let (_, view_m, view_n) = coords;
54        let (batch, spatial) = div_mod_seq(view_m, &self.shape_out);
55
56        NhwcCoords {
57            batch,
58            spatial: cast_seq(spatial),
59            channel: view_n,
60        }
61    }
62
63    fn to_source_pos_checked(&self, coords: Self::Coordinates) -> (NhwcCoords, bool) {
64        (self.to_source_pos(coords), self.is_in_bounds(coords))
65    }
66
67    fn shape(&self) -> Self::Coordinates {
68        (1, self.rows, self.cols)
69    }
70
71    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
72        let (_, row, col) = pos;
73        (!self.config.check_row_bounds || row < self.rows)
74            && (!self.config.check_col_bounds || col < self.cols)
75    }
76}
77
78impl<'a, R: Runtime> OutLayoutLaunch<'a, R> {
79    pub fn from_args(
80        client: &ComputeClient<R>,
81        problem: &ConvolutionProblem,
82        config: GlobalMemoryConfig,
83    ) -> Self {
84        match problem.operation {
85            ConvolutionOperation::Forward => Self::from_args_fprop(client, problem, config),
86            ConvolutionOperation::ForwardTransposed | ConvolutionOperation::BackwardData => {
87                Self::from_args_dgrad(client, problem, config)
88            }
89            ConvolutionOperation::BackwardWeight => Self::from_args_wgrad(client, problem, config),
90        }
91    }
92
93    fn from_args_fprop(
94        client: &ComputeClient<R>,
95        problem: &ConvolutionProblem,
96        config: GlobalMemoryConfig,
97    ) -> Self {
98        let shape_out = problem
99            .out_shape
100            .iter()
101            .map(|s| FastDivmodArgs::<u32>::new(client, *s as u32))
102            .collect();
103        let shape_m = ScalarArg::new(problem.m as u32);
104        let shape_n = ScalarArg::new(problem.n as u32);
105
106        Self::new(shape_out, shape_m, shape_n, config)
107    }
108
109    fn from_args_dgrad(
110        client: &ComputeClient<R>,
111        problem: &ConvolutionProblem,
112        config: GlobalMemoryConfig,
113    ) -> Self {
114        let shape = problem
115            .in_shape
116            .iter()
117            .map(|s| FastDivmodArgs::<u32>::new(client, *s as u32))
118            .collect();
119        let shape_m = ScalarArg::new(problem.m as u32);
120        let shape_n = ScalarArg::new(problem.n as u32);
121
122        Self::new(shape, shape_m, shape_n, config)
123    }
124
125    fn from_args_wgrad(
126        client: &ComputeClient<R>,
127        problem: &ConvolutionProblem,
128        config: GlobalMemoryConfig,
129    ) -> Self {
130        let shape_out = problem
131            .out_shape
132            .iter()
133            .map(|s| FastDivmodArgs::<u32>::new(client, *s as u32))
134            .collect();
135        let shape_m = ScalarArg::new(problem.m as u32);
136        let shape_k = ScalarArg::new(problem.k as u32);
137
138        Self::new(shape_out, shape_k, shape_m, config)
139    }
140}