cubecl_convolution/components/global/layout/
write.rs

1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl};
3use cubecl_matmul::components::global::memory::GlobalMemoryConfig;
4use cubecl_std::{
5    FastDivmod, FastDivmodArgs,
6    tensor::layout::{Coords3d, Layout, LayoutExpand},
7};
8
9use crate::{
10    components::{
11        ConvolutionProblem,
12        global::{
13            layout::{NhwcCoords, cast_seq},
14            read::im2col_tma::div_mod_seq,
15        },
16    },
17    kernels::layered::selector::RuntimeArgs,
18};
19
20/// Maps a 4D NHWC out tensor of shape `((n, h, w), c)` to a col-major 2D matmul tile with
21/// shape `(m, n)`
22#[derive(CubeType, CubeLaunch, Clone)]
23pub struct OutLayout {
24    /// Shape of DHW
25    pub shape_out: Sequence<FastDivmod>,
26
27    /// Shape of the conceptual `m` size
28    pub shape_m: u32,
29    /// Shape of the conceptual `n`size, or channels
30    pub shape_n: u32,
31
32    /// Global memory config for the backing tensor
33    #[cube(comptime)]
34    pub config: GlobalMemoryConfig,
35}
36
37#[cube]
38impl OutLayout {
39    pub fn new(args: &RuntimeArgs, #[comptime] config: GlobalMemoryConfig) -> OutLayout {
40        OutLayout {
41            shape_out: args.shape_out.clone(),
42            shape_m: args.shape_m,
43            shape_n: args.shape_n,
44            config,
45        }
46    }
47}
48
49#[cube]
50impl Layout for OutLayout {
51    type Coordinates = Coords3d;
52    type SourceCoordinates = NhwcCoords;
53
54    fn to_source_pos(&self, coords: Self::Coordinates) -> NhwcCoords {
55        let (_, view_m, view_n) = coords;
56        let (batch, spatial) = div_mod_seq(view_m, &self.shape_out);
57
58        NhwcCoords {
59            batch,
60            spatial: cast_seq(spatial),
61            channel: view_n,
62        }
63    }
64
65    fn to_source_pos_checked(&self, coords: Self::Coordinates) -> (NhwcCoords, bool) {
66        (self.to_source_pos(coords), self.is_in_bounds(coords))
67    }
68
69    fn shape(&self) -> Self::Coordinates {
70        (1, self.shape_m, self.shape_n)
71    }
72
73    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
74        let (_, m, n) = pos;
75        let check_m = comptime![self.config.check_row_bounds];
76        let check_n = comptime![self.config.check_col_bounds];
77        (!check_m || m < self.shape_m) && (!check_n || n < self.shape_n)
78    }
79}
80
81impl<'a, R: Runtime> OutLayoutLaunch<'a, R> {
82    pub fn from_args(
83        client: &ComputeClient<R::Server>,
84        problem: &ConvolutionProblem,
85        config: GlobalMemoryConfig,
86    ) -> Self {
87        let shape_out = problem
88            .out_shape
89            .iter()
90            .map(|s| FastDivmodArgs::new(client, *s as u32))
91            .collect();
92        let shape_m = ScalarArg::new(problem.m as u32);
93        let shape_n = ScalarArg::new(problem.n as u32);
94
95        Self::new(shape_out, shape_m, shape_n, config)
96    }
97}