cubecl_std/tensor/layout/
simple.rs1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl, zspace::Shape};
3
4use crate::tensor::layout::{Coords1d, Layout, LayoutExpand};
5
6#[derive(CubeType, CubeLaunch, Clone)]
9pub struct SimpleLayout {
10 len: usize,
11 #[cube(comptime)]
12 vector_size: VectorSize,
13}
14
15#[cube]
16impl SimpleLayout {
17 pub fn new(len: usize, #[comptime] vector_size: VectorSize) -> Self {
22 SimpleLayout { len, vector_size }
23 }
24}
25
26impl<R: Runtime> SimpleLayoutLaunch<R> {
27 pub fn from_shape(shape: &Shape, vector_size: VectorSize) -> Self {
28 let len = shape.iter().product::<usize>();
29 Self::new(len, vector_size)
30 }
31
32 pub fn from_handle(handle: TensorBinding<R>, vector_size: VectorSize) -> Self {
33 Self::from_shape(&handle.shape, vector_size)
34 }
35}
36
37#[cube]
38impl Layout for SimpleLayout {
39 type Coordinates = Coords1d;
40 type SourceCoordinates = Coords1d;
41
42 fn to_source_pos(&self, pos: Self::Coordinates) -> usize {
43 pos / self.vector_size
44 }
45
46 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (usize, bool) {
47 (self.to_source_pos(pos), self.is_in_bounds(pos))
48 }
49
50 fn shape(&self) -> Self::Coordinates {
51 self.len
52 }
53
54 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
55 pos < self.len
56 }
57}