cubecl_convolution/components/global/read/reader/
layout.rs1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3use cubecl_std::{
4 FastDivmod,
5 tensor::layout::{Coords3d, Layout, LayoutExpand},
6};
7
8use crate::components::global::layout::NhwcCoords;
9
10#[derive(CubeType, CubeLaunch)]
11pub struct TmaWeightLayout {
12 padded_channels: FastDivmod,
13 #[cube(comptime)]
14 kernel_size: Vec<u32>,
15}
16
17#[cube]
18impl TmaWeightLayout {
19 pub fn new(padded_channels: FastDivmod, #[comptime] kernel_size: Vec<u32>) -> Self {
20 TmaWeightLayout {
21 padded_channels,
22 kernel_size,
23 }
24 }
25}
26
27#[cube]
28impl Layout for TmaWeightLayout {
29 type Coordinates = Coords3d;
30 type SourceCoordinates = NhwcCoords;
31
32 fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
33 let (_, k, n) = pos;
34 let (mut k_idx, in_c) = self.padded_channels.div_mod(k);
35 let k_rank = comptime![self.kernel_size.len() as u32];
36 let mut k_pos = Sequence::new();
37
38 #[unroll]
39 for i in 0..k_rank {
40 let dim = comptime![k_rank - i - 1];
41 let k_size = comptime![self.kernel_size[dim as usize]];
42 k_pos.push((k_idx % k_size) as i32);
43 k_idx /= k_size;
44 }
45
46 NhwcCoords {
47 batch: n,
48 spatial: k_pos.rev(),
49 channel: in_c,
50 }
51 }
52
53 fn is_in_bounds(&self, _pos: Self::Coordinates) -> bool {
54 true.runtime()
55 }
56
57 fn shape(&self) -> Self::Coordinates {
58 (u32::MAX, u32::MAX, u32::MAX).runtime()
59 }
60
61 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
62 (self.to_source_pos(pos), self.is_in_bounds(pos))
63 }
64}
65
66#[derive(CubeType, CubeLaunch)]
68pub struct TmaDummyLayout {}
69
70#[cube]
71impl Layout for TmaDummyLayout {
72 type Coordinates = Coords3d;
73 type SourceCoordinates = Coords3d;
74
75 fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
76 pos
77 }
78
79 fn is_in_bounds(&self, _pos: Self::Coordinates) -> bool {
80 true.runtime()
81 }
82
83 fn shape(&self) -> Self::Coordinates {
84 (u32::MAX, u32::MAX, u32::MAX).runtime()
85 }
86
87 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
88 (self.to_source_pos(pos), self.is_in_bounds(pos))
89 }
90}