cubecl_convolution/components/global/layout/
weight.rs

1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl};
3use cubecl_matmul::components::{
4    MatmulElems,
5    global::{GlobalConfig, memory::GlobalMemoryConfig},
6};
7use cubecl_std::{
8    FastDivmod, FastDivmodArgs,
9    tensor::layout::{Coords3d, Layout, LayoutExpand},
10};
11
12use crate::components::{
13    ConvGemmConfig, ConvolutionConfig, ConvolutionParams, ConvolutionProblem,
14    global::{args::RuntimeArgs, layout::NhwcCoords},
15};
16
17/// Maps a 4D weight tensor of shape `(out_c, (k_h, k_w, in_c))` to a col-major 2D matmul tile with
18/// shape `(n, k)`
19#[derive(CubeType, CubeLaunch, Clone)]
20pub struct WeightLayout {
21    /// Number of channels, including padding, used for decomposing `k`
22    pub padded_channels: FastDivmod,
23
24    /// Shape of the conceptual `k` size, including padding
25    pub shape_k: u32,
26    /// Shape of the conceptual `n` size, or `out_c`
27    pub shape_n: u32,
28
29    /// Size of the convolution kernel
30    #[cube(comptime)]
31    pub params: ConvolutionParams,
32    /// Global memory config for the backing tensor
33    #[cube(comptime)]
34    pub config: GlobalMemoryConfig,
35}
36
37#[cube]
38impl WeightLayout {
39    pub fn new<E: Numeric, G: GlobalConfig>(
40        args: &RuntimeArgs,
41        #[comptime] config: ConvolutionConfig<G>,
42    ) -> WeightLayout {
43        WeightLayout {
44            shape_k: args.shape_k,
45            shape_n: args.shape_n,
46            padded_channels: args.padded_channels,
47            params: config.convolution_params,
48            config: config.rhs_global_memory_config(),
49        }
50    }
51}
52
53#[cube]
54impl Layout for WeightLayout {
55    type Coordinates = Coords3d;
56    type SourceCoordinates = NhwcCoords;
57
58    fn to_source_pos(&self, coords: Self::Coordinates) -> NhwcCoords {
59        let params = comptime![self.params];
60        let (_, k, n) = coords;
61
62        let (mut rem, in_c) = self.padded_channels.div_mod(k);
63
64        let spatial_dims = comptime![params.dimensionality.num_dims()];
65        let mut kernel_pos = Sequence::<i32>::new();
66
67        #[unroll]
68        for i in 0..spatial_dims {
69            let dim = comptime![spatial_dims - i - 1];
70            let ksize = comptime![params.kernel_size[dim as usize]];
71            let k_pos = rem % ksize;
72            rem /= ksize;
73
74            kernel_pos.push(k_pos as i32);
75        }
76
77        let kernel_pos = kernel_pos.rev();
78
79        NhwcCoords {
80            batch: n,
81            spatial: kernel_pos,
82            channel: in_c,
83        }
84    }
85
86    fn to_source_pos_checked(&self, coords: Self::Coordinates) -> (NhwcCoords, bool) {
87        (self.to_source_pos(coords), self.is_in_bounds(coords))
88    }
89
90    fn shape(&self) -> Self::Coordinates {
91        (1, self.shape_k, self.shape_n)
92    }
93
94    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
95        let (_, k, n) = pos;
96        let check_k = comptime![self.config.check_row_bounds];
97        let check_n = comptime![self.config.check_col_bounds];
98        (!check_k || k < self.shape_k) && (!check_n || n < self.shape_n)
99    }
100}
101
102impl<'a, R: Runtime> WeightLayoutLaunch<'a, R> {
103    pub fn from_args(
104        client: &ComputeClient<R>,
105        problem: &ConvolutionProblem,
106        params: ConvolutionParams,
107        config: GlobalMemoryConfig,
108        dtypes: &MatmulElems,
109    ) -> Self {
110        let load_width = client.properties().hardware.load_width;
111        let channel_align = load_width / dtypes.lhs_global.size_bits() as u32;
112        let padded_channels = (problem.channels as u32).next_multiple_of(channel_align);
113
114        let size_k = problem.kernel_size.iter().product::<u32>() * padded_channels;
115        let padded_channels = FastDivmodArgs::new(client, padded_channels);
116        let shape_k = ScalarArg::new(size_k);
117        let shape_n = ScalarArg::new(problem.n as u32);
118
119        WeightLayoutLaunch::new(padded_channels, shape_k, shape_n, params, config)
120    }
121}