cubek_convolution/components/global/layout/
write.rs1use cubecl::prelude::*;
2use cubecl::std::{
3 FastDivmod, FastDivmodArgs,
4 tensor::layout::{Coords3d, Layout, LayoutExpand},
5};
6use cubek_matmul::components::global::memory::GlobalMemoryConfig;
7
8use crate::components::{
9 ConvolutionProblem,
10 global::{
11 args::RuntimeArgs,
12 layout::{NhwcCoords, cast_seq},
13 read::im2col_tma::div_mod_seq,
14 },
15};
16
17#[derive(CubeType, CubeLaunch, Clone)]
20pub struct OutLayout {
21 pub shape_out: Sequence<FastDivmod>,
23
24 pub shape_m: u32,
26 pub shape_n: u32,
28
29 #[cube(comptime)]
31 pub config: GlobalMemoryConfig,
32}
33
34#[cube]
35impl OutLayout {
36 pub fn new(
37 args: &RuntimeArgs,
38 shape_out: Sequence<FastDivmod>,
39 #[comptime] config: GlobalMemoryConfig,
40 ) -> OutLayout {
41 OutLayout {
42 shape_out,
43 shape_m: args.shape_m,
44 shape_n: args.shape_n,
45 config,
46 }
47 }
48}
49
50#[cube]
51impl Layout for OutLayout {
52 type Coordinates = Coords3d;
53 type SourceCoordinates = NhwcCoords;
54
55 fn to_source_pos(&self, coords: Self::Coordinates) -> NhwcCoords {
56 let (_, view_m, view_n) = coords;
57 let (batch, spatial) = div_mod_seq(view_m, &self.shape_out);
58
59 NhwcCoords {
60 batch,
61 spatial: cast_seq(spatial),
62 channel: view_n,
63 }
64 }
65
66 fn to_source_pos_checked(&self, coords: Self::Coordinates) -> (NhwcCoords, bool) {
67 (self.to_source_pos(coords), self.is_in_bounds(coords))
68 }
69
70 fn shape(&self) -> Self::Coordinates {
71 (1, self.shape_m, self.shape_n)
72 }
73
74 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
75 let (_, m, n) = pos;
76 let check_m = comptime![self.config.check_row_bounds];
77 let check_n = comptime![self.config.check_col_bounds];
78 (!check_m || m < self.shape_m) && (!check_n || n < self.shape_n)
79 }
80}
81
82impl<'a, R: Runtime> OutLayoutLaunch<'a, R> {
83 pub fn from_args(
84 client: &ComputeClient<R>,
85 problem: &ConvolutionProblem,
86 config: GlobalMemoryConfig,
87 ) -> Self {
88 let shape_out = problem
89 .out_shape
90 .iter()
91 .map(|s| FastDivmodArgs::new(client, *s as u32))
92 .collect();
93 let shape_m = ScalarArg::new(problem.m as u32);
94 let shape_n = ScalarArg::new(problem.n as u32);
95
96 Self::new(shape_out, shape_m, shape_n, config)
97 }
98}