cubecl_std/tensor/layout/
strided.rs

1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3
4use crate::{
5    FastDivmod, FastDivmodArgs,
6    tensor::layout::{Coords1d, Layout, LayoutExpand},
7};
8
9/// Layout for tensors strided only on the last dimension, i.e. freshly allocated ones. Treats the
10/// tensor as 2D for the purposes of indexing, with the remaining dimensions being collapsed into
11/// a single contiguous one
12#[derive(CubeType, CubeLaunch, Clone)]
13pub struct StridedLayout {
14    shape: FastDivmod,
15    stride: u32,
16    len: u32,
17    #[cube(comptime)]
18    line_size: u32,
19}
20
21impl<'a, R: Runtime> StridedLayoutLaunch<'a, R> {
22    pub fn from_shape_strides(
23        client: &ComputeClient<R::Server>,
24        shape: &[usize],
25        strides: &[usize],
26        line_size: u8,
27    ) -> Self {
28        let rank = shape.len();
29        let len = shape.iter().product::<usize>() / line_size as usize;
30        Self::new(
31            FastDivmodArgs::new(client, shape[rank - 1] as u32),
32            ScalarArg::new(strides[rank - 2] as u32),
33            ScalarArg::new(len as u32),
34            line_size as u32,
35        )
36    }
37
38    pub fn from_handle(
39        client: &ComputeClient<R::Server>,
40        handle: &TensorHandleRef<'_, R>,
41        line_size: u8,
42    ) -> Self {
43        Self::from_shape_strides(client, handle.shape, handle.strides, line_size)
44    }
45}
46
47#[cube]
48impl Layout for StridedLayout {
49    type Coordinates = Coords1d;
50    type SourceCoordinates = Coords1d;
51
52    fn to_source_pos(&self, pos: Self::Coordinates) -> u32 {
53        let offset_abs = pos * self.line_size;
54        let (y, x) = self.shape.div_mod(offset_abs);
55        let offset = y * self.stride + x;
56        offset / self.line_size
57    }
58
59    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (u32, bool) {
60        (self.to_source_pos(pos), self.is_in_bounds(pos))
61    }
62
63    fn shape(&self) -> Self::Coordinates {
64        self.len
65    }
66
67    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
68        pos < self.len
69    }
70}