cubecl_std/tensor/layout/
permuted.rs1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl, ir::VectorSize, zspace::Shape};
3
4use crate::{
5 FastDivmod,
6 tensor::{
7 index_offset_contiguous_fastdivmod,
8 launch::{BufferArg, ViewLayoutLaunchArg},
9 layout::{Coords1d, Layout, LayoutExpand},
10 },
11};
12
13#[derive(CubeType, Clone)]
16pub struct PermutedLayout {
17 shape: Sequence<FastDivmod<usize>>,
18 strides: Sequence<usize>,
19 len: usize,
20 #[cube(comptime)]
21 vector_size: VectorSize,
22}
23
24#[cube]
25impl PermutedLayout {
26 pub fn new(
27 shape: Sequence<FastDivmod<usize>>,
28 strides: Sequence<usize>,
29 len: usize,
30 #[comptime] vector_size: VectorSize,
31 ) -> Self {
32 PermutedLayout {
33 shape,
34 strides,
35 len,
36 vector_size,
37 }
38 }
39}
40
41#[derive(Default)]
42pub struct PermutedLayoutLaunch {
43 reference_shape: Option<Shape>,
44}
45
46#[derive(Debug, Hash, PartialEq, Eq, Clone)]
47pub struct PermutedLayoutCompilationArg {
48 shape: <Sequence<FastDivmod<usize>> as LaunchArg>::CompilationArg,
49 strides: <Sequence<usize> as LaunchArg>::CompilationArg,
50}
51
52impl ViewLayoutLaunchArg for PermutedLayout {
53 type RuntimeArg<R: Runtime> = PermutedLayoutLaunch;
54 type CompilationArg = PermutedLayoutCompilationArg;
55
56 fn register<R: Runtime, B: BufferArg>(
57 arg: Self::RuntimeArg<R>,
58 buffer: &B,
59 ty: Type,
60 launcher: &mut KernelLauncher<R>,
61 ) -> Self::CompilationArg {
62 let shape = buffer.shape();
63 let strides = buffer.strides();
64 let (shape, strides, len) = match arg.reference_shape {
65 Some(reference_shape) => {
66 let len = reference_shape.len();
67 let strides = strides_ref(shape, &reference_shape, strides);
68 (reference_shape.iter().copied().collect(), strides, len)
69 }
70 None => (
71 shape.iter().copied().collect(),
72 strides.iter().copied().collect(),
73 buffer.len(),
74 ),
75 };
76 let len = len / ty.vector_size();
77 let shape = <Sequence<FastDivmod<usize>> as LaunchArg>::register(shape, launcher);
78 let strides = <Sequence<usize> as LaunchArg>::register(strides, launcher);
79 <usize as LaunchArg>::register(len, launcher);
80 PermutedLayoutCompilationArg { shape, strides }
81 }
82
83 fn expand(
84 arg: &Self::CompilationArg,
85 ty: Type,
86 builder: &mut KernelBuilder,
87 ) -> <Self as CubeType>::ExpandType {
88 PermutedLayoutExpand {
89 shape: <Sequence<FastDivmod<usize>> as LaunchArg>::expand(&arg.shape, builder),
90 strides: <Sequence<usize> as LaunchArg>::expand(&arg.strides, builder),
91 len: <usize as LaunchArg>::expand(&(), builder),
92 vector_size: ty.vector_size(),
93 }
94 }
95}
96
97fn strides_ref<R: Runtime>(
98 shape: &[usize],
99 reference_shape: &[usize],
100 strides: &[usize],
101) -> SequenceArg<R, usize> {
102 debug_assert!(
103 shape.len() == reference_shape.len(),
104 "Shape and reference should have the same rank"
105 );
106 debug_assert!(
107 shape
108 .iter()
109 .zip(reference_shape.iter())
110 .all(|(s, r)| s == r || *s == 1),
111 "Shape should be equal to reference or 1 on each dimension"
112 );
113
114 strides
115 .iter()
116 .zip(shape.iter().zip(reference_shape.iter()))
117 .map(|(stride, (s, r))| if *s == *r { *stride } else { 0 })
118 .collect()
119}
120
121impl PermutedLayoutLaunch {
122 pub fn new() -> Self {
124 Self::default()
125 }
126
127 pub fn from_reference_shape(reference_shape: Shape) -> Self {
130 Self {
131 reference_shape: Some(reference_shape),
132 }
133 }
134
135 pub fn from_reference_handle<R: Runtime>(reference_handle: TensorBinding<R>) -> Self {
136 Self::from_reference_shape(reference_handle.shape)
137 }
138}
139
140#[cube]
141impl Layout for PermutedLayout {
142 type Coordinates = Coords1d;
143 type SourceCoordinates = Coords1d;
144
145 fn to_source_pos(&self, pos: Self::Coordinates) -> usize {
146 index_offset_contiguous_fastdivmod(pos, &self.shape, &self.strides, self.vector_size)
147 }
148
149 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (usize, bool) {
150 (self.to_source_pos(pos), self.is_in_bounds(pos))
151 }
152
153 fn shape(&self) -> Self::Coordinates {
154 self.len
155 }
156
157 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
158 pos < self.len
159 }
160}