Skip to main content

cubecl_std/tensor/layout/
strided.rs

1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl};
3
4use crate::{
5    FastDivmod,
6    tensor::{
7        launch::{BufferArg, ViewLayoutLaunchArg},
8        layout::{Coords1d, Layout, LayoutExpand},
9    },
10};
11
12/// Layout for tensors strided only on the last dimension, i.e. freshly allocated ones. Treats the
13/// tensor as 2D for the purposes of indexing, with the remaining dimensions being collapsed into
14/// a single contiguous one
15#[derive(CubeType, Clone)]
16pub struct StridedLayout {
17    shape: FastDivmod<usize>,
18    stride: usize,
19    len: usize,
20    #[cube(comptime)]
21    vector_size: VectorSize,
22}
23
24#[derive(Debug, Hash, PartialEq, Eq, Clone)]
25pub struct StridedLayoutCompilationArg {
26    shape: <FastDivmod<usize> as LaunchArg>::CompilationArg,
27}
28
29impl ViewLayoutLaunchArg for StridedLayout {
30    type RuntimeArg<R: Runtime> = ();
31    type CompilationArg = StridedLayoutCompilationArg;
32
33    fn register<R: Runtime, B: BufferArg>(
34        _: Self::RuntimeArg<R>,
35        buffer: &B,
36        ty: Type,
37        launcher: &mut KernelLauncher<R>,
38    ) -> Self::CompilationArg {
39        let shape = buffer.shape();
40        let strides = buffer.strides();
41        let rank = shape.len();
42        let len = shape.iter().product::<usize>() / ty.vector_size();
43
44        let shape = <FastDivmod<usize> as LaunchArg>::register(shape[rank - 1], launcher);
45        <usize as LaunchArg>::register(strides[rank - 2], launcher);
46        <usize as LaunchArg>::register(len, launcher);
47        StridedLayoutCompilationArg { shape }
48    }
49
50    fn expand(
51        arg: &Self::CompilationArg,
52        ty: Type,
53        builder: &mut KernelBuilder,
54    ) -> <Self as CubeType>::ExpandType {
55        StridedLayoutExpand {
56            shape: <FastDivmod<usize> as LaunchArg>::expand(&arg.shape, builder),
57            stride: <usize as LaunchArg>::expand(&(), builder),
58            len: <usize as LaunchArg>::expand(&(), builder),
59            vector_size: ty.vector_size(),
60        }
61    }
62}
63
64#[cube]
65impl Layout for StridedLayout {
66    type Coordinates = Coords1d;
67    type SourceCoordinates = Coords1d;
68
69    fn to_source_pos(&self, pos: Self::Coordinates) -> usize {
70        let offset_abs = pos * self.vector_size;
71        let (y, x) = self.shape.div_mod(offset_abs);
72        let offset = y * self.stride + x;
73        offset / self.vector_size
74    }
75
76    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (usize, bool) {
77        (self.to_source_pos(pos), self.is_in_bounds(pos))
78    }
79
80    fn shape(&self) -> Self::Coordinates {
81        self.len
82    }
83
84    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
85        pos < self.len
86    }
87}