cubek_convolution/components/global/layout/
weight.rs

1use cubecl::prelude::*;
2use cubecl::std::{
3    FastDivmod, FastDivmodArgs,
4    tensor::layout::{Layout, LayoutExpand},
5};
6use cubek_matmul::{
7    components::global::{GlobalConfig, memory::GlobalMemoryConfig},
8    launch::BatchedCoords,
9};
10
11use crate::components::{
12    ConvGemmConfig, ConvolutionConfig, ConvolutionOperation, ConvolutionParams, ConvolutionProblem,
13    global::layout::NhwcCoords,
14};
15
16/// Maps a 4D weight tensor of shape `(out_c, (k_h, k_w, in_c))` to a col-major 2D matmul tile with
17/// shape `(n, k)`
18#[derive(CubeType, CubeLaunch, Clone)]
19pub struct WeightLayout {
20    /// Number of channels, including padding, used for decomposing `k`
21    pub padded_channels: FastDivmod<u32>,
22
23    /// Shape of the combined kernel and channels dim, including padding
24    pub rows: u32,
25    /// Shape of the `out_c` dimension
26    pub cols: u32,
27
28    /// Size of the convolution kernel
29    #[cube(comptime)]
30    pub params: ConvolutionParams,
31    /// Global memory config for the backing tensor
32    #[cube(comptime)]
33    pub config: GlobalMemoryConfig,
34}
35
36#[cube]
37impl WeightLayout {
38    pub fn new<E: Numeric, G: GlobalConfig>(
39        rows: u32,
40        cols: u32,
41        padded_channels: FastDivmod<u32>,
42        #[comptime] config: ConvolutionConfig<G>,
43    ) -> WeightLayout {
44        WeightLayout {
45            rows,
46            cols,
47            padded_channels,
48            params: config.params,
49            config: config.rhs_global_memory_config(),
50        }
51    }
52}
53
54#[cube]
55impl Layout for WeightLayout {
56    type Coordinates = BatchedCoords;
57    type SourceCoordinates = NhwcCoords;
58
59    fn to_source_pos(&self, coords: Self::Coordinates) -> NhwcCoords {
60        let params = self.params.comptime();
61        let (_, k, n) = coords;
62
63        let (mut rem, k_channel) = self.padded_channels.div_mod(k);
64
65        let spatial_dims = params.dimensionality.num_dims();
66        let mut kernel_pos = Sequence::<i32>::new();
67
68        #[unroll]
69        for i in 0..spatial_dims {
70            let dim = spatial_dims - i - 1;
71            let ksize = params.kernel_size[dim];
72            let k_pos = rem % ksize;
73            rem /= ksize;
74
75            kernel_pos.push(k_pos as i32);
76        }
77
78        let kernel_pos = kernel_pos.rev();
79
80        let (batch, channel) = match params.operation {
81            ConvolutionOperation::Forward | ConvolutionOperation::BackwardWeight => (n, k_channel),
82            ConvolutionOperation::ForwardTransposed | ConvolutionOperation::BackwardData => {
83                (k_channel, n)
84            }
85        };
86
87        NhwcCoords {
88            batch,
89            spatial: kernel_pos,
90            channel,
91        }
92    }
93
94    fn to_source_pos_checked(&self, coords: Self::Coordinates) -> (NhwcCoords, bool) {
95        (self.to_source_pos(coords), self.is_in_bounds(coords))
96    }
97
98    fn shape(&self) -> Self::Coordinates {
99        (1, self.rows, self.cols)
100    }
101
102    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
103        let (_, k, n) = pos;
104        let check_k = self.config.check_row_bounds;
105        let check_n = self.config.check_col_bounds;
106        (!check_k || k < self.rows) && (!check_n || n < self.cols)
107    }
108}
109
110impl<'a, R: Runtime> WeightLayoutLaunch<'a, R> {
111    pub fn from_args(
112        client: &ComputeClient<R>,
113        problem: &ConvolutionProblem,
114        config: GlobalMemoryConfig,
115    ) -> Self {
116        match problem.operation {
117            ConvolutionOperation::Forward
118            | ConvolutionOperation::ForwardTransposed
119            | ConvolutionOperation::BackwardData => Self::from_args_rhs(client, problem, config),
120            ConvolutionOperation::BackwardWeight => Self::from_args_out(client, problem, config),
121        }
122    }
123
124    fn from_args_rhs(
125        client: &ComputeClient<R>,
126        problem: &ConvolutionProblem,
127        config: GlobalMemoryConfig,
128    ) -> Self {
129        let padded_channels = problem.padded_channels as u32;
130        let padded_channels = FastDivmodArgs::<u32>::new(client, padded_channels);
131        let shape_k = ScalarArg::new(problem.k as u32);
132        let shape_n = ScalarArg::new(problem.n as u32);
133
134        let params = ConvolutionParams::from_problem(problem);
135
136        WeightLayoutLaunch::new(padded_channels, shape_k, shape_n, params, config)
137    }
138
139    fn from_args_out(
140        client: &ComputeClient<R>,
141        problem: &ConvolutionProblem,
142        config: GlobalMemoryConfig,
143    ) -> Self {
144        let padded_channels = problem.padded_channels as u32;
145        let padded_channels = FastDivmodArgs::<u32>::new(client, padded_channels);
146        let shape_m = ScalarArg::new(problem.m as u32);
147        let shape_n = ScalarArg::new(problem.n as u32);
148
149        let params = ConvolutionParams::from_problem(problem);
150
151        WeightLayoutLaunch::new(padded_channels, shape_n, shape_m, params, config)
152    }
153}