cubecl_convolution/components/global/layout/
write.rs

1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl};
3use cubecl_matmul::components::global::memory::GlobalMemoryConfig;
4use cubecl_std::{
5    FastDivmod, FastDivmodArgs,
6    tensor::layout::{Coords3d, Layout, LayoutExpand},
7};
8
9use crate::components::{
10    ConvolutionProblem,
11    global::{
12        args::RuntimeArgs,
13        layout::{NhwcCoords, cast_seq},
14        read::im2col_tma::div_mod_seq,
15    },
16};
17
18/// Maps a 4D NHWC out tensor of shape `((n, h, w), c)` to a col-major 2D matmul tile with
19/// shape `(m, n)`
20#[derive(CubeType, CubeLaunch, Clone)]
21pub struct OutLayout {
22    /// Shape of DHW
23    pub shape_out: Sequence<FastDivmod>,
24
25    /// Shape of the conceptual `m` size
26    pub shape_m: u32,
27    /// Shape of the conceptual `n`size, or channels
28    pub shape_n: u32,
29
30    /// Global memory config for the backing tensor
31    #[cube(comptime)]
32    pub config: GlobalMemoryConfig,
33}
34
35#[cube]
36impl OutLayout {
37    pub fn new(
38        args: &RuntimeArgs,
39        shape_out: Sequence<FastDivmod>,
40        #[comptime] config: GlobalMemoryConfig,
41    ) -> OutLayout {
42        OutLayout {
43            shape_out,
44            shape_m: args.shape_m,
45            shape_n: args.shape_n,
46            config,
47        }
48    }
49}
50
51#[cube]
52impl Layout for OutLayout {
53    type Coordinates = Coords3d;
54    type SourceCoordinates = NhwcCoords;
55
56    fn to_source_pos(&self, coords: Self::Coordinates) -> NhwcCoords {
57        let (_, view_m, view_n) = coords;
58        let (batch, spatial) = div_mod_seq(view_m, &self.shape_out);
59
60        NhwcCoords {
61            batch,
62            spatial: cast_seq(spatial),
63            channel: view_n,
64        }
65    }
66
67    fn to_source_pos_checked(&self, coords: Self::Coordinates) -> (NhwcCoords, bool) {
68        (self.to_source_pos(coords), self.is_in_bounds(coords))
69    }
70
71    fn shape(&self) -> Self::Coordinates {
72        (1, self.shape_m, self.shape_n)
73    }
74
75    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
76        let (_, m, n) = pos;
77        let check_m = comptime![self.config.check_row_bounds];
78        let check_n = comptime![self.config.check_col_bounds];
79        (!check_m || m < self.shape_m) && (!check_n || n < self.shape_n)
80    }
81}
82
83impl<'a, R: Runtime> OutLayoutLaunch<'a, R> {
84    pub fn from_args(
85        client: &ComputeClient<R>,
86        problem: &ConvolutionProblem,
87        config: GlobalMemoryConfig,
88    ) -> Self {
89        let shape_out = problem
90            .out_shape
91            .iter()
92            .map(|s| FastDivmodArgs::new(client, *s as u32))
93            .collect();
94        let shape_m = ScalarArg::new(problem.m as u32);
95        let shape_n = ScalarArg::new(problem.n as u32);
96
97        Self::new(shape_out, shape_m, shape_n, config)
98    }
99}