cubecl_std/tensor/layout/
permuted.rs

1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl, ir::LineSize};
3
4use crate::{
5    FastDivmod, FastDivmodArgs,
6    tensor::{
7        index_offset_contiguous_fastdivmod,
8        layout::{Coords1d, Layout, LayoutExpand},
9    },
10};
11
12/// Layout for mapping heavily permuted tensors that can't be indexed as linear or 2D strided to a
13/// linear index
14#[derive(CubeType, CubeLaunch, Clone)]
15pub struct PermutedLayout {
16    shape: Sequence<FastDivmod<usize>>,
17    strides: Sequence<usize>,
18    len: usize,
19    #[cube(comptime)]
20    line_size: LineSize,
21}
22
23impl<'a, R: Runtime> PermutedLayoutLaunch<'a, R> {
24    /// Create a new permuted layout for a possibly broadcast tensor, with a reference shape to be
25    /// broadcast to.
26    pub fn from_shape_strides(
27        client: &ComputeClient<R>,
28        shape: &[usize],
29        strides: &[usize],
30        line_size: LineSize,
31    ) -> Self {
32        let len = shape.iter().product::<usize>() / line_size;
33
34        let shape = shape
35            .iter()
36            .map(|it| FastDivmodArgs::<usize>::new(client, *it))
37            .collect();
38        let strides = strides.iter().map(|it| ScalarArg::new(*it)).collect();
39
40        Self::new(shape, strides, ScalarArg::new(len), line_size)
41    }
42
43    /// Create a new permuted layout for a possibly broadcast tensor, with a reference shape to be
44    /// broadcast to.
45    pub fn from_shapes_strides_ref(
46        client: &ComputeClient<R>,
47        shape: &[usize],
48        reference_shape: &[usize],
49        strides: &[usize],
50        line_size: LineSize,
51    ) -> Self {
52        debug_assert!(
53            shape.len() == reference_shape.len(),
54            "Shape and reference should have the same rank"
55        );
56        debug_assert!(
57            shape
58                .iter()
59                .zip(reference_shape)
60                .all(|(s, r)| s == r || *s == 1),
61            "Shape should be equal to reference or 1 on each dimension"
62        );
63
64        let strides: Vec<usize> = strides
65            .iter()
66            .zip(shape.iter().zip(reference_shape))
67            .map(|(stride, (s, r))| if *s == *r { *stride } else { 0 })
68            .collect();
69
70        Self::from_shape_strides(client, reference_shape, &strides, line_size)
71    }
72
73    pub fn from_handles_ref(
74        client: &ComputeClient<R>,
75        handle: &TensorHandleRef<'_, R>,
76        reference_handle: &TensorHandleRef<'_, R>,
77        line_size: LineSize,
78    ) -> Self {
79        Self::from_shapes_strides_ref(
80            client,
81            handle.shape,
82            reference_handle.shape,
83            handle.strides,
84            line_size,
85        )
86    }
87
88    pub fn from_handle(
89        client: &ComputeClient<R>,
90        handle: &TensorHandleRef<'_, R>,
91        line_size: LineSize,
92    ) -> Self {
93        Self::from_shape_strides(client, handle.shape, handle.strides, line_size)
94    }
95}
96
97#[cube]
98impl Layout for PermutedLayout {
99    type Coordinates = Coords1d;
100    type SourceCoordinates = Coords1d;
101
102    fn to_source_pos(&self, pos: Self::Coordinates) -> usize {
103        index_offset_contiguous_fastdivmod(pos, &self.shape, &self.strides, self.line_size)
104    }
105
106    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (usize, bool) {
107        (self.to_source_pos(pos), self.is_in_bounds(pos))
108    }
109
110    fn shape(&self) -> Self::Coordinates {
111        self.len
112    }
113
114    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
115        pos < self.len
116    }
117}