cubek_convolution/components/global/read/reader/
layout.rs

1use crate::components::global::layout::NhwcCoords;
2use cubecl::prelude::*;
3use cubecl::std::{
4    FastDivmod,
5    tensor::layout::{Coords3d, Layout, LayoutExpand},
6};
7
8#[derive(CubeType, CubeLaunch)]
9pub struct TmaWeightLayout {
10    padded_channels: FastDivmod,
11    #[cube(comptime)]
12    kernel_size: Vec<u32>,
13}
14
15#[cube]
16impl TmaWeightLayout {
17    pub fn new(padded_channels: FastDivmod, #[comptime] kernel_size: Vec<u32>) -> Self {
18        TmaWeightLayout {
19            padded_channels,
20            kernel_size,
21        }
22    }
23}
24
25#[cube]
26impl Layout for TmaWeightLayout {
27    type Coordinates = Coords3d;
28    type SourceCoordinates = NhwcCoords;
29
30    fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
31        let (_, k, n) = pos;
32        let (mut k_idx, in_c) = self.padded_channels.div_mod(k);
33        let k_rank = comptime![self.kernel_size.len() as u32];
34        let mut k_pos = Sequence::new();
35
36        #[unroll]
37        for i in 0..k_rank {
38            let dim = comptime![k_rank - i - 1];
39            let k_size = comptime![self.kernel_size[dim as usize]];
40            k_pos.push((k_idx % k_size) as i32);
41            k_idx /= k_size;
42        }
43
44        NhwcCoords {
45            batch: n,
46            spatial: k_pos.rev(),
47            channel: in_c,
48        }
49    }
50
51    fn is_in_bounds(&self, _pos: Self::Coordinates) -> bool {
52        true.runtime()
53    }
54
55    fn shape(&self) -> Self::Coordinates {
56        (u32::MAX, u32::MAX, u32::MAX).runtime()
57    }
58
59    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
60        (self.to_source_pos(pos), self.is_in_bounds(pos))
61    }
62}
63
64/// Dummy layout for launching, to be exited out later with `as_tensor_map`.
65#[derive(CubeType, CubeLaunch)]
66pub struct TmaDummyLayout {}
67
68#[cube]
69impl Layout for TmaDummyLayout {
70    type Coordinates = Coords3d;
71    type SourceCoordinates = Coords3d;
72
73    fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
74        pos
75    }
76
77    fn is_in_bounds(&self, _pos: Self::Coordinates) -> bool {
78        true.runtime()
79    }
80
81    fn shape(&self) -> Self::Coordinates {
82        (u32::MAX, u32::MAX, u32::MAX).runtime()
83    }
84
85    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
86        (self.to_source_pos(pos), self.is_in_bounds(pos))
87    }
88}