cubecl_std/tensor/layout/
strided.rs1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl};
3
4use crate::{
5 FastDivmod,
6 tensor::{
7 launch::{BufferArg, ViewLayoutLaunchArg},
8 layout::{Coords1d, Layout, LayoutExpand},
9 },
10};
11
12#[derive(CubeType, Clone)]
16pub struct StridedLayout {
17 shape: FastDivmod<usize>,
18 stride: usize,
19 len: usize,
20 #[cube(comptime)]
21 vector_size: VectorSize,
22}
23
24#[derive(Debug, Hash, PartialEq, Eq, Clone)]
25pub struct StridedLayoutCompilationArg {
26 shape: <FastDivmod<usize> as LaunchArg>::CompilationArg,
27}
28
29impl ViewLayoutLaunchArg for StridedLayout {
30 type RuntimeArg<R: Runtime> = ();
31 type CompilationArg = StridedLayoutCompilationArg;
32
33 fn register<R: Runtime, B: BufferArg>(
34 _: Self::RuntimeArg<R>,
35 buffer: &B,
36 ty: Type,
37 launcher: &mut KernelLauncher<R>,
38 ) -> Self::CompilationArg {
39 let shape = buffer.shape();
40 let strides = buffer.strides();
41 let rank = shape.len();
42 let len = shape.iter().product::<usize>() / ty.vector_size();
43
44 let shape = <FastDivmod<usize> as LaunchArg>::register(shape[rank - 1], launcher);
45 <usize as LaunchArg>::register(strides[rank - 2], launcher);
46 <usize as LaunchArg>::register(len, launcher);
47 StridedLayoutCompilationArg { shape }
48 }
49
50 fn expand(
51 arg: &Self::CompilationArg,
52 ty: Type,
53 builder: &mut KernelBuilder,
54 ) -> <Self as CubeType>::ExpandType {
55 StridedLayoutExpand {
56 shape: <FastDivmod<usize> as LaunchArg>::expand(&arg.shape, builder),
57 stride: <usize as LaunchArg>::expand(&(), builder),
58 len: <usize as LaunchArg>::expand(&(), builder),
59 vector_size: ty.vector_size(),
60 }
61 }
62}
63
64#[cube]
65impl Layout for StridedLayout {
66 type Coordinates = Coords1d;
67 type SourceCoordinates = Coords1d;
68
69 fn to_source_pos(&self, pos: Self::Coordinates) -> usize {
70 let offset_abs = pos * self.vector_size;
71 let (y, x) = self.shape.div_mod(offset_abs);
72 let offset = y * self.stride + x;
73 offset / self.vector_size
74 }
75
76 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (usize, bool) {
77 (self.to_source_pos(pos), self.is_in_bounds(pos))
78 }
79
80 fn shape(&self) -> Self::Coordinates {
81 self.len
82 }
83
84 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
85 pos < self.len
86 }
87}