cubecl_convolution/components/global/layout/
write.rs1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl};
3use cubecl_matmul::components::global::memory::GlobalMemoryConfig;
4use cubecl_std::{
5 FastDivmod, FastDivmodArgs,
6 tensor::layout::{Coords3d, Layout, LayoutExpand},
7};
8
9use crate::{
10 components::{
11 ConvolutionProblem,
12 global::{
13 layout::{NhwcCoords, cast_seq},
14 read::im2col_tma::div_mod_seq,
15 },
16 },
17 kernels::layered::selector::RuntimeArgs,
18};
19
20#[derive(CubeType, CubeLaunch, Clone)]
23pub struct OutLayout {
24 pub shape_out: Sequence<FastDivmod>,
26
27 pub shape_m: u32,
29 pub shape_n: u32,
31
32 #[cube(comptime)]
34 pub config: GlobalMemoryConfig,
35}
36
37#[cube]
38impl OutLayout {
39 pub fn new(args: &RuntimeArgs, #[comptime] config: GlobalMemoryConfig) -> OutLayout {
40 OutLayout {
41 shape_out: args.shape_out.clone(),
42 shape_m: args.shape_m,
43 shape_n: args.shape_n,
44 config,
45 }
46 }
47}
48
49#[cube]
50impl Layout for OutLayout {
51 type Coordinates = Coords3d;
52 type SourceCoordinates = NhwcCoords;
53
54 fn to_source_pos(&self, coords: Self::Coordinates) -> NhwcCoords {
55 let (_, view_m, view_n) = coords;
56 let (batch, spatial) = div_mod_seq(view_m, &self.shape_out);
57
58 NhwcCoords {
59 batch,
60 spatial: cast_seq(spatial),
61 channel: view_n,
62 }
63 }
64
65 fn to_source_pos_checked(&self, coords: Self::Coordinates) -> (NhwcCoords, bool) {
66 (self.to_source_pos(coords), self.is_in_bounds(coords))
67 }
68
69 fn shape(&self) -> Self::Coordinates {
70 (1, self.shape_m, self.shape_n)
71 }
72
73 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
74 let (_, m, n) = pos;
75 let check_m = comptime![self.config.check_row_bounds];
76 let check_n = comptime![self.config.check_col_bounds];
77 (!check_m || m < self.shape_m) && (!check_n || n < self.shape_n)
78 }
79}
80
81impl<'a, R: Runtime> OutLayoutLaunch<'a, R> {
82 pub fn from_args(
83 client: &ComputeClient<R::Server>,
84 problem: &ConvolutionProblem,
85 config: GlobalMemoryConfig,
86 ) -> Self {
87 let shape_out = problem
88 .out_shape
89 .iter()
90 .map(|s| FastDivmodArgs::new(client, *s as u32))
91 .collect();
92 let shape_m = ScalarArg::new(problem.m as u32);
93 let shape_n = ScalarArg::new(problem.n as u32);
94
95 Self::new(shape_out, shape_m, shape_n, config)
96 }
97}