Skip to main content

cubecl_std/tensor/layout/
fixed_dim.rs

1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3
4use crate::tensor::layout::{
5    Layout, LayoutExpand,
6    as_dyn::{IntoDyn, IntoDynExpand},
7};
8
9#[derive(CubeType, CubeLaunch)]
10pub struct FixedDimLayout<D: IntoDyn> {
11    shape: D,
12    strides: Sequence<usize>,
13    #[cube(comptime)]
14    line_size: LineSize,
15    #[cube(comptime)]
16    checked: bool,
17}
18
19#[cube]
20impl<D: IntoDyn> FixedDimLayout<D> {
21    pub fn new(
22        shape: D,
23        strides: Sequence<usize>,
24        #[comptime] line_size: LineSize,
25        #[comptime] checked: bool,
26    ) -> Self {
27        FixedDimLayout::<D> {
28            shape,
29            strides,
30            line_size,
31            checked,
32        }
33    }
34}
35
36#[cube]
37impl<D: IntoDyn> Layout for FixedDimLayout<D> {
38    type Coordinates = D;
39    type SourceCoordinates = usize;
40
41    fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
42        let pos = pos.into_dyn();
43        let mut offset = 0;
44
45        #[unroll]
46        for i in 0..pos.len() {
47            offset += pos[i] as usize * self.strides[i];
48        }
49
50        offset / self.line_size
51    }
52
53    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
54        let mut in_bounds = true;
55        if comptime![self.checked] {
56            let pos = pos.into_dyn();
57            let shape = self.shape.clone().into_dyn();
58
59            #[unroll]
60            for i in 0..pos.len() {
61                in_bounds &= pos[i] < shape[i];
62            }
63        }
64        in_bounds
65    }
66
67    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
68        (self.to_source_pos(pos.clone()), self.is_in_bounds(pos))
69    }
70
71    fn shape(&self) -> Self::Coordinates {
72        self.shape.clone()
73    }
74}
75
76impl<'a, D: IntoDyn, R: Runtime> FixedDimLayoutLaunch<'a, D, R> {
77    pub fn from_shape_handle(
78        handle: &TensorHandleRef<'a, R>,
79        shape: D::RuntimeArg<'a, R>,
80        line_size: LineSize,
81    ) -> Self {
82        let strides = handle.strides.iter().copied().map(ScalarArg::new).collect();
83        Self::new(shape, strides, line_size, true)
84    }
85
86    pub fn from_shape_handle_unchecked(
87        handle: &TensorHandleRef<'a, R>,
88        shape: D::RuntimeArg<'a, R>,
89        line_size: LineSize,
90    ) -> Self {
91        let strides = handle.strides.iter().copied().map(ScalarArg::new).collect();
92        Self::new(shape, strides, line_size, false)
93    }
94}