cubecl_convolution/components/global/read/reader/
layout.rs1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3use cubecl_std::{
4 FastDivmod,
5 tensor::layout::{Coords3d, Layout, LayoutExpand},
6};
7
8#[derive(CubeType, CubeLaunch)]
9pub struct TmaWeightLayout {
10 padded_channels: FastDivmod,
11}
12
13#[cube]
14impl TmaWeightLayout {
15 pub fn new(padded_channels: FastDivmod) -> Self {
16 TmaWeightLayout { padded_channels }
17 }
18}
19
20#[cube]
21impl Layout for TmaWeightLayout {
22 type Coordinates = Coords3d;
23 type SourceCoordinates = Coords3d;
24
25 fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
26 let (_, k, n) = pos;
27 let (k_idx, in_c) = self.padded_channels.div_mod(k);
28 (n, k_idx, in_c)
29 }
30
31 fn is_in_bounds(&self, _pos: Self::Coordinates) -> bool {
32 true.runtime()
33 }
34
35 fn shape(&self) -> Self::Coordinates {
36 (u32::MAX, u32::MAX, u32::MAX).runtime()
37 }
38
39 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
40 (self.to_source_pos(pos), self.is_in_bounds(pos))
41 }
42}
43
44#[derive(CubeType, CubeLaunch)]
46pub struct TmaDummyLayout {}
47
48#[cube]
49impl Layout for TmaDummyLayout {
50 type Coordinates = Coords3d;
51 type SourceCoordinates = Coords3d;
52
53 fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
54 pos
55 }
56
57 fn is_in_bounds(&self, _pos: Self::Coordinates) -> bool {
58 true.runtime()
59 }
60
61 fn shape(&self) -> Self::Coordinates {
62 (u32::MAX, u32::MAX, u32::MAX).runtime()
63 }
64
65 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
66 (self.to_source_pos(pos), self.is_in_bounds(pos))
67 }
68}