cubek_convolution/components/global/layout/
out.rs1use cubecl::prelude::*;
2use cubecl::std::{
3 FastDivmod, FastDivmodArgs,
4 tensor::layout::{Layout, LayoutExpand},
5};
6use cubek_matmul::{components::global::memory::GlobalMemoryConfig, launch::BatchedCoords};
7
8use crate::components::{
9 ConvolutionOperation, ConvolutionProblem,
10 global::layout::{NhwcCoords, cast_seq, div_mod_seq},
11};
12
13#[derive(CubeType, CubeLaunch, Clone)]
16pub struct OutLayout {
17 pub shape_out: Sequence<FastDivmod<u32>>,
19
20 pub rows: u32,
22 pub cols: u32,
24
25 #[cube(comptime)]
27 pub config: GlobalMemoryConfig,
28}
29
30#[cube]
31impl OutLayout {
32 pub fn new(
33 rows: u32,
34 cols: u32,
35 shape_out: Sequence<FastDivmod<u32>>,
36 #[comptime] config: GlobalMemoryConfig,
37 ) -> OutLayout {
38 OutLayout {
39 shape_out,
40 rows,
41 cols,
42 config,
43 }
44 }
45}
46
47#[cube]
48impl Layout for OutLayout {
49 type Coordinates = BatchedCoords;
50 type SourceCoordinates = NhwcCoords;
51
52 fn to_source_pos(&self, coords: Self::Coordinates) -> NhwcCoords {
53 let (_, view_m, view_n) = coords;
54 let (batch, spatial) = div_mod_seq(view_m, &self.shape_out);
55
56 NhwcCoords {
57 batch,
58 spatial: cast_seq(spatial),
59 channel: view_n,
60 }
61 }
62
63 fn to_source_pos_checked(&self, coords: Self::Coordinates) -> (NhwcCoords, bool) {
64 (self.to_source_pos(coords), self.is_in_bounds(coords))
65 }
66
67 fn shape(&self) -> Self::Coordinates {
68 (1, self.rows, self.cols)
69 }
70
71 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
72 let (_, row, col) = pos;
73 (!self.config.check_row_bounds || row < self.rows)
74 && (!self.config.check_col_bounds || col < self.cols)
75 }
76}
77
78impl<'a, R: Runtime> OutLayoutLaunch<'a, R> {
79 pub fn from_args(
80 client: &ComputeClient<R>,
81 problem: &ConvolutionProblem,
82 config: GlobalMemoryConfig,
83 ) -> Self {
84 match problem.operation {
85 ConvolutionOperation::Forward => Self::from_args_fprop(client, problem, config),
86 ConvolutionOperation::ForwardTransposed | ConvolutionOperation::BackwardData => {
87 Self::from_args_dgrad(client, problem, config)
88 }
89 ConvolutionOperation::BackwardWeight => Self::from_args_wgrad(client, problem, config),
90 }
91 }
92
93 fn from_args_fprop(
94 client: &ComputeClient<R>,
95 problem: &ConvolutionProblem,
96 config: GlobalMemoryConfig,
97 ) -> Self {
98 let shape_out = problem
99 .out_shape
100 .iter()
101 .map(|s| FastDivmodArgs::<u32>::new(client, *s as u32))
102 .collect();
103 let shape_m = ScalarArg::new(problem.m as u32);
104 let shape_n = ScalarArg::new(problem.n as u32);
105
106 Self::new(shape_out, shape_m, shape_n, config)
107 }
108
109 fn from_args_dgrad(
110 client: &ComputeClient<R>,
111 problem: &ConvolutionProblem,
112 config: GlobalMemoryConfig,
113 ) -> Self {
114 let shape = problem
115 .in_shape
116 .iter()
117 .map(|s| FastDivmodArgs::<u32>::new(client, *s as u32))
118 .collect();
119 let shape_m = ScalarArg::new(problem.m as u32);
120 let shape_n = ScalarArg::new(problem.n as u32);
121
122 Self::new(shape, shape_m, shape_n, config)
123 }
124
125 fn from_args_wgrad(
126 client: &ComputeClient<R>,
127 problem: &ConvolutionProblem,
128 config: GlobalMemoryConfig,
129 ) -> Self {
130 let shape_out = problem
131 .out_shape
132 .iter()
133 .map(|s| FastDivmodArgs::<u32>::new(client, *s as u32))
134 .collect();
135 let shape_m = ScalarArg::new(problem.m as u32);
136 let shape_k = ScalarArg::new(problem.k as u32);
137
138 Self::new(shape_out, shape_k, shape_m, config)
139 }
140}