cubecl_convolution/components/global/layout/
weight.rs

1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl};
3use cubecl_matmul::components::{
4    MatmulIdent,
5    global::{GlobalConfig, memory::GlobalMemoryConfig},
6};
7use cubecl_std::{
8    FastDivmod, FastDivmodArgs,
9    tensor::layout::{Coords3d, Layout, LayoutExpand},
10};
11
12use crate::{
13    components::{
14        ConvGemmConfig, ConvolutionConfig, ConvolutionParams, ConvolutionProblem,
15        global::layout::NhwcCoords,
16    },
17    kernels::layered::selector::RuntimeArgs,
18};
19
20/// Maps a 4D weight tensor of shape `(out_c, (k_h, k_w, in_c))` to a col-major 2D matmul tile with
21/// shape `(n, k)`
22#[derive(CubeType, CubeLaunch, Clone)]
23pub struct WeightLayout {
24    /// Number of channels, including padding, used for decomposing `k`
25    pub channels: FastDivmod,
26
27    /// Shape of the conceptual `k` size, including padding
28    pub shape_k: u32,
29    /// Shape of the conceptual `n` size, or `out_c`
30    pub shape_n: u32,
31
32    /// Size of the convolution kernel
33    #[cube(comptime)]
34    pub params: ConvolutionParams,
35    /// Global memory config for the backing tensor
36    #[cube(comptime)]
37    pub config: GlobalMemoryConfig,
38}
39
40#[cube]
41impl WeightLayout {
42    pub fn new<E: Numeric, G: GlobalConfig>(
43        args: &RuntimeArgs,
44        #[comptime] config: ConvolutionConfig<G>,
45    ) -> WeightLayout {
46        WeightLayout {
47            shape_k: args.shape_k,
48            shape_n: args.shape_n,
49            channels: args.padded_channels,
50            params: config.convolution_params(),
51            config: config.global_memory_config(MatmulIdent::Rhs),
52        }
53    }
54}
55
56#[cube]
57impl Layout for WeightLayout {
58    type Coordinates = Coords3d;
59    type SourceCoordinates = NhwcCoords;
60
61    fn to_source_pos(&self, coords: Self::Coordinates) -> NhwcCoords {
62        let params = comptime![self.params];
63        let (_, k, n) = coords;
64
65        let (mut rem, in_c) = self.channels.div_mod(k);
66
67        let spatial_dims = comptime![params.dimensionality.num_dims()];
68        let mut kernel_pos = Sequence::<i32>::new();
69
70        #[unroll]
71        for i in 0..spatial_dims {
72            let dim = comptime![spatial_dims - i - 1];
73            let ksize = comptime![params.kernel_size[dim as usize]];
74            let k_pos = rem % ksize;
75            rem /= ksize;
76
77            kernel_pos.push(k_pos as i32);
78        }
79
80        let kernel_pos = kernel_pos.rev();
81
82        NhwcCoords {
83            batch: n,
84            spatial: kernel_pos,
85            channel: in_c,
86        }
87    }
88
89    fn to_source_pos_checked(&self, coords: Self::Coordinates) -> (NhwcCoords, bool) {
90        (self.to_source_pos(coords), self.is_in_bounds(coords))
91    }
92
93    fn shape(&self) -> Self::Coordinates {
94        (1, self.shape_k, self.shape_n)
95    }
96
97    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
98        let (_, k, n) = pos;
99        let check_k = comptime![self.config.check_row_bounds];
100        let check_n = comptime![self.config.check_col_bounds];
101        (!check_k || k < self.shape_k) && (!check_n || n < self.shape_n)
102    }
103}
104
105impl<'a, R: Runtime> WeightLayoutLaunch<'a, R> {
106    pub fn from_args(
107        client: &ComputeClient<R::Server>,
108        problem: &ConvolutionProblem,
109        params: ConvolutionParams,
110        config: GlobalMemoryConfig,
111    ) -> Self {
112        let channels = FastDivmodArgs::new(client, problem.channels as u32);
113        let shape_k = ScalarArg::new(problem.k as u32);
114        let shape_n = ScalarArg::new(problem.n as u32);
115
116        WeightLayoutLaunch::new(channels, shape_k, shape_n, params, config)
117    }
118}