cubek_convolution/components/global/layout/
write.rs

1use cubecl::prelude::*;
2use cubecl::std::{
3    FastDivmod, FastDivmodArgs,
4    tensor::layout::{Coords3d, Layout, LayoutExpand},
5};
6use cubek_matmul::components::global::memory::GlobalMemoryConfig;
7
8use crate::components::{
9    ConvolutionProblem,
10    global::{
11        args::RuntimeArgs,
12        layout::{NhwcCoords, cast_seq},
13        read::im2col_tma::div_mod_seq,
14    },
15};
16
17/// Maps a 4D NHWC out tensor of shape `((n, h, w), c)` to a col-major 2D matmul tile with
18/// shape `(m, n)`
19#[derive(CubeType, CubeLaunch, Clone)]
20pub struct OutLayout {
21    /// Shape of DHW
22    pub shape_out: Sequence<FastDivmod>,
23
24    /// Shape of the conceptual `m` size
25    pub shape_m: u32,
26    /// Shape of the conceptual `n`size, or channels
27    pub shape_n: u32,
28
29    /// Global memory config for the backing tensor
30    #[cube(comptime)]
31    pub config: GlobalMemoryConfig,
32}
33
34#[cube]
35impl OutLayout {
36    pub fn new(
37        args: &RuntimeArgs,
38        shape_out: Sequence<FastDivmod>,
39        #[comptime] config: GlobalMemoryConfig,
40    ) -> OutLayout {
41        OutLayout {
42            shape_out,
43            shape_m: args.shape_m,
44            shape_n: args.shape_n,
45            config,
46        }
47    }
48}
49
50#[cube]
51impl Layout for OutLayout {
52    type Coordinates = Coords3d;
53    type SourceCoordinates = NhwcCoords;
54
55    fn to_source_pos(&self, coords: Self::Coordinates) -> NhwcCoords {
56        let (_, view_m, view_n) = coords;
57        let (batch, spatial) = div_mod_seq(view_m, &self.shape_out);
58
59        NhwcCoords {
60            batch,
61            spatial: cast_seq(spatial),
62            channel: view_n,
63        }
64    }
65
66    fn to_source_pos_checked(&self, coords: Self::Coordinates) -> (NhwcCoords, bool) {
67        (self.to_source_pos(coords), self.is_in_bounds(coords))
68    }
69
70    fn shape(&self) -> Self::Coordinates {
71        (1, self.shape_m, self.shape_n)
72    }
73
74    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
75        let (_, m, n) = pos;
76        let check_m = comptime![self.config.check_row_bounds];
77        let check_n = comptime![self.config.check_col_bounds];
78        (!check_m || m < self.shape_m) && (!check_n || n < self.shape_n)
79    }
80}
81
82impl<'a, R: Runtime> OutLayoutLaunch<'a, R> {
83    pub fn from_args(
84        client: &ComputeClient<R>,
85        problem: &ConvolutionProblem,
86        config: GlobalMemoryConfig,
87    ) -> Self {
88        let shape_out = problem
89            .out_shape
90            .iter()
91            .map(|s| FastDivmodArgs::new(client, *s as u32))
92            .collect();
93        let shape_m = ScalarArg::new(problem.m as u32);
94        let shape_n = ScalarArg::new(problem.n as u32);
95
96        Self::new(shape_out, shape_m, shape_n, config)
97    }
98}