cubecl_std/tensor/layout/
fixed_dim.rs1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3
4use crate::tensor::layout::{
5 Layout, LayoutExpand,
6 as_dyn::{IntoDyn, IntoDynExpand},
7};
8
9#[derive(CubeType, CubeLaunch)]
10pub struct FixedDimLayout<D: IntoDyn> {
11 shape: D,
12 strides: Sequence<usize>,
13 #[cube(comptime)]
14 vector_size: VectorSize,
15 #[cube(comptime)]
16 checked: bool,
17}
18
19#[cube]
20impl<D: IntoDyn> FixedDimLayout<D> {
21 pub fn new(
22 shape: D,
23 strides: Sequence<usize>,
24 #[comptime] vector_size: VectorSize,
25 #[comptime] checked: bool,
26 ) -> Self {
27 FixedDimLayout::<D> {
28 shape,
29 strides,
30 vector_size,
31 checked,
32 }
33 }
34}
35
36#[cube]
37impl<D: IntoDyn> Layout for FixedDimLayout<D> {
38 type Coordinates = D;
39 type SourceCoordinates = usize;
40
41 fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
42 let pos = pos.into_dyn();
43 let mut offset = 0;
44
45 #[unroll]
46 for i in 0..pos.len() {
47 offset += pos[i] as usize * self.strides[i];
48 }
49
50 offset / self.vector_size
51 }
52
53 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
54 let mut in_bounds = true;
55 if comptime![self.checked] {
56 let pos = pos.into_dyn();
57 let shape = self.shape.clone().into_dyn();
58
59 #[unroll]
60 for i in 0..pos.len() {
61 in_bounds &= pos[i] < shape[i];
62 }
63 }
64 in_bounds
65 }
66
67 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
68 (self.to_source_pos(pos.clone()), self.is_in_bounds(pos))
69 }
70
71 fn shape(&self) -> Self::Coordinates {
72 self.shape.clone()
73 }
74}
75
76impl<D: IntoDyn, R: Runtime> FixedDimLayoutLaunch<D, R> {
77 pub fn from_shape_handle(
78 handle: &TensorBinding<R>,
79 shape: D::RuntimeArg<R>,
80 vector_size: VectorSize,
81 ) -> Self {
82 let strides = handle.strides.iter().copied().collect();
83 Self::new(shape, strides, vector_size, true)
84 }
85
86 pub fn from_shape_handle_unchecked(
87 handle: &TensorBinding<R>,
88 shape: D::RuntimeArg<R>,
89 vector_size: VectorSize,
90 ) -> Self {
91 let strides = handle.strides.iter().copied().collect();
92 Self::new(shape, strides, vector_size, false)
93 }
94}