cubek_convolution/components/global/layout/
weight.rs

1use cubecl::prelude::*;
2use cubecl::std::{
3    FastDivmod, FastDivmodArgs,
4    tensor::layout::{Coords3d, Layout, LayoutExpand},
5};
6use cubek_matmul::components::{
7    MatmulElems,
8    global::{GlobalConfig, memory::GlobalMemoryConfig},
9};
10
11use crate::components::{
12    ConvGemmConfig, ConvolutionConfig, ConvolutionParams, ConvolutionProblem,
13    global::{args::RuntimeArgs, layout::NhwcCoords},
14};
15
16/// Maps a 4D weight tensor of shape `(out_c, (k_h, k_w, in_c))` to a col-major 2D matmul tile with
17/// shape `(n, k)`
18#[derive(CubeType, CubeLaunch, Clone)]
19pub struct WeightLayout {
20    /// Number of channels, including padding, used for decomposing `k`
21    pub padded_channels: FastDivmod,
22
23    /// Shape of the conceptual `k` size, including padding
24    pub shape_k: u32,
25    /// Shape of the conceptual `n` size, or `out_c`
26    pub shape_n: u32,
27
28    /// Size of the convolution kernel
29    #[cube(comptime)]
30    pub params: ConvolutionParams,
31    /// Global memory config for the backing tensor
32    #[cube(comptime)]
33    pub config: GlobalMemoryConfig,
34}
35
36#[cube]
37impl WeightLayout {
38    pub fn new<E: Numeric, G: GlobalConfig>(
39        args: &RuntimeArgs,
40        #[comptime] config: ConvolutionConfig<G>,
41    ) -> WeightLayout {
42        WeightLayout {
43            shape_k: args.shape_k,
44            shape_n: args.shape_n,
45            padded_channels: args.padded_channels,
46            params: config.convolution_params,
47            config: config.rhs_global_memory_config(),
48        }
49    }
50}
51
52#[cube]
53impl Layout for WeightLayout {
54    type Coordinates = Coords3d;
55    type SourceCoordinates = NhwcCoords;
56
57    fn to_source_pos(&self, coords: Self::Coordinates) -> NhwcCoords {
58        let params = comptime![self.params];
59        let (_, k, n) = coords;
60
61        let (mut rem, in_c) = self.padded_channels.div_mod(k);
62
63        let spatial_dims = comptime![params.dimensionality.num_dims()];
64        let mut kernel_pos = Sequence::<i32>::new();
65
66        #[unroll]
67        for i in 0..spatial_dims {
68            let dim = comptime![spatial_dims - i - 1];
69            let ksize = comptime![params.kernel_size[dim as usize]];
70            let k_pos = rem % ksize;
71            rem /= ksize;
72
73            kernel_pos.push(k_pos as i32);
74        }
75
76        let kernel_pos = kernel_pos.rev();
77
78        NhwcCoords {
79            batch: n,
80            spatial: kernel_pos,
81            channel: in_c,
82        }
83    }
84
85    fn to_source_pos_checked(&self, coords: Self::Coordinates) -> (NhwcCoords, bool) {
86        (self.to_source_pos(coords), self.is_in_bounds(coords))
87    }
88
89    fn shape(&self) -> Self::Coordinates {
90        (1, self.shape_k, self.shape_n)
91    }
92
93    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
94        let (_, k, n) = pos;
95        let check_k = comptime![self.config.check_row_bounds];
96        let check_n = comptime![self.config.check_col_bounds];
97        (!check_k || k < self.shape_k) && (!check_n || n < self.shape_n)
98    }
99}
100
101impl<'a, R: Runtime> WeightLayoutLaunch<'a, R> {
102    pub fn from_args(
103        client: &ComputeClient<R>,
104        problem: &ConvolutionProblem,
105        params: ConvolutionParams,
106        config: GlobalMemoryConfig,
107        dtypes: &MatmulElems,
108    ) -> Self {
109        let load_width = client.properties().hardware.load_width;
110        let channel_align = load_width / dtypes.lhs_global.size_bits() as u32;
111        let padded_channels = (problem.channels as u32).next_multiple_of(channel_align);
112
113        let size_k = problem.kernel_size.iter().product::<u32>() * padded_channels;
114        let padded_channels = FastDivmodArgs::new(client, padded_channels);
115        let shape_k = ScalarArg::new(size_k);
116        let shape_n = ScalarArg::new(problem.n as u32);
117
118        WeightLayoutLaunch::new(padded_channels, shape_k, shape_n, params, config)
119    }
120}