Skip to main content

cubecl_std/tensor/layout/
permuted.rs

1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl, ir::VectorSize, zspace::Shape};
3
4use crate::{
5    FastDivmod,
6    tensor::{
7        index_offset_contiguous_fastdivmod,
8        launch::{BufferArg, ViewLayoutLaunchArg},
9        layout::{Coords1d, Layout, LayoutExpand},
10    },
11};
12
13/// Layout for mapping heavily permuted tensors that can't be indexed as linear or 2D strided to a
14/// linear index
15#[derive(CubeType, Clone)]
16pub struct PermutedLayout {
17    shape: Sequence<FastDivmod<usize>>,
18    strides: Sequence<usize>,
19    len: usize,
20    #[cube(comptime)]
21    vector_size: VectorSize,
22}
23
24#[cube]
25impl PermutedLayout {
26    pub fn new(
27        shape: Sequence<FastDivmod<usize>>,
28        strides: Sequence<usize>,
29        len: usize,
30        #[comptime] vector_size: VectorSize,
31    ) -> Self {
32        PermutedLayout {
33            shape,
34            strides,
35            len,
36            vector_size,
37        }
38    }
39}
40
41#[derive(Default)]
42pub struct PermutedLayoutLaunch {
43    reference_shape: Option<Shape>,
44}
45
46#[derive(Debug, Hash, PartialEq, Eq, Clone)]
47pub struct PermutedLayoutCompilationArg {
48    shape: <Sequence<FastDivmod<usize>> as LaunchArg>::CompilationArg,
49    strides: <Sequence<usize> as LaunchArg>::CompilationArg,
50}
51
52impl ViewLayoutLaunchArg for PermutedLayout {
53    type RuntimeArg<R: Runtime> = PermutedLayoutLaunch;
54    type CompilationArg = PermutedLayoutCompilationArg;
55
56    fn register<R: Runtime, B: BufferArg>(
57        arg: Self::RuntimeArg<R>,
58        buffer: &B,
59        ty: Type,
60        launcher: &mut KernelLauncher<R>,
61    ) -> Self::CompilationArg {
62        let shape = buffer.shape();
63        let strides = buffer.strides();
64        let (shape, strides, len) = match arg.reference_shape {
65            Some(reference_shape) => {
66                let len = reference_shape.len();
67                let strides = strides_ref(shape, &reference_shape, strides);
68                (reference_shape.iter().copied().collect(), strides, len)
69            }
70            None => (
71                shape.iter().copied().collect(),
72                strides.iter().copied().collect(),
73                buffer.len(),
74            ),
75        };
76        let len = len / ty.vector_size();
77        let shape = <Sequence<FastDivmod<usize>> as LaunchArg>::register(shape, launcher);
78        let strides = <Sequence<usize> as LaunchArg>::register(strides, launcher);
79        <usize as LaunchArg>::register(len, launcher);
80        PermutedLayoutCompilationArg { shape, strides }
81    }
82
83    fn expand(
84        arg: &Self::CompilationArg,
85        ty: Type,
86        builder: &mut KernelBuilder,
87    ) -> <Self as CubeType>::ExpandType {
88        PermutedLayoutExpand {
89            shape: <Sequence<FastDivmod<usize>> as LaunchArg>::expand(&arg.shape, builder),
90            strides: <Sequence<usize> as LaunchArg>::expand(&arg.strides, builder),
91            len: <usize as LaunchArg>::expand(&(), builder),
92            vector_size: ty.vector_size(),
93        }
94    }
95}
96
97fn strides_ref<R: Runtime>(
98    shape: &[usize],
99    reference_shape: &[usize],
100    strides: &[usize],
101) -> SequenceArg<R, usize> {
102    debug_assert!(
103        shape.len() == reference_shape.len(),
104        "Shape and reference should have the same rank"
105    );
106    debug_assert!(
107        shape
108            .iter()
109            .zip(reference_shape.iter())
110            .all(|(s, r)| s == r || *s == 1),
111        "Shape should be equal to reference or 1 on each dimension"
112    );
113
114    strides
115        .iter()
116        .zip(shape.iter().zip(reference_shape.iter()))
117        .map(|(stride, (s, r))| if *s == *r { *stride } else { 0 })
118        .collect()
119}
120
121impl PermutedLayoutLaunch {
122    /// Create a new permuted layout without a reference shape.
123    pub fn new() -> Self {
124        Self::default()
125    }
126
127    /// Create a new permuted layout for a possibly broadcast tensor, with a reference shape to be
128    /// broadcast to.
129    pub fn from_reference_shape(reference_shape: Shape) -> Self {
130        Self {
131            reference_shape: Some(reference_shape),
132        }
133    }
134
135    pub fn from_reference_handle<R: Runtime>(reference_handle: TensorBinding<R>) -> Self {
136        Self::from_reference_shape(reference_handle.shape)
137    }
138}
139
140#[cube]
141impl Layout for PermutedLayout {
142    type Coordinates = Coords1d;
143    type SourceCoordinates = Coords1d;
144
145    fn to_source_pos(&self, pos: Self::Coordinates) -> usize {
146        index_offset_contiguous_fastdivmod(pos, &self.shape, &self.strides, self.vector_size)
147    }
148
149    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (usize, bool) {
150        (self.to_source_pos(pos), self.is_in_bounds(pos))
151    }
152
153    fn shape(&self) -> Self::Coordinates {
154        self.len
155    }
156
157    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
158        pos < self.len
159    }
160}