cubecl_std/tensor/layout/
simple.rs

1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3
4use crate::tensor::layout::{Coords1d, Layout, LayoutExpand};
5
6/// Layout for contiguous tensors, indexed in individual elements.
7/// Differs from `PlainLayout` because `PlainLayout` expects line indices, not element indices.
8#[derive(CubeType, CubeLaunch, Clone)]
9pub struct SimpleLayout {
10    len: u32,
11    #[cube(comptime)]
12    line_size: u32,
13}
14
15#[cube]
16impl SimpleLayout {
17    /// Create a new simple layout with a length and line size.
18    ///
19    /// # Note
20    /// Length should be in elements, not lines!
21    pub fn new(len: u32, #[comptime] line_size: u32) -> Self {
22        SimpleLayout { len, line_size }
23    }
24}
25
26impl<'a, R: Runtime> SimpleLayoutLaunch<'a, R> {
27    pub fn from_shape(shape: &[usize], line_size: u8) -> Self {
28        let len = shape.iter().product::<usize>();
29        Self::new(ScalarArg::new(len as u32), line_size as u32)
30    }
31
32    pub fn from_handle(handle: &TensorHandleRef<'_, R>, line_size: u8) -> Self {
33        Self::from_shape(handle.shape, line_size)
34    }
35}
36
37#[cube]
38impl Layout for SimpleLayout {
39    type Coordinates = Coords1d;
40    type SourceCoordinates = Coords1d;
41
42    fn to_source_pos(&self, pos: Self::Coordinates) -> u32 {
43        pos / self.line_size
44    }
45
46    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (u32, bool) {
47        (self.to_source_pos(pos), self.is_in_bounds(pos))
48    }
49
50    fn shape(&self) -> Self::Coordinates {
51        self.len
52    }
53
54    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
55        pos < self.len
56    }
57}