cubecl_std/
reinterpret_slice.rs1use core::marker::PhantomData;
2
3use cubecl::prelude::*;
4use cubecl_core::{self as cubecl, ir::VectorSize, unexpanded};
5
6#[derive(CubeType)]
14pub struct ReinterpretSlice<S: CubePrimitive, T: CubePrimitive> {
15 slice: Slice<S>,
17
18 #[cube(comptime)]
19 vector_size: VectorSize,
20
21 #[cube(comptime)]
22 load_many: Option<usize>,
23
24 #[cube(comptime)]
25 _phantom: PhantomData<T>,
26}
27
28#[cube]
29impl<S: CubePrimitive, T: CubePrimitive> ReinterpretSlice<S, T> {
30 pub fn new(slice: Slice<S>) -> ReinterpretSlice<S, T> {
31 let in_vector_size = slice.vector_size();
32 let source_size = S::Scalar::type_size();
33 let target_size = T::Scalar::type_size();
34 let (optimized_vector_size, load_many) = comptime!(optimize_vector_size(
35 source_size,
36 in_vector_size,
37 target_size
38 ));
39 match comptime!(optimized_vector_size) {
40 Some(vector_size) => {
41 let size!(N2) = vector_size;
42 let slice = slice.into_vectorized().with_vector_size::<N2>();
43
44 ReinterpretSlice::<S, T> {
45 slice: unsafe { slice.downcast_unchecked() },
46 vector_size,
47 load_many,
48 _phantom: PhantomData,
49 }
50 }
51 None => ReinterpretSlice::<S, T> {
52 slice,
53 vector_size: in_vector_size,
54 load_many,
55 _phantom: PhantomData,
56 },
57 }
58 }
59
60 pub fn read(&self, index: usize) -> T {
61 let size!(N) = self.vector_size;
62 let slice = self.slice.into_vectorized().with_vector_size::<N>();
63 match comptime!(self.load_many) {
64 Some(amount) => {
65 let first = index * amount;
66 let size!(N2) = comptime!(amount * self.vector_size);
67 let mut vector = Vector::<S::Scalar, N2>::empty();
68 #[unroll]
69 for k in 0..amount {
70 let elem = slice[first + k];
71 #[unroll]
72 for j in 0..self.vector_size {
73 vector[k * self.vector_size + j] = elem[j];
74 }
75 }
76 T::reinterpret(vector)
77 }
78 None => T::reinterpret(slice[index]),
79 }
80 }
81}
82
83#[derive(CubeType)]
91pub struct ReinterpretSliceMut<S: CubePrimitive, T: CubePrimitive> {
92 slice: SliceMut<S>,
93
94 #[cube(comptime)]
95 vector_size: VectorSize,
96
97 #[cube(comptime)]
98 load_many: Option<usize>,
99
100 #[cube(comptime)]
101 _phantom: PhantomData<T>,
102}
103
104#[cube]
105impl<S: CubePrimitive, T: CubePrimitive> ReinterpretSliceMut<S, T> {
106 pub fn new(slice: SliceMut<S>) -> ReinterpretSliceMut<S, T> {
107 let in_vector_size = slice.vector_size();
108 let source_size = S::Scalar::type_size();
109 let target_size = T::Scalar::type_size();
110 let (optimized_vector_size, load_many) = comptime!(optimize_vector_size(
111 source_size,
112 in_vector_size,
113 target_size
114 ));
115 match comptime!(optimized_vector_size) {
116 Some(vector_size) => {
117 let size!(N2) = vector_size;
118 let slice = slice.into_vectorized().with_vector_size::<N2>();
119
120 ReinterpretSliceMut::<S, T> {
121 slice: unsafe { slice.downcast_unchecked() },
122 vector_size,
123 load_many,
124 _phantom: PhantomData,
125 }
126 }
127 None => ReinterpretSliceMut::<S, T> {
128 slice,
129 vector_size: in_vector_size,
130 load_many,
131 _phantom: PhantomData,
132 },
133 }
134 }
135
136 pub fn read(&self, index: usize) -> T {
137 let size!(N) = self.vector_size;
138 let slice = self.slice.into_vectorized().with_vector_size::<N>();
139 match comptime!(self.load_many) {
140 Some(amount) => {
141 let first = index * amount;
142 let size!(N2) = comptime!(amount * self.vector_size);
143 let mut vector = Vector::<S::Scalar, N2>::empty();
144 #[unroll]
145 for k in 0..amount {
146 let elem = slice[first + k];
147 #[unroll]
148 for j in 0..self.vector_size {
149 vector[k * self.vector_size + j] = elem[j];
150 }
151 }
152 T::reinterpret(vector)
153 }
154 None => T::reinterpret(slice[index]),
155 }
156 }
157
158 pub fn write(&mut self, index: usize, value: T) {
159 let size!(N) = self.vector_size;
160 let mut slice = self.slice.into_vectorized().with_vector_size::<N>();
161 let size!(N1) = S::reinterpret_vectorization::<T>();
162 let reinterpreted = Vector::<S::Scalar, N1>::reinterpret(value);
163 match comptime!(self.load_many) {
164 Some(amount) => {
165 let first = index * amount;
166 let vector_size = comptime!(reinterpreted.size() / amount);
167
168 #[unroll]
169 for k in 0..amount {
170 let mut vector = Vector::empty();
171 #[unroll]
172 for j in 0..vector_size {
173 vector[j] = reinterpreted[k * vector_size + j];
174 }
175 slice[first + k] = vector;
176 }
177 }
178 None => slice[index] = Vector::cast_from(reinterpreted),
179 }
180 }
181}
182
183fn optimize_vector_size(
184 source_size: usize,
185 vector_size: VectorSize,
186 target_size: usize,
187) -> (Option<usize>, Option<usize>) {
188 let vector_source_size = source_size * vector_size;
189 match vector_source_size.cmp(&target_size) {
190 core::cmp::Ordering::Less => {
191 if !target_size.is_multiple_of(vector_source_size) {
192 panic!("incompatible number of bytes");
193 }
194
195 let ratio = target_size / vector_source_size;
196
197 (None, Some(ratio))
198 }
199 core::cmp::Ordering::Greater => {
200 if !vector_source_size.is_multiple_of(target_size) {
201 panic!("incompatible number of bytes");
202 }
203 let ratio = vector_source_size / target_size;
204
205 (Some(vector_size / ratio), None)
206 }
207 core::cmp::Ordering::Equal => (None, None),
208 }
209}
210
211pub fn size_of<S: CubePrimitive>() -> u32 {
212 unexpanded!()
213}
214
215pub mod size_of {
216 use super::*;
217 #[allow(unused, clippy::all)]
218 pub fn expand<S: CubePrimitive>(context: &mut cubecl::prelude::Scope) -> u32 {
219 S::as_type(context).size() as u32
220 }
221}