cubecl_convolution/components/global/layout/
write.rs1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl};
3use cubecl_matmul::components::global::memory::GlobalMemoryConfig;
4use cubecl_std::{
5 FastDivmod, FastDivmodArgs,
6 tensor::layout::{Coords3d, Layout, LayoutExpand},
7};
8
9use crate::components::{
10 ConvolutionProblem,
11 global::{
12 args::RuntimeArgs,
13 layout::{NhwcCoords, cast_seq},
14 read::im2col_tma::div_mod_seq,
15 },
16};
17
18#[derive(CubeType, CubeLaunch, Clone)]
21pub struct OutLayout {
22 pub shape_out: Sequence<FastDivmod>,
24
25 pub shape_m: u32,
27 pub shape_n: u32,
29
30 #[cube(comptime)]
32 pub config: GlobalMemoryConfig,
33}
34
35#[cube]
36impl OutLayout {
37 pub fn new(
38 args: &RuntimeArgs,
39 shape_out: Sequence<FastDivmod>,
40 #[comptime] config: GlobalMemoryConfig,
41 ) -> OutLayout {
42 OutLayout {
43 shape_out,
44 shape_m: args.shape_m,
45 shape_n: args.shape_n,
46 config,
47 }
48 }
49}
50
51#[cube]
52impl Layout for OutLayout {
53 type Coordinates = Coords3d;
54 type SourceCoordinates = NhwcCoords;
55
56 fn to_source_pos(&self, coords: Self::Coordinates) -> NhwcCoords {
57 let (_, view_m, view_n) = coords;
58 let (batch, spatial) = div_mod_seq(view_m, &self.shape_out);
59
60 NhwcCoords {
61 batch,
62 spatial: cast_seq(spatial),
63 channel: view_n,
64 }
65 }
66
67 fn to_source_pos_checked(&self, coords: Self::Coordinates) -> (NhwcCoords, bool) {
68 (self.to_source_pos(coords), self.is_in_bounds(coords))
69 }
70
71 fn shape(&self) -> Self::Coordinates {
72 (1, self.shape_m, self.shape_n)
73 }
74
75 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
76 let (_, m, n) = pos;
77 let check_m = comptime![self.config.check_row_bounds];
78 let check_n = comptime![self.config.check_col_bounds];
79 (!check_m || m < self.shape_m) && (!check_n || n < self.shape_n)
80 }
81}
82
83impl<'a, R: Runtime> OutLayoutLaunch<'a, R> {
84 pub fn from_args(
85 client: &ComputeClient<R>,
86 problem: &ConvolutionProblem,
87 config: GlobalMemoryConfig,
88 ) -> Self {
89 let shape_out = problem
90 .out_shape
91 .iter()
92 .map(|s| FastDivmodArgs::new(client, *s as u32))
93 .collect();
94 let shape_m = ScalarArg::new(problem.m as u32);
95 let shape_n = ScalarArg::new(problem.n as u32);
96
97 Self::new(shape_out, shape_m, shape_n, config)
98 }
99}