cubecl_convolution/components/global/read/reader/
layout.rs

1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3use cubecl_std::{
4    FastDivmod,
5    tensor::layout::{Coords3d, Layout, LayoutExpand},
6};
7
8use crate::components::global::layout::NhwcCoords;
9
10#[derive(CubeType, CubeLaunch)]
11pub struct TmaWeightLayout {
12    padded_channels: FastDivmod,
13    #[cube(comptime)]
14    kernel_size: Vec<u32>,
15}
16
17#[cube]
18impl TmaWeightLayout {
19    pub fn new(padded_channels: FastDivmod, #[comptime] kernel_size: Vec<u32>) -> Self {
20        TmaWeightLayout {
21            padded_channels,
22            kernel_size,
23        }
24    }
25}
26
27#[cube]
28impl Layout for TmaWeightLayout {
29    type Coordinates = Coords3d;
30    type SourceCoordinates = NhwcCoords;
31
32    fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
33        let (_, k, n) = pos;
34        let (mut k_idx, in_c) = self.padded_channels.div_mod(k);
35        let k_rank = comptime![self.kernel_size.len() as u32];
36        let mut k_pos = Sequence::new();
37
38        #[unroll]
39        for i in 0..k_rank {
40            let dim = comptime![k_rank - i - 1];
41            let k_size = comptime![self.kernel_size[dim as usize]];
42            k_pos.push((k_idx % k_size) as i32);
43            k_idx /= k_size;
44        }
45
46        NhwcCoords {
47            batch: n,
48            spatial: k_pos.rev(),
49            channel: in_c,
50        }
51    }
52
53    fn is_in_bounds(&self, _pos: Self::Coordinates) -> bool {
54        true.runtime()
55    }
56
57    fn shape(&self) -> Self::Coordinates {
58        (u32::MAX, u32::MAX, u32::MAX).runtime()
59    }
60
61    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
62        (self.to_source_pos(pos), self.is_in_bounds(pos))
63    }
64}
65
66/// Dummy layout for launching, to be exited out later with `as_tensor_map`.
67#[derive(CubeType, CubeLaunch)]
68pub struct TmaDummyLayout {}
69
70#[cube]
71impl Layout for TmaDummyLayout {
72    type Coordinates = Coords3d;
73    type SourceCoordinates = Coords3d;
74
75    fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
76        pos
77    }
78
79    fn is_in_bounds(&self, _pos: Self::Coordinates) -> bool {
80        true.runtime()
81    }
82
83    fn shape(&self) -> Self::Coordinates {
84        (u32::MAX, u32::MAX, u32::MAX).runtime()
85    }
86
87    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
88        (self.to_source_pos(pos), self.is_in_bounds(pos))
89    }
90}