Skip to main content

cubecl_std/tensor/layout/
permuted.rs

1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl, ir::LineSize};
3
4use crate::{
5    FastDivmod, FastDivmodArgs,
6    tensor::{
7        index_offset_contiguous_fastdivmod,
8        layout::{Coords1d, Layout, LayoutExpand},
9    },
10};
11
12/// Layout for mapping heavily permuted tensors that can't be indexed as linear or 2D strided to a
13/// linear index
14#[derive(CubeType, CubeLaunch, Clone)]
15pub struct PermutedLayout {
16    shape: Sequence<FastDivmod<usize>>,
17    strides: Sequence<usize>,
18    len: usize,
19    #[cube(comptime)]
20    line_size: LineSize,
21}
22
23#[cube]
24impl PermutedLayout {
25    pub fn new(
26        shape: Sequence<FastDivmod<usize>>,
27        strides: Sequence<usize>,
28        len: usize,
29        #[comptime] line_size: LineSize,
30    ) -> Self {
31        PermutedLayout {
32            shape,
33            strides,
34            len,
35            line_size,
36        }
37    }
38}
39
40impl<'a, R: Runtime> PermutedLayoutLaunch<'a, R> {
41    /// Create a new permuted layout for a possibly broadcast tensor, with a reference shape to be
42    /// broadcast to.
43    pub fn from_shape_strides(
44        client: &ComputeClient<R>,
45        shape: &[usize],
46        strides: &[usize],
47        line_size: LineSize,
48    ) -> Self {
49        let len = shape.iter().product::<usize>() / line_size;
50
51        let shape = shape
52            .iter()
53            .map(|it| FastDivmodArgs::<usize>::new(client, *it))
54            .collect();
55        let strides = strides.iter().map(|it| ScalarArg::new(*it)).collect();
56
57        Self::new(shape, strides, ScalarArg::new(len), line_size)
58    }
59
60    /// Create a new permuted layout for a possibly broadcast tensor, with a reference shape to be
61    /// broadcast to.
62    pub fn from_shapes_strides_ref(
63        client: &ComputeClient<R>,
64        shape: &[usize],
65        reference_shape: &[usize],
66        strides: &[usize],
67        line_size: LineSize,
68    ) -> Self {
69        debug_assert!(
70            shape.len() == reference_shape.len(),
71            "Shape and reference should have the same rank"
72        );
73        debug_assert!(
74            shape
75                .iter()
76                .zip(reference_shape)
77                .all(|(s, r)| s == r || *s == 1),
78            "Shape should be equal to reference or 1 on each dimension"
79        );
80
81        let strides: Vec<usize> = strides
82            .iter()
83            .zip(shape.iter().zip(reference_shape))
84            .map(|(stride, (s, r))| if *s == *r { *stride } else { 0 })
85            .collect();
86
87        Self::from_shape_strides(client, reference_shape, &strides, line_size)
88    }
89
90    pub fn from_handles_ref(
91        client: &ComputeClient<R>,
92        handle: &TensorHandleRef<'_, R>,
93        reference_handle: &TensorHandleRef<'_, R>,
94        line_size: LineSize,
95    ) -> Self {
96        Self::from_shapes_strides_ref(
97            client,
98            handle.shape,
99            reference_handle.shape,
100            handle.strides,
101            line_size,
102        )
103    }
104
105    pub fn from_handle(
106        client: &ComputeClient<R>,
107        handle: &TensorHandleRef<'_, R>,
108        line_size: LineSize,
109    ) -> Self {
110        Self::from_shape_strides(client, handle.shape, handle.strides, line_size)
111    }
112}
113
114#[cube]
115impl Layout for PermutedLayout {
116    type Coordinates = Coords1d;
117    type SourceCoordinates = Coords1d;
118
119    fn to_source_pos(&self, pos: Self::Coordinates) -> usize {
120        index_offset_contiguous_fastdivmod(pos, &self.shape, &self.strides, self.line_size)
121    }
122
123    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (usize, bool) {
124        (self.to_source_pos(pos), self.is_in_bounds(pos))
125    }
126
127    fn shape(&self) -> Self::Coordinates {
128        self.len
129    }
130
131    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
132        pos < self.len
133    }
134}