cubecl_convolution/components/global/layout/
weight.rs1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl};
3use cubecl_matmul::components::global::{GlobalConfig, memory::GlobalMemoryConfig};
4use cubecl_std::{
5 FastDivmod, FastDivmodArgs,
6 tensor::layout::{Coords3d, Layout, LayoutExpand},
7};
8
9use crate::{
10 components::{
11 ConvGemmConfig, ConvolutionConfig, ConvolutionParams, ConvolutionProblem,
12 global::layout::NhwcCoords,
13 },
14 kernels::layered::selector::RuntimeArgs,
15};
16
17#[derive(CubeType, CubeLaunch, Clone)]
20pub struct WeightLayout {
21 pub channels: FastDivmod,
23
24 pub shape_k: u32,
26 pub shape_n: u32,
28
29 #[cube(comptime)]
31 pub params: ConvolutionParams,
32 #[cube(comptime)]
34 pub config: GlobalMemoryConfig,
35}
36
37#[cube]
38impl WeightLayout {
39 pub fn new<E: Numeric, G: GlobalConfig>(
40 args: &RuntimeArgs,
41 #[comptime] config: ConvolutionConfig<G>,
42 ) -> WeightLayout {
43 WeightLayout {
44 shape_k: args.shape_k,
45 shape_n: args.shape_n,
46 channels: args.padded_channels,
47 params: config.convolution_params,
48 config: config.rhs_global_memory_config(),
49 }
50 }
51}
52
53#[cube]
54impl Layout for WeightLayout {
55 type Coordinates = Coords3d;
56 type SourceCoordinates = NhwcCoords;
57
58 fn to_source_pos(&self, coords: Self::Coordinates) -> NhwcCoords {
59 let params = comptime![self.params];
60 let (_, k, n) = coords;
61
62 let (mut rem, in_c) = self.channels.div_mod(k);
63
64 let spatial_dims = comptime![params.dimensionality.num_dims()];
65 let mut kernel_pos = Sequence::<i32>::new();
66
67 #[unroll]
68 for i in 0..spatial_dims {
69 let dim = comptime![spatial_dims - i - 1];
70 let ksize = comptime![params.kernel_size[dim as usize]];
71 let k_pos = rem % ksize;
72 rem /= ksize;
73
74 kernel_pos.push(k_pos as i32);
75 }
76
77 let kernel_pos = kernel_pos.rev();
78
79 NhwcCoords {
80 batch: n,
81 spatial: kernel_pos,
82 channel: in_c,
83 }
84 }
85
86 fn to_source_pos_checked(&self, coords: Self::Coordinates) -> (NhwcCoords, bool) {
87 (self.to_source_pos(coords), self.is_in_bounds(coords))
88 }
89
90 fn shape(&self) -> Self::Coordinates {
91 (1, self.shape_k, self.shape_n)
92 }
93
94 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
95 let (_, k, n) = pos;
96 let check_k = comptime![self.config.check_row_bounds];
97 let check_n = comptime![self.config.check_col_bounds];
98 (!check_k || k < self.shape_k) && (!check_n || n < self.shape_n)
99 }
100}
101
102impl<'a, R: Runtime> WeightLayoutLaunch<'a, R> {
103 pub fn from_args(
104 client: &ComputeClient<R::Server>,
105 problem: &ConvolutionProblem,
106 params: ConvolutionParams,
107 config: GlobalMemoryConfig,
108 ) -> Self {
109 let channels = FastDivmodArgs::new(client, problem.channels as u32);
110 let shape_k = ScalarArg::new(problem.k as u32);
111 let shape_n = ScalarArg::new(problem.n as u32);
112
113 WeightLayoutLaunch::new(channels, shape_k, shape_n, params, config)
114 }
115}