cubek_convolution/components/global/read/reader/
layout.rs1use crate::components::global::layout::NhwcCoords;
2use cubecl::prelude::*;
3use cubecl::std::{
4 FastDivmod,
5 tensor::layout::{Coords3d, Layout, LayoutExpand},
6};
7
8#[derive(CubeType, CubeLaunch)]
9pub struct TmaWeightLayout {
10 padded_channels: FastDivmod,
11 #[cube(comptime)]
12 kernel_size: Vec<u32>,
13}
14
15#[cube]
16impl TmaWeightLayout {
17 pub fn new(padded_channels: FastDivmod, #[comptime] kernel_size: Vec<u32>) -> Self {
18 TmaWeightLayout {
19 padded_channels,
20 kernel_size,
21 }
22 }
23}
24
25#[cube]
26impl Layout for TmaWeightLayout {
27 type Coordinates = Coords3d;
28 type SourceCoordinates = NhwcCoords;
29
30 fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
31 let (_, k, n) = pos;
32 let (mut k_idx, in_c) = self.padded_channels.div_mod(k);
33 let k_rank = comptime![self.kernel_size.len() as u32];
34 let mut k_pos = Sequence::new();
35
36 #[unroll]
37 for i in 0..k_rank {
38 let dim = comptime![k_rank - i - 1];
39 let k_size = comptime![self.kernel_size[dim as usize]];
40 k_pos.push((k_idx % k_size) as i32);
41 k_idx /= k_size;
42 }
43
44 NhwcCoords {
45 batch: n,
46 spatial: k_pos.rev(),
47 channel: in_c,
48 }
49 }
50
51 fn is_in_bounds(&self, _pos: Self::Coordinates) -> bool {
52 true.runtime()
53 }
54
55 fn shape(&self) -> Self::Coordinates {
56 (u32::MAX, u32::MAX, u32::MAX).runtime()
57 }
58
59 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
60 (self.to_source_pos(pos), self.is_in_bounds(pos))
61 }
62}
63
64#[derive(CubeType, CubeLaunch)]
66pub struct TmaDummyLayout {}
67
68#[cube]
69impl Layout for TmaDummyLayout {
70 type Coordinates = Coords3d;
71 type SourceCoordinates = Coords3d;
72
73 fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
74 pos
75 }
76
77 fn is_in_bounds(&self, _pos: Self::Coordinates) -> bool {
78 true.runtime()
79 }
80
81 fn shape(&self) -> Self::Coordinates {
82 (u32::MAX, u32::MAX, u32::MAX).runtime()
83 }
84
85 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
86 (self.to_source_pos(pos), self.is_in_bounds(pos))
87 }
88}