cubek_convolution/components/global/layout/
weight.rs1use cubecl::prelude::*;
2use cubecl::std::{
3 FastDivmod, FastDivmodArgs,
4 tensor::layout::{Layout, LayoutExpand},
5};
6use cubek_matmul::{
7 components::global::{GlobalConfig, memory::GlobalMemoryConfig},
8 launch::BatchedCoords,
9};
10
11use crate::components::{
12 ConvGemmConfig, ConvolutionConfig, ConvolutionOperation, ConvolutionParams, ConvolutionProblem,
13 global::layout::NhwcCoords,
14};
15
16#[derive(CubeType, CubeLaunch, Clone)]
19pub struct WeightLayout {
20 pub padded_channels: FastDivmod<u32>,
22
23 pub rows: u32,
25 pub cols: u32,
27
28 #[cube(comptime)]
30 pub params: ConvolutionParams,
31 #[cube(comptime)]
33 pub config: GlobalMemoryConfig,
34}
35
36#[cube]
37impl WeightLayout {
38 pub fn new<E: Numeric, G: GlobalConfig>(
39 rows: u32,
40 cols: u32,
41 padded_channels: FastDivmod<u32>,
42 #[comptime] config: ConvolutionConfig<G>,
43 ) -> WeightLayout {
44 WeightLayout {
45 rows,
46 cols,
47 padded_channels,
48 params: config.params,
49 config: config.rhs_global_memory_config(),
50 }
51 }
52}
53
54#[cube]
55impl Layout for WeightLayout {
56 type Coordinates = BatchedCoords;
57 type SourceCoordinates = NhwcCoords;
58
59 fn to_source_pos(&self, coords: Self::Coordinates) -> NhwcCoords {
60 let params = self.params.comptime();
61 let (_, k, n) = coords;
62
63 let (mut rem, k_channel) = self.padded_channels.div_mod(k);
64
65 let spatial_dims = params.dimensionality.num_dims();
66 let mut kernel_pos = Sequence::<i32>::new();
67
68 #[unroll]
69 for i in 0..spatial_dims {
70 let dim = spatial_dims - i - 1;
71 let ksize = params.kernel_size[dim];
72 let k_pos = rem % ksize;
73 rem /= ksize;
74
75 kernel_pos.push(k_pos as i32);
76 }
77
78 let kernel_pos = kernel_pos.rev();
79
80 let (batch, channel) = match params.operation {
81 ConvolutionOperation::Forward | ConvolutionOperation::BackwardWeight => (n, k_channel),
82 ConvolutionOperation::ForwardTransposed | ConvolutionOperation::BackwardData => {
83 (k_channel, n)
84 }
85 };
86
87 NhwcCoords {
88 batch,
89 spatial: kernel_pos,
90 channel,
91 }
92 }
93
94 fn to_source_pos_checked(&self, coords: Self::Coordinates) -> (NhwcCoords, bool) {
95 (self.to_source_pos(coords), self.is_in_bounds(coords))
96 }
97
98 fn shape(&self) -> Self::Coordinates {
99 (1, self.rows, self.cols)
100 }
101
102 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
103 let (_, k, n) = pos;
104 let check_k = self.config.check_row_bounds;
105 let check_n = self.config.check_col_bounds;
106 (!check_k || k < self.rows) && (!check_n || n < self.cols)
107 }
108}
109
110impl<'a, R: Runtime> WeightLayoutLaunch<'a, R> {
111 pub fn from_args(
112 client: &ComputeClient<R>,
113 problem: &ConvolutionProblem,
114 config: GlobalMemoryConfig,
115 ) -> Self {
116 match problem.operation {
117 ConvolutionOperation::Forward
118 | ConvolutionOperation::ForwardTransposed
119 | ConvolutionOperation::BackwardData => Self::from_args_rhs(client, problem, config),
120 ConvolutionOperation::BackwardWeight => Self::from_args_out(client, problem, config),
121 }
122 }
123
124 fn from_args_rhs(
125 client: &ComputeClient<R>,
126 problem: &ConvolutionProblem,
127 config: GlobalMemoryConfig,
128 ) -> Self {
129 let padded_channels = problem.padded_channels as u32;
130 let padded_channels = FastDivmodArgs::<u32>::new(client, padded_channels);
131 let shape_k = ScalarArg::new(problem.k as u32);
132 let shape_n = ScalarArg::new(problem.n as u32);
133
134 let params = ConvolutionParams::from_problem(problem);
135
136 WeightLayoutLaunch::new(padded_channels, shape_k, shape_n, params, config)
137 }
138
139 fn from_args_out(
140 client: &ComputeClient<R>,
141 problem: &ConvolutionProblem,
142 config: GlobalMemoryConfig,
143 ) -> Self {
144 let padded_channels = problem.padded_channels as u32;
145 let padded_channels = FastDivmodArgs::<u32>::new(client, padded_channels);
146 let shape_m = ScalarArg::new(problem.m as u32);
147 let shape_n = ScalarArg::new(problem.n as u32);
148
149 let params = ConvolutionParams::from_problem(problem);
150
151 WeightLayoutLaunch::new(padded_channels, shape_n, shape_m, params, config)
152 }
153}