1use alloc::boxed::Box;
2use core::{
3 marker::PhantomData,
4 ops::{Deref, DerefMut},
5};
6
7use crate::{self as cubecl, unexpanded};
8use cubecl::prelude::*;
9use cubecl_ir::{Branch, ElemType, FloatKind, ManagedVariable, RangeLoop, Variable, VectorSize};
10use cubecl_macros::intrinsic;
11
12#[derive(Clone, Copy)]
13pub struct ReadOnly;
14#[derive(Clone, Copy)]
15pub struct ReadWrite;
16
17#[derive(Clone, Copy)]
23pub struct Slice<E: CubePrimitive, IO: SliceVisibility = ReadOnly> {
24 _e: PhantomData<E>,
25 _io: PhantomData<IO>,
26 _offset: PhantomData<usize>,
27 length: usize,
28}
29
30#[derive(CubeType)]
31pub enum SliceOrigin<E: CubePrimitive> {
32 Tensor(Tensor<E>),
33 Array(Array<E>),
34 SharedMemory(SharedMemory<E>),
35}
36
37impl<E: CubePrimitive> SliceOriginExpand<E> {
38 pub fn vector_size(&self) -> VectorSize {
39 match self {
40 SliceOriginExpand::Tensor(t) => t.vector_size(),
41 SliceOriginExpand::Array(t) => t.vector_size(),
42 SliceOriginExpand::SharedMemory(t) => t.vector_size(),
43 }
44 }
45}
46
47impl<E: CubePrimitive, IO: SliceVisibility> Iterator for Slice<E, IO> {
48 type Item = E;
49
50 fn next(&mut self) -> Option<Self::Item> {
51 unexpanded!()
52 }
53}
54
55pub trait SliceVisibility: Clone + Copy + Send + Sync + 'static {}
56
57impl SliceVisibility for ReadOnly {}
58
59impl SliceVisibility for ReadWrite {}
60
61pub struct SliceExpand<E: CubePrimitive, IO: SliceVisibility> {
62 pub(crate) origin: SliceOriginExpand<E>,
63 pub(crate) io: PhantomData<IO>,
64 pub(crate) offset: NativeExpand<usize>,
65 pub(crate) length: NativeExpand<usize>,
66 pub(crate) vector_size: Option<VectorSize>,
67}
68
69impl<E: CubePrimitive, IO: SliceVisibility> SliceExpand<E, IO> {
70 pub fn __to_raw_parts(&self) -> (Variable, Variable) {
71 let expand = match self.origin.clone() {
72 SliceOriginExpand::Tensor(expand) => expand.expand,
73 SliceOriginExpand::Array(expand) => expand.expand,
74 SliceOriginExpand::SharedMemory(expand) => expand.expand,
75 };
76
77 (*expand, *self.offset.expand)
78 }
79}
80
81#[cube]
82impl<E: Scalar, N: Size, IO: SliceVisibility> Slice<Vector<E, N>, IO> {
83 #[allow(unused_variables)]
89 pub fn with_vector_size<N2: Size>(&self) -> Slice<Vector<E, N2>, IO> {
90 intrinsic!(|scope| {
91 let vector_size = N2::__expand_value(scope);
92 let (input, offset) = self.__to_raw_parts();
93 let mut item = input.ty;
94
95 let current = input.ty.vector_size();
96 let mut out = self
97 .clone()
98 .__expand_downcast_unchecked_method::<Vector<E, N2>>(scope);
99
100 if vector_size == item.vector_size() {
101 return out;
102 }
103
104 if current < vector_size {
105 let ratio = vector_size / current;
106 let length = cubecl::frontend::div::expand(scope, self.length, ratio.into());
107 let offset = cubecl::frontend::div::expand(scope, self.offset, ratio.into());
108 out.length = length;
109 out.offset = offset;
110 } else {
111 let ratio = current / vector_size;
112 let length = cubecl::frontend::mul::expand(scope, self.length, ratio.into());
113 let offset = cubecl::frontend::mul::expand(scope, self.offset, ratio.into());
114 out.length = length;
115 out.offset = offset;
116 }
117
118 out.vector_size = Some(vector_size);
119 out
120 })
121 }
122}
123
124#[cube]
125impl<E: CubePrimitive, IO: SliceVisibility> Slice<E, IO> {
126 pub fn into_vectorized(&self) -> Slice<Vector<E::Scalar, E::Size>, IO> {
129 intrinsic!(|scope| {
130 SliceExpand::<Vector<E::Scalar, E::Size>, IO> {
131 origin: self.origin.cast_unchecked(),
132 io: self.io.clone(),
133 offset: self.offset.clone(),
134 length: self.length.clone(),
135 vector_size: self.vector_size,
136 }
137 })
138 }
139 pub fn downcast<T: CubePrimitive>(&self) -> Slice<T, IO> {
144 intrinsic!(|scope| {
145 if T::as_type(scope) != E::as_type(scope) && !is_tf32::<E, T>(scope) {
146 let elems = [T::as_type(scope).elem_type(), E::as_type(scope).elem_type()];
147 let is_flex32_cast = elems.contains(&ElemType::Float(FloatKind::F32))
148 && elems.contains(&ElemType::Float(FloatKind::Flex32));
149
150 if !is_flex32_cast {
151 panic!("Downcast should only be used to satisfy the Rust type system.")
152 }
153 }
154
155 unsafe { self.__expand_downcast_unchecked_method(scope) }
156 })
157 }
158
159 pub unsafe fn downcast_unchecked<T: CubePrimitive>(&self) -> Slice<T, IO> {
165 intrinsic!(|scope| {
166 SliceExpand::<T, IO> {
167 origin: self.origin.cast_unchecked(),
168 io: self.io.clone(),
169 offset: self.offset.clone(),
170 length: self.length.clone(),
171 vector_size: self.vector_size.clone(),
172 }
173 })
174 }
175}
176
177#[cube]
178impl<E: CubePrimitive> Slice<E, ReadOnly> {
179 pub fn as_mut_unchecked(&self) -> Slice<E, ReadWrite> {
180 intrinsic!(|scope| {
181 SliceExpand::<E, ReadWrite> {
182 origin: self.origin,
183 io: PhantomData,
184 offset: self.offset.clone(),
185 length: self.length.clone(),
186 vector_size: self.vector_size.clone(),
187 }
188 })
189 }
190}
191
192impl<E: CubePrimitive> SliceOriginExpand<E> {
193 fn cast_unchecked<T: CubePrimitive>(self) -> SliceOriginExpand<T> {
194 match self {
195 SliceOriginExpand::Tensor(expand) => {
196 SliceOriginExpand::<T>::Tensor(expand.expand.into())
197 }
198 SliceOriginExpand::Array(expand) => SliceOriginExpand::<T>::Array(expand.expand.into()),
199 SliceOriginExpand::SharedMemory(expand) => {
200 SliceOriginExpand::<T>::SharedMemory(expand.expand.into())
201 }
202 }
203 }
204}
205
206impl<E: CubePrimitive, IO: SliceVisibility> Slice<E, IO> {
207 pub fn new(_origin: SliceOrigin<E>, _offset: usize, _length: usize) -> Self {
208 unexpanded!()
209 }
210 pub fn __expand_new(
211 scope: &mut Scope,
212 origin: SliceOriginExpand<E>,
213 start: NativeExpand<usize>,
214 end: NativeExpand<usize>,
215 ) -> SliceExpand<E, IO> {
216 Self::__expand_new_expand(scope, origin, start, end)
217 }
218 pub fn __expand_new_expand(
219 scope: &mut Scope,
220 origin: SliceOriginExpand<E>,
221 start: NativeExpand<usize>,
222 end: NativeExpand<usize>,
223 ) -> SliceExpand<E, IO> {
224 let length = cubecl::frontend::sub::expand(scope, end, start.clone());
225
226 SliceExpand::<E, IO> {
227 origin,
228 io: PhantomData,
229 offset: start,
230 length,
231 vector_size: None,
232 }
233 }
234}
235
236#[cube]
237impl<E: CubePrimitive, IO: SliceVisibility> Slice<E, IO> {
238 pub fn len(&self) -> usize {
240 self.length
241 }
242 pub fn is_empty(&self) -> bool {
244 self.length == 0
245 }
246}
247
248impl<E: CubePrimitive, IO: SliceVisibility> CubeType for Slice<E, IO> {
249 type ExpandType = SliceExpand<E, IO>;
250}
251
252impl<E: CubePrimitive, IO: SliceVisibility> CubeType for &Slice<E, IO> {
253 type ExpandType = SliceExpand<E, IO>;
254}
255
256impl<E: CubePrimitive, IO: SliceVisibility> CubeType for &mut Slice<E, IO> {
257 type ExpandType = SliceExpand<E, IO>;
258}
259
260impl<E: CubePrimitive, IO: SliceVisibility> IntoMut for SliceExpand<E, IO> {
261 fn into_mut(self, _scope: &mut cubecl_ir::Scope) -> Self {
262 self
263 }
264}
265
266impl<E: CubePrimitive, IO: SliceVisibility> CubeDebug for SliceExpand<E, IO> {}
267impl<E: CubePrimitive, IO: SliceVisibility> Clone for SliceExpand<E, IO> {
268 fn clone(&self) -> Self {
269 Self {
270 origin: self.origin.clone(),
271 offset: self.offset.clone(),
272 length: self.length.clone(),
273 vector_size: self.vector_size,
274 io: PhantomData,
275 }
276 }
277}
278
279impl<E: CubePrimitive> SizedContainer for Slice<E, ReadOnly> {
281 type Item = E;
282}
283
284impl<E: CubePrimitive> Iterable<E> for SliceExpand<E, ReadOnly> {
285 fn expand(
286 self,
287 scope: &mut Scope,
288 mut body: impl FnMut(&mut Scope, <E as CubeType>::ExpandType),
289 ) {
290 let index_ty = u32::as_type(scope);
291 let len: ManagedVariable = self.length.clone().into();
292
293 let mut child = scope.child();
294 let i = child.create_local_restricted(index_ty);
295
296 let index = i.clone().into();
297 let item = index::expand(&mut child, self, index);
298 body(&mut child, item);
299
300 scope.register(Branch::RangeLoop(Box::new(RangeLoop {
301 i: *i,
302 start: 0usize.into(),
303 end: *len,
304 step: None,
305 inclusive: false,
306 scope: child,
307 })));
308 }
309
310 fn expand_unroll(
311 self,
312 _scope: &mut Scope,
313 _body: impl FnMut(&mut Scope, <E as CubeType>::ExpandType),
314 ) {
315 unimplemented!("Can't unroll slice iterator")
316 }
317}
318impl<E: CubePrimitive, IO: SliceVisibility> CubeIndex for Slice<E, IO> {
319 type Output = E;
320 type Idx = usize;
321
322 fn expand_index(
323 scope: &mut Scope,
324 array: Self::ExpandType,
325 index: NativeExpand<usize>,
326 ) -> <Self::Output as CubeType>::ExpandType {
327 array.__expand_read_method(scope, index)
328 }
329}
330
331impl<E: CubePrimitive, IO: SliceVisibility> CubeIndexExpand for SliceExpand<E, IO> {
332 type Output = E::ExpandType;
333 type Idx = NativeExpand<usize>;
334
335 fn expand_index(self, scope: &mut Scope, index: NativeExpand<usize>) -> Self::Output {
336 self.__expand_read_method(scope, index)
337 }
338 fn expand_index_unchecked(self, scope: &mut Scope, index: NativeExpand<usize>) -> Self::Output {
339 self.__expand_read_unchecked_method(scope, index)
340 }
341}
342
343impl<E: CubePrimitive, IO: SliceVisibility> List<E> for Slice<E, IO> {}
344impl<E: CubePrimitive, IO: SliceVisibility> ListExpand<E> for SliceExpand<E, IO> {
345 fn __expand_read_method(
346 &self,
347 scope: &mut cubecl_ir::Scope,
348 index: NativeExpand<usize>,
349 ) -> <E as CubeType>::ExpandType {
350 read_offset::expand::<E>(
351 scope,
352 self.origin.clone(),
353 self.offset.clone(),
354 index,
355 self.vector_size,
356 true,
357 )
358 }
359 fn __expand_read_unchecked_method(
360 &self,
361 scope: &mut cubecl_ir::Scope,
362 index: NativeExpand<usize>,
363 ) -> <E as CubeType>::ExpandType {
364 read_offset::expand::<E>(
365 scope,
366 self.origin.clone(),
367 self.offset.clone(),
368 index,
369 self.vector_size,
370 false,
371 )
372 }
373
374 fn __expand_len_method(&self, scope: &mut Scope) -> NativeExpand<usize> {
375 Self::__expand_len(scope, self.clone())
376 }
377}
378
379impl<T: CubePrimitive, IO: SliceVisibility> Deref for Slice<T, IO> {
380 type Target = [T];
381
382 fn deref(&self) -> &Self::Target {
383 unexpanded!()
384 }
385}
386
387impl<T: CubePrimitive> DerefMut for Slice<T, ReadWrite> {
388 fn deref_mut(&mut self) -> &mut Self::Target {
389 unexpanded!()
390 }
391}
392
393impl<E: CubePrimitive, IO: SliceVisibility> Vectorized for Slice<E, IO> {}
394impl<E: CubePrimitive, IO: SliceVisibility> VectorizedExpand for SliceExpand<E, IO> {
395 fn vector_size(&self) -> VectorSize {
396 self.vector_size
397 .unwrap_or_else(|| self.origin.vector_size())
398 }
399}
400
401impl<E: CubePrimitive> CubeIndexMut for Slice<E, ReadWrite> {
402 fn expand_index_mut(
403 scope: &mut Scope,
404 array: Self::ExpandType,
405 index: NativeExpand<usize>,
406 value: NativeExpand<E>,
407 ) {
408 array.__expand_write_method(scope, index, value)
409 }
410}
411
412impl<E: CubePrimitive> CubeIndexMutExpand for SliceExpand<E, ReadWrite> {
413 fn expand_index_mut(self, scope: &mut Scope, index: NativeExpand<usize>, value: Self::Output) {
414 self.__expand_write_method(scope, index, value)
415 }
416}
417
418impl<E: CubePrimitive> ListMut<E> for Slice<E, ReadWrite> {}
419impl<E: CubePrimitive> ListMutExpand<E> for SliceExpand<E, ReadWrite> {
420 fn __expand_write_method(
421 &self,
422 scope: &mut cubecl_ir::Scope,
423 index: NativeExpand<usize>,
424 value: NativeExpand<E>,
425 ) {
426 write_offset::expand::<E>(
427 scope,
428 self.origin.clone(),
429 self.offset.clone(),
430 index,
431 value,
432 self.vector_size,
433 )
434 }
435}
436
437mod read_offset {
438 use super::*;
439
440 pub fn expand<E: CubePrimitive>(
441 scope: &mut cubecl::prelude::Scope,
442 origin: SliceOriginExpand<E>,
443 offset: <usize as cubecl::prelude::CubeType>::ExpandType,
444 index: <usize as cubecl::prelude::CubeType>::ExpandType,
445 vector_size: Option<VectorSize>,
446 checked: bool,
447 ) -> <E as cubecl::prelude::CubeType>::ExpandType {
448 let index = cubecl::frontend::add::expand(scope, offset, index);
449
450 match origin {
451 SliceOriginExpand::Tensor(expand) => {
452 expand_index_native::<Tensor<E>>(scope, expand, index, vector_size, checked)
453 }
454 SliceOriginExpand::Array(expand) => {
455 expand_index_native::<Array<E>>(scope, expand, index, vector_size, checked)
456 }
457 SliceOriginExpand::SharedMemory(expand) => {
458 expand_index_native::<SharedMemory<E>>(scope, expand, index, vector_size, checked)
459 }
460 }
461 }
462}
463
464mod write_offset {
465 use super::*;
466
467 pub fn expand<E: CubePrimitive>(
468 scope: &mut cubecl::prelude::Scope,
469 origin: SliceOriginExpand<E>,
470 offset: <usize as cubecl::prelude::CubeType>::ExpandType,
471 index: <usize as cubecl::prelude::CubeType>::ExpandType,
472 value: <E as cubecl::prelude::CubeType>::ExpandType,
473 vector_size: Option<VectorSize>,
474 ) {
475 let index = cubecl::frontend::add::expand(scope, offset, index);
476
477 match origin {
478 SliceOriginExpand::Tensor(expand) => expand_index_assign_native::<Tensor<E>>(
479 scope,
480 expand,
481 index,
482 value,
483 vector_size,
484 true,
485 ),
486 SliceOriginExpand::Array(expand) => expand_index_assign_native::<Array<E>>(
487 scope,
488 expand,
489 index,
490 value,
491 vector_size,
492 false,
493 ),
494 SliceOriginExpand::SharedMemory(expand) => {
495 expand_index_assign_native::<SharedMemory<E>>(
496 scope,
497 expand,
498 index,
499 value,
500 vector_size,
501 false,
502 )
503 }
504 }
505 }
506}