cubecl_convolution/components/global/read/reader/
layout.rs

1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3use cubecl_std::{
4    FastDivmod,
5    tensor::layout::{Coords3d, Layout, LayoutExpand},
6};
7
8#[derive(CubeType, CubeLaunch)]
9pub struct TmaWeightLayout {
10    padded_channels: FastDivmod,
11}
12
13#[cube]
14impl TmaWeightLayout {
15    pub fn new(padded_channels: FastDivmod) -> Self {
16        TmaWeightLayout { padded_channels }
17    }
18}
19
20#[cube]
21impl Layout for TmaWeightLayout {
22    type Coordinates = Coords3d;
23    type SourceCoordinates = Coords3d;
24
25    fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
26        let (_, k, n) = pos;
27        let (k_idx, in_c) = self.padded_channels.div_mod(k);
28        (n, k_idx, in_c)
29    }
30
31    fn is_in_bounds(&self, _pos: Self::Coordinates) -> bool {
32        true.runtime()
33    }
34
35    fn shape(&self) -> Self::Coordinates {
36        (u32::MAX, u32::MAX, u32::MAX).runtime()
37    }
38
39    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
40        (self.to_source_pos(pos), self.is_in_bounds(pos))
41    }
42}
43
44/// Dummy layout for launching, to be exited out later with `as_tensor_map`.
45#[derive(CubeType, CubeLaunch)]
46pub struct TmaDummyLayout {}
47
48#[cube]
49impl Layout for TmaDummyLayout {
50    type Coordinates = Coords3d;
51    type SourceCoordinates = Coords3d;
52
53    fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
54        pos
55    }
56
57    fn is_in_bounds(&self, _pos: Self::Coordinates) -> bool {
58        true.runtime()
59    }
60
61    fn shape(&self) -> Self::Coordinates {
62        (u32::MAX, u32::MAX, u32::MAX).runtime()
63    }
64
65    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
66        (self.to_source_pos(pos), self.is_in_bounds(pos))
67    }
68}