cubecl_std/tensor/layout/
plain.rs

1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3
4use crate::tensor::layout::{Coords1d, Layout, LayoutExpand};
5
6/// Layout for contiguous tensors.
7#[derive(CubeType, CubeLaunch, Clone)]
8pub struct PlainLayout {
9    len: u32,
10}
11
12#[cube]
13impl PlainLayout {
14    pub fn new(len: u32) -> Self {
15        PlainLayout { len }
16    }
17}
18
19impl<'a, R: Runtime> PlainLayoutLaunch<'a, R> {
20    pub fn from_shape(shape: &[usize], line_size: u8) -> Self {
21        let len = shape.iter().product::<usize>();
22        let len = len / line_size as usize;
23        Self::new(ScalarArg::new(len as u32))
24    }
25
26    pub fn from_handle(handle: &TensorHandleRef<'_, R>, line_size: u8) -> Self {
27        Self::from_shape(handle.shape, line_size)
28    }
29}
30
31#[cube]
32impl Layout for PlainLayout {
33    type Coordinates = Coords1d;
34    type SourceCoordinates = Coords1d;
35
36    fn to_source_pos(&self, pos: Self::Coordinates) -> u32 {
37        pos
38    }
39
40    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (u32, bool) {
41        (self.to_source_pos(pos), self.is_in_bounds(pos))
42    }
43
44    fn shape(&self) -> Self::Coordinates {
45        self.len
46    }
47
48    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
49        pos < self.len
50    }
51}