cubecl_std/tensor/layout/
permuted.rs

1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3
4use crate::{
5    FastDivmod, FastDivmodArgs,
6    tensor::{
7        index_offset_contiguous_fastdivmod,
8        layout::{Coords1d, Layout, LayoutExpand},
9    },
10};
11
12/// Layout for mapping heavily permuted tensors that can't be indexed as linear or 2D strided to a
13/// linear index
14#[derive(CubeType, CubeLaunch, Clone)]
15pub struct PermutedLayout {
16    shape: Sequence<FastDivmod>,
17    strides: Sequence<u32>,
18    len: u32,
19    #[cube(comptime)]
20    line_size: u32,
21}
22
23impl<'a, R: Runtime> PermutedLayoutLaunch<'a, R> {
24    /// Create a new permuted layout for a possibly broadcast tensor, with a reference shape to be
25    /// broadcast to.
26    pub fn from_shape_strides(
27        client: &ComputeClient<R::Server>,
28        shape: &[usize],
29        strides: &[usize],
30        line_size: u8,
31    ) -> Self {
32        let len = shape.iter().product::<usize>() / line_size as usize;
33
34        let shape = shape
35            .iter()
36            .map(|it| FastDivmodArgs::new(client, *it as u32))
37            .collect();
38        let strides = strides
39            .iter()
40            .map(|it| ScalarArg::new(*it as u32))
41            .collect();
42
43        Self::new(shape, strides, ScalarArg::new(len as u32), line_size as u32)
44    }
45
46    /// Create a new permuted layout for a possibly broadcast tensor, with a reference shape to be
47    /// broadcast to.
48    pub fn from_shapes_strides_ref(
49        client: &ComputeClient<R::Server>,
50        shape: &[usize],
51        reference_shape: &[usize],
52        strides: &[usize],
53        line_size: u8,
54    ) -> Self {
55        debug_assert!(
56            shape.len() == reference_shape.len(),
57            "Shape and reference should have the same rank"
58        );
59        debug_assert!(
60            shape
61                .iter()
62                .zip(reference_shape)
63                .all(|(s, r)| s == r || *s == 1),
64            "Shape should be equal to reference or 1 on each dimension"
65        );
66
67        let strides: Vec<usize> = strides
68            .iter()
69            .zip(shape.iter().zip(reference_shape))
70            .map(|(stride, (s, r))| if *s == *r { *stride } else { 0 })
71            .collect();
72
73        Self::from_shape_strides(client, reference_shape, &strides, line_size)
74    }
75
76    pub fn from_handles_ref(
77        client: &ComputeClient<R::Server>,
78        handle: &TensorHandleRef<'_, R>,
79        reference_handle: &TensorHandleRef<'_, R>,
80        line_size: u8,
81    ) -> Self {
82        Self::from_shapes_strides_ref(
83            client,
84            handle.shape,
85            reference_handle.shape,
86            handle.strides,
87            line_size,
88        )
89    }
90
91    pub fn from_handle(
92        client: &ComputeClient<R::Server>,
93        handle: &TensorHandleRef<'_, R>,
94        line_size: u8,
95    ) -> Self {
96        Self::from_shape_strides(client, handle.shape, handle.strides, line_size)
97    }
98}
99
100#[cube]
101impl Layout for PermutedLayout {
102    type Coordinates = Coords1d;
103    type SourceCoordinates = Coords1d;
104
105    fn to_source_pos(&self, pos: Self::Coordinates) -> u32 {
106        index_offset_contiguous_fastdivmod(
107            pos,
108            &self.shape,
109            &self.strides,
110            comptime![self.line_size],
111        )
112    }
113
114    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (u32, bool) {
115        (self.to_source_pos(pos), self.is_in_bounds(pos))
116    }
117
118    fn shape(&self) -> Self::Coordinates {
119        self.len
120    }
121
122    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
123        pos < self.len
124    }
125}