cubecl_std/tensor/layout/
fixed_dim.rs1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3
4use crate::tensor::layout::{
5 Layout, LayoutExpand,
6 as_dyn::{IntoDyn, IntoDynExpand},
7};
8
9#[derive(CubeType, CubeLaunch)]
10pub struct FixedDimLayout<D: IntoDyn> {
11 shape: D,
12 strides: Sequence<usize>,
13 #[cube(comptime)]
14 line_size: LineSize,
15 #[cube(comptime)]
16 checked: bool,
17}
18
19#[cube]
20impl<D: IntoDyn> FixedDimLayout<D> {
21 pub fn new(
22 shape: D,
23 strides: Sequence<usize>,
24 #[comptime] line_size: LineSize,
25 #[comptime] checked: bool,
26 ) -> Self {
27 FixedDimLayout::<D> {
28 shape,
29 strides,
30 line_size,
31 checked,
32 }
33 }
34}
35
36#[cube]
37impl<D: IntoDyn> Layout for FixedDimLayout<D> {
38 type Coordinates = D;
39 type SourceCoordinates = usize;
40
41 fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
42 let pos = pos.into_dyn();
43 let mut offset = 0;
44
45 #[unroll]
46 for i in 0..pos.len() {
47 offset += pos[i] as usize * self.strides[i];
48 }
49
50 offset / self.line_size
51 }
52
53 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
54 let mut in_bounds = true;
55 if comptime![self.checked] {
56 let pos = pos.into_dyn();
57 let shape = self.shape.clone().into_dyn();
58
59 #[unroll]
60 for i in 0..pos.len() {
61 in_bounds &= pos[i] < shape[i];
62 }
63 }
64 in_bounds
65 }
66
67 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
68 (self.to_source_pos(pos.clone()), self.is_in_bounds(pos))
69 }
70
71 fn shape(&self) -> Self::Coordinates {
72 self.shape.clone()
73 }
74}
75
76impl<'a, D: IntoDyn, R: Runtime> FixedDimLayoutLaunch<'a, D, R> {
77 pub fn from_shape_handle(
78 handle: &TensorHandleRef<'a, R>,
79 shape: D::RuntimeArg<'a, R>,
80 line_size: LineSize,
81 ) -> Self {
82 let strides = handle.strides.iter().copied().map(ScalarArg::new).collect();
83 Self::new(shape, strides, line_size, true)
84 }
85
86 pub fn from_shape_handle_unchecked(
87 handle: &TensorHandleRef<'a, R>,
88 shape: D::RuntimeArg<'a, R>,
89 line_size: LineSize,
90 ) -> Self {
91 let strides = handle.strides.iter().copied().map(ScalarArg::new).collect();
92 Self::new(shape, strides, line_size, false)
93 }
94}