cubek_convolution/components/global/layout/
weight.rs1use cubecl::prelude::*;
2use cubecl::std::{
3 FastDivmod,
4 tensor::layout::{Layout, LayoutExpand},
5};
6use cubek_matmul::{
7 components::global::{GlobalConfig, memory::GlobalLayoutConfig},
8 launch::BatchedCoords,
9};
10
11use crate::components::{
12 ConvolutionOperation, ConvolutionParams, ConvolutionProblem, global::layout::NhwcCoords,
13};
14
15#[derive(CubeType, CubeLaunch, Clone)]
18pub struct WeightLayout {
19 pub padded_channels: FastDivmod<u32>,
21
22 pub rows: u32,
24 pub cols: u32,
26
27 #[cube(comptime)]
29 pub params: ConvolutionParams,
30 #[cube(comptime)]
32 pub config: GlobalLayoutConfig,
33}
34
35#[cube]
36impl WeightLayout {
37 pub fn new<E: Numeric, G: GlobalConfig>(
38 rows: u32,
39 cols: u32,
40 padded_channels: FastDivmod<u32>,
41 #[comptime] config: GlobalLayoutConfig,
42 #[comptime] params: ConvolutionParams,
43 ) -> WeightLayout {
44 WeightLayout {
45 rows,
46 cols,
47 padded_channels,
48 config,
49 params,
50 }
51 }
52}
53
54#[cube]
55impl Layout for WeightLayout {
56 type Coordinates = BatchedCoords;
57 type SourceCoordinates = NhwcCoords;
58
59 fn to_source_pos(&self, coords: Self::Coordinates) -> NhwcCoords {
60 let params = self.params.comptime();
61 let (_, k, n) = coords;
62
63 let (mut rem, k_channel) = self.padded_channels.div_mod(k);
64
65 let spatial_dims = params.dimensionality.num_dims();
66 let mut kernel_pos = Sequence::<i32>::new();
67
68 #[unroll]
69 for i in 0..spatial_dims {
70 let dim = spatial_dims - i - 1;
71 let ksize = params.kernel_size[dim];
72 let k_pos = rem % ksize;
73 rem /= ksize;
74
75 kernel_pos.push(k_pos as i32);
76 }
77
78 let kernel_pos = kernel_pos.rev();
79
80 let (batch, channel) = match params.operation {
81 ConvolutionOperation::Forward | ConvolutionOperation::BackwardWeight => (n, k_channel),
82 ConvolutionOperation::ForwardTransposed | ConvolutionOperation::BackwardData => {
83 (k_channel, n)
84 }
85 };
86
87 NhwcCoords {
88 batch,
89 spatial: kernel_pos,
90 channel,
91 }
92 }
93
94 fn to_source_pos_checked(&self, coords: Self::Coordinates) -> (NhwcCoords, bool) {
95 (self.to_source_pos(coords), self.is_in_bounds(coords))
96 }
97
98 fn shape(&self) -> Self::Coordinates {
99 (1, self.rows, self.cols)
100 }
101
102 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
103 let (_, k, n) = pos;
104 let check_k = self.config.check_row_bounds;
105 let check_n = self.config.check_col_bounds;
106 (!check_k || k < self.rows) && (!check_n || n < self.cols)
107 }
108}
109
110impl<R: Runtime> WeightLayoutLaunch<R> {
111 pub fn from_args(problem: &ConvolutionProblem, config: GlobalLayoutConfig) -> Self {
112 match problem.operation {
113 ConvolutionOperation::Forward
114 | ConvolutionOperation::ForwardTransposed
115 | ConvolutionOperation::BackwardData => Self::from_args_rhs(problem, config),
116 ConvolutionOperation::BackwardWeight => Self::from_args_out(problem, config),
117 }
118 }
119
120 fn from_args_rhs(problem: &ConvolutionProblem, config: GlobalLayoutConfig) -> Self {
121 let padded_channels = problem.padded_channels as u32;
122 let shape_k = problem.k as u32;
123 let shape_n = problem.n as u32;
124
125 let params = ConvolutionParams::from_problem(problem);
126
127 WeightLayoutLaunch::new(padded_channels, shape_k, shape_n, params, config)
128 }
129
130 fn from_args_out(problem: &ConvolutionProblem, config: GlobalLayoutConfig) -> Self {
131 let padded_channels = problem.padded_channels as u32;
132 let shape_m = problem.m as u32;
133 let shape_n = problem.n as u32;
134
135 let params = ConvolutionParams::from_problem(problem);
136
137 WeightLayoutLaunch::new(padded_channels, shape_n, shape_m, params, config)
138 }
139}