cubecl_std/tensor/layout/
strided.rs1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3
4use crate::{
5 FastDivmod, FastDivmodArgs,
6 tensor::layout::{Coords1d, Layout, LayoutExpand},
7};
8
9#[derive(CubeType, CubeLaunch, Clone)]
13pub struct StridedLayout {
14 shape: FastDivmod,
15 stride: u32,
16 len: u32,
17 #[cube(comptime)]
18 line_size: u32,
19}
20
21impl<'a, R: Runtime> StridedLayoutLaunch<'a, R> {
22 pub fn from_shape_strides(
23 client: &ComputeClient<R::Server>,
24 shape: &[usize],
25 strides: &[usize],
26 line_size: u8,
27 ) -> Self {
28 let rank = shape.len();
29 let len = shape.iter().product::<usize>() / line_size as usize;
30 Self::new(
31 FastDivmodArgs::new(client, shape[rank - 1] as u32),
32 ScalarArg::new(strides[rank - 2] as u32),
33 ScalarArg::new(len as u32),
34 line_size as u32,
35 )
36 }
37
38 pub fn from_handle(
39 client: &ComputeClient<R::Server>,
40 handle: &TensorHandleRef<'_, R>,
41 line_size: u8,
42 ) -> Self {
43 Self::from_shape_strides(client, handle.shape, handle.strides, line_size)
44 }
45}
46
47#[cube]
48impl Layout for StridedLayout {
49 type Coordinates = Coords1d;
50 type SourceCoordinates = Coords1d;
51
52 fn to_source_pos(&self, pos: Self::Coordinates) -> u32 {
53 let offset_abs = pos * self.line_size;
54 let (y, x) = self.shape.div_mod(offset_abs);
55 let offset = y * self.stride + x;
56 offset / self.line_size
57 }
58
59 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (u32, bool) {
60 (self.to_source_pos(pos), self.is_in_bounds(pos))
61 }
62
63 fn shape(&self) -> Self::Coordinates {
64 self.len
65 }
66
67 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
68 pos < self.len
69 }
70}