cubek_convolution/components/global/layout/
out.rs1use cubecl::prelude::*;
2use cubecl::std::{
3 FastDivmod,
4 tensor::layout::{Layout, LayoutExpand},
5};
6use cubek_matmul::{components::global::memory::GlobalLayoutConfig, launch::BatchedCoords};
7
8use crate::components::{
9 ConvolutionOperation, ConvolutionProblem,
10 global::layout::{NhwcCoords, cast_seq, div_mod_seq},
11};
12
13#[derive(CubeType, CubeLaunch, Clone)]
16pub struct OutLayout {
17 pub shape_out: Sequence<FastDivmod<u32>>,
19
20 pub rows: u32,
22 pub cols: u32,
24
25 #[cube(comptime)]
27 pub config: GlobalLayoutConfig,
28}
29
30#[cube]
31impl OutLayout {
32 pub fn new(
33 rows: u32,
34 cols: u32,
35 shape_out: Sequence<FastDivmod<u32>>,
36 #[comptime] config: GlobalLayoutConfig,
37 ) -> OutLayout {
38 OutLayout {
39 shape_out,
40 rows,
41 cols,
42 config,
43 }
44 }
45}
46
47#[cube]
48impl Layout for OutLayout {
49 type Coordinates = BatchedCoords;
50 type SourceCoordinates = NhwcCoords;
51
52 fn to_source_pos(&self, coords: Self::Coordinates) -> NhwcCoords {
53 let (_, view_m, view_n) = coords;
54 let (batch, spatial) = div_mod_seq(view_m, &self.shape_out);
55
56 NhwcCoords {
57 batch,
58 spatial: cast_seq(spatial),
59 channel: view_n,
60 }
61 }
62
63 fn to_source_pos_checked(&self, coords: Self::Coordinates) -> (NhwcCoords, bool) {
64 (self.to_source_pos(coords), self.is_in_bounds(coords))
65 }
66
67 fn shape(&self) -> Self::Coordinates {
68 (1, self.rows, self.cols)
69 }
70
71 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
72 let (_, row, col) = pos;
73 (!self.config.check_row_bounds || row < self.rows)
74 && (!self.config.check_col_bounds || col < self.cols)
75 }
76}
77
78impl<R: Runtime> OutLayoutLaunch<R> {
79 pub fn from_args(problem: &ConvolutionProblem, config: GlobalLayoutConfig) -> Self {
80 match problem.operation {
81 ConvolutionOperation::Forward => Self::from_args_fprop(problem, config),
82 ConvolutionOperation::ForwardTransposed | ConvolutionOperation::BackwardData => {
83 Self::from_args_dgrad(problem, config)
84 }
85 ConvolutionOperation::BackwardWeight => Self::from_args_wgrad(problem, config),
86 }
87 }
88
89 fn from_args_fprop(problem: &ConvolutionProblem, config: GlobalLayoutConfig) -> Self {
90 let shape_out = problem.out_shape.iter().map(|s| *s as u32).collect();
91 let shape_m = problem.m as u32;
92 let shape_n = problem.n as u32;
93
94 Self::new(shape_out, shape_m, shape_n, config)
95 }
96
97 fn from_args_dgrad(problem: &ConvolutionProblem, config: GlobalLayoutConfig) -> Self {
98 let shape = problem.in_shape.iter().map(|s| *s as u32).collect();
99 let shape_m = problem.m as u32;
100 let shape_n = problem.n as u32;
101
102 Self::new(shape, shape_m, shape_n, config)
103 }
104
105 fn from_args_wgrad(problem: &ConvolutionProblem, config: GlobalLayoutConfig) -> Self {
106 let shape_out = problem.out_shape.iter().map(|s| *s as u32).collect();
107 let shape_m = problem.m as u32;
108 let shape_k = problem.k as u32;
109
110 Self::new(shape_out, shape_k, shape_m, config)
111 }
112}