1use alloc::boxed::Box;
2use core::{
3 marker::PhantomData,
4 ops::{Deref, DerefMut},
5};
6
7use crate::{self as cubecl, unexpanded};
8use cubecl::prelude::*;
9use cubecl_ir::{Branch, ElemType, ExpandElement, FloatKind, LineSize, RangeLoop, Type, Variable};
10use cubecl_macros::intrinsic;
11
12#[derive(Clone, Copy)]
13pub struct ReadOnly;
14#[derive(Clone, Copy)]
15pub struct ReadWrite;
16
17#[derive(Clone, Copy)]
23pub struct Slice<E: CubePrimitive, IO: SliceVisibility = ReadOnly> {
24 _e: PhantomData<E>,
25 _io: PhantomData<IO>,
26 _offset: PhantomData<usize>,
27 length: usize,
28}
29
30#[derive(CubeType)]
31pub enum SliceOrigin<E: CubePrimitive> {
32 Tensor(Tensor<E>),
33 Array(Array<E>),
34 SharedMemory(SharedMemory<E>),
35}
36
37impl<E: CubePrimitive> SliceOriginExpand<E> {
38 pub fn line_size(&self) -> LineSize {
39 match self {
40 SliceOriginExpand::Tensor(t) => t.line_size(),
41 SliceOriginExpand::Array(t) => t.line_size(),
42 SliceOriginExpand::SharedMemory(t) => t.line_size(),
43 }
44 }
45}
46
47impl<E: CubePrimitive, IO: SliceVisibility> Iterator for Slice<E, IO> {
48 type Item = E;
49
50 fn next(&mut self) -> Option<Self::Item> {
51 unexpanded!()
52 }
53}
54
55pub trait SliceVisibility: Clone + Copy + Send + Sync + 'static {}
56
57impl SliceVisibility for ReadOnly {}
58
59impl SliceVisibility for ReadWrite {}
60
61pub struct SliceExpand<E: CubePrimitive, IO: SliceVisibility> {
62 pub(crate) origin: SliceOriginExpand<E>,
63 pub(crate) io: PhantomData<IO>,
64 pub(crate) offset: ExpandElementTyped<usize>,
65 pub(crate) length: ExpandElementTyped<usize>,
66 pub(crate) line_size: Option<LineSize>,
67}
68
69impl<E: CubePrimitive, IO: SliceVisibility> SliceExpand<E, IO> {
70 pub fn __to_raw_parts(&self) -> (Variable, Variable) {
71 let expand = match self.origin.clone() {
72 SliceOriginExpand::Tensor(expand) => expand.expand,
73 SliceOriginExpand::Array(expand) => expand.expand,
74 SliceOriginExpand::SharedMemory(expand) => expand.expand,
75 };
76
77 (*expand, *self.offset.expand)
78 }
79}
80
81#[cube]
82impl<E: CubePrimitive, IO: SliceVisibility> Slice<Line<E>, IO> {
83 #[allow(unused_variables)]
89 pub fn with_line_size(&self, #[comptime] line_size: LineSize) -> Slice<Line<E>, IO> {
90 intrinsic!(|scope| {
91 let (input, offset) = self.__to_raw_parts();
92 let mut item = input.ty;
93
94 if line_size == item.line_size() {
95 return self;
96 }
97
98 let current = input.ty.line_size();
99 let mut out = self.clone();
100
101 if current < line_size {
102 let ratio = line_size / current;
103 let length = cubecl::frontend::div::expand(scope, self.length, ratio.into());
104 let offset = cubecl::frontend::div::expand(scope, self.offset, ratio.into());
105 out.length = length;
106 out.offset = offset;
107 } else {
108 let ratio = current / line_size;
109 let length = cubecl::frontend::mul::expand(scope, self.length, ratio.into());
110 let offset = cubecl::frontend::mul::expand(scope, self.offset, ratio.into());
111 out.length = length;
112 out.offset = offset;
113 }
114
115 out.line_size = Some(line_size);
116 out
117 })
118 }
119}
120
121#[cube]
122impl<E: CubePrimitive, IO: SliceVisibility> Slice<E, IO> {
123 pub fn into_lined(&self) -> Slice<Line<E>, IO> {
125 intrinsic!(|_scope| {
126 SliceExpand::<Line<E>, IO> {
127 origin: self.origin.cast_unchecked(),
128 io: self.io.clone(),
129 offset: self.offset.clone(),
130 length: self.length.clone(),
131 line_size: None,
132 }
133 })
134 }
135 pub fn downcast<T: CubePrimitive>(&self) -> Slice<T, IO> {
140 intrinsic!(|scope| {
141 if T::as_type(scope) != E::as_type(scope) && !is_tf32::<E, T>(scope) {
142 let elems = [T::as_type(scope).elem_type(), E::as_type(scope).elem_type()];
143 let is_flex32_cast = elems.contains(&ElemType::Float(FloatKind::F32))
144 && elems.contains(&ElemType::Float(FloatKind::Flex32));
145
146 if !is_flex32_cast {
147 panic!("Downcast should only be used to satisfy the Rust type system.")
148 }
149 }
150
151 SliceExpand::<T, IO> {
152 origin: self.origin.cast_unchecked(),
153 io: self.io.clone(),
154 offset: self.offset.clone(),
155 length: self.length.clone(),
156 line_size: self.line_size.clone(),
157 }
158 })
159 }
160}
161
162#[cube]
163impl<E: CubePrimitive> Slice<E, ReadOnly> {
164 pub fn as_mut_unchecked(&self) -> Slice<E, ReadWrite> {
165 intrinsic!(|scope| {
166 SliceExpand::<E, ReadWrite> {
167 origin: self.origin,
168 io: PhantomData,
169 offset: self.offset.clone(),
170 length: self.length.clone(),
171 line_size: self.line_size.clone(),
172 }
173 })
174 }
175}
176
177impl<E: CubePrimitive> SliceOriginExpand<E> {
178 fn cast_unchecked<T: CubePrimitive>(self) -> SliceOriginExpand<T> {
179 match self {
180 SliceOriginExpand::Tensor(expand) => {
181 SliceOriginExpand::<T>::Tensor(expand.expand.into())
182 }
183 SliceOriginExpand::Array(expand) => SliceOriginExpand::<T>::Array(expand.expand.into()),
184 SliceOriginExpand::SharedMemory(expand) => {
185 SliceOriginExpand::<T>::SharedMemory(expand.expand.into())
186 }
187 }
188 }
189}
190
191impl<E: CubePrimitive, IO: SliceVisibility> Slice<E, IO> {
192 pub fn new(_origin: SliceOrigin<E>, _offset: usize, _length: usize) -> Self {
193 unexpanded!()
194 }
195 pub fn __expand_new(
196 scope: &mut Scope,
197 origin: SliceOriginExpand<E>,
198 start: ExpandElementTyped<usize>,
199 end: ExpandElementTyped<usize>,
200 ) -> SliceExpand<E, IO> {
201 Self::__expand_new_expand(scope, origin, start, end)
202 }
203 pub fn __expand_new_expand(
204 scope: &mut Scope,
205 origin: SliceOriginExpand<E>,
206 start: ExpandElementTyped<usize>,
207 end: ExpandElementTyped<usize>,
208 ) -> SliceExpand<E, IO> {
209 let length = cubecl::frontend::sub::expand(scope, end, start.clone());
210
211 SliceExpand::<E, IO> {
212 origin,
213 io: PhantomData,
214 offset: start,
215 length,
216 line_size: None,
217 }
218 }
219}
220
221#[cube]
222impl<E: CubePrimitive, IO: SliceVisibility> Slice<E, IO> {
223 pub fn len(&self) -> usize {
225 self.length
226 }
227 pub fn is_empty(&self) -> bool {
229 self.length == 0
230 }
231}
232
233impl<E: CubePrimitive, IO: SliceVisibility> CubeType for Slice<E, IO> {
234 type ExpandType = SliceExpand<E, IO>;
235}
236
237impl<E: CubePrimitive, IO: SliceVisibility> CubeType for &Slice<E, IO> {
238 type ExpandType = SliceExpand<E, IO>;
239}
240
241impl<E: CubePrimitive, IO: SliceVisibility> CubeType for &mut Slice<E, IO> {
242 type ExpandType = SliceExpand<E, IO>;
243}
244
245impl<E: CubePrimitive, IO: SliceVisibility> IntoMut for SliceExpand<E, IO> {
246 fn into_mut(self, _scope: &mut cubecl_ir::Scope) -> Self {
247 self
248 }
249}
250
251impl<E: CubePrimitive, IO: SliceVisibility> CubeDebug for SliceExpand<E, IO> {}
252impl<E: CubePrimitive, IO: SliceVisibility> Clone for SliceExpand<E, IO> {
253 fn clone(&self) -> Self {
254 Self {
255 origin: self.origin.clone(),
256 offset: self.offset.clone(),
257 length: self.length.clone(),
258 line_size: self.line_size,
259 io: PhantomData,
260 }
261 }
262}
263
264impl<E: CubePrimitive> SizedContainer for Slice<E, ReadOnly> {
266 type Item = E;
267}
268
269impl<E: CubePrimitive> Iterable<E> for SliceExpand<E, ReadOnly> {
270 fn expand(
271 self,
272 scope: &mut Scope,
273 mut body: impl FnMut(&mut Scope, <E as CubeType>::ExpandType),
274 ) {
275 let index_ty = Type::new(u32::as_type(scope));
276 let len: ExpandElement = self.length.clone().into();
277
278 let mut child = scope.child();
279 let i = child.create_local_restricted(index_ty);
280
281 let index = i.clone().into();
282 let item = index::expand(&mut child, self, index);
283 body(&mut child, item);
284
285 scope.register(Branch::RangeLoop(Box::new(RangeLoop {
286 i: *i,
287 start: 0usize.into(),
288 end: *len,
289 step: None,
290 inclusive: false,
291 scope: child,
292 })));
293 }
294
295 fn expand_unroll(
296 self,
297 _scope: &mut Scope,
298 _body: impl FnMut(&mut Scope, <E as CubeType>::ExpandType),
299 ) {
300 unimplemented!("Can't unroll slice iterator")
301 }
302}
303impl<E: CubePrimitive, IO: SliceVisibility> CubeIndex for Slice<E, IO> {
304 type Output = E;
305 type Idx = usize;
306
307 fn expand_index(
308 scope: &mut Scope,
309 array: Self::ExpandType,
310 index: ExpandElementTyped<usize>,
311 ) -> <Self::Output as CubeType>::ExpandType {
312 array.__expand_read_method(scope, index)
313 }
314}
315
316impl<E: CubePrimitive, IO: SliceVisibility> CubeIndexExpand for SliceExpand<E, IO> {
317 type Output = E::ExpandType;
318 type Idx = ExpandElementTyped<usize>;
319
320 fn expand_index(self, scope: &mut Scope, index: ExpandElementTyped<usize>) -> Self::Output {
321 self.__expand_read_method(scope, index)
322 }
323 fn expand_index_unchecked(
324 self,
325 scope: &mut Scope,
326 index: ExpandElementTyped<usize>,
327 ) -> Self::Output {
328 self.__expand_read_unchecked_method(scope, index)
329 }
330}
331
332impl<E: CubePrimitive, IO: SliceVisibility> List<E> for Slice<E, IO> {}
333impl<E: CubePrimitive, IO: SliceVisibility> ListExpand<E> for SliceExpand<E, IO> {
334 fn __expand_read_method(
335 &self,
336 scope: &mut cubecl_ir::Scope,
337 index: ExpandElementTyped<usize>,
338 ) -> <E as CubeType>::ExpandType {
339 read_offset::expand::<E>(
340 scope,
341 self.origin.clone(),
342 self.offset.clone(),
343 index,
344 self.line_size,
345 true,
346 )
347 }
348 fn __expand_read_unchecked_method(
349 &self,
350 scope: &mut cubecl_ir::Scope,
351 index: ExpandElementTyped<usize>,
352 ) -> <E as CubeType>::ExpandType {
353 read_offset::expand::<E>(
354 scope,
355 self.origin.clone(),
356 self.offset.clone(),
357 index,
358 self.line_size,
359 false,
360 )
361 }
362
363 fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<usize> {
364 Self::__expand_len(scope, self.clone())
365 }
366}
367
368impl<T: CubePrimitive, IO: SliceVisibility> Deref for Slice<T, IO> {
369 type Target = [T];
370
371 fn deref(&self) -> &Self::Target {
372 unexpanded!()
373 }
374}
375
376impl<T: CubePrimitive> DerefMut for Slice<T, ReadWrite> {
377 fn deref_mut(&mut self) -> &mut Self::Target {
378 unexpanded!()
379 }
380}
381
382impl<E: CubePrimitive, IO: SliceVisibility> Lined for Slice<E, IO> {}
383impl<E: CubePrimitive, IO: SliceVisibility> LinedExpand for SliceExpand<E, IO> {
384 fn line_size(&self) -> LineSize {
385 self.line_size.unwrap_or_else(|| self.origin.line_size())
386 }
387}
388
389impl<E: CubePrimitive> CubeIndexMut for Slice<E, ReadWrite> {
390 fn expand_index_mut(
391 scope: &mut Scope,
392 array: Self::ExpandType,
393 index: ExpandElementTyped<usize>,
394 value: ExpandElementTyped<E>,
395 ) {
396 array.__expand_write_method(scope, index, value)
397 }
398}
399
400impl<E: CubePrimitive> CubeIndexMutExpand for SliceExpand<E, ReadWrite> {
401 fn expand_index_mut(
402 self,
403 scope: &mut Scope,
404 index: ExpandElementTyped<usize>,
405 value: Self::Output,
406 ) {
407 self.__expand_write_method(scope, index, value)
408 }
409}
410
411impl<E: CubePrimitive> ListMut<E> for Slice<E, ReadWrite> {}
412impl<E: CubePrimitive> ListMutExpand<E> for SliceExpand<E, ReadWrite> {
413 fn __expand_write_method(
414 &self,
415 scope: &mut cubecl_ir::Scope,
416 index: ExpandElementTyped<usize>,
417 value: ExpandElementTyped<E>,
418 ) {
419 write_offset::expand::<E>(
420 scope,
421 self.origin.clone(),
422 self.offset.clone(),
423 index,
424 value,
425 self.line_size,
426 )
427 }
428}
429
430mod read_offset {
431 use super::*;
432
433 pub fn expand<E: CubePrimitive>(
434 scope: &mut cubecl::prelude::Scope,
435 origin: SliceOriginExpand<E>,
436 offset: <usize as cubecl::prelude::CubeType>::ExpandType,
437 index: <usize as cubecl::prelude::CubeType>::ExpandType,
438 line_size: Option<LineSize>,
439 checked: bool,
440 ) -> <E as cubecl::prelude::CubeType>::ExpandType {
441 let index = cubecl::frontend::add::expand(scope, offset, index);
442
443 match origin {
444 SliceOriginExpand::Tensor(expand) => {
445 expand_index_native::<Tensor<E>>(scope, expand, index, line_size, checked)
446 }
447 SliceOriginExpand::Array(expand) => {
448 expand_index_native::<Array<E>>(scope, expand, index, line_size, checked)
449 }
450 SliceOriginExpand::SharedMemory(expand) => {
451 expand_index_native::<SharedMemory<E>>(scope, expand, index, line_size, checked)
452 }
453 }
454 }
455}
456
457mod write_offset {
458 use super::*;
459
460 pub fn expand<E: CubePrimitive>(
461 scope: &mut cubecl::prelude::Scope,
462 origin: SliceOriginExpand<E>,
463 offset: <usize as cubecl::prelude::CubeType>::ExpandType,
464 index: <usize as cubecl::prelude::CubeType>::ExpandType,
465 value: <E as cubecl::prelude::CubeType>::ExpandType,
466 line_size: Option<LineSize>,
467 ) {
468 let index = cubecl::frontend::add::expand(scope, offset, index);
469
470 match origin {
471 SliceOriginExpand::Tensor(expand) => expand_index_assign_native::<Tensor<E>>(
472 scope, expand, index, value, line_size, true,
473 ),
474 SliceOriginExpand::Array(expand) => expand_index_assign_native::<Array<E>>(
475 scope, expand, index, value, line_size, false,
476 ),
477 SliceOriginExpand::SharedMemory(expand) => {
478 expand_index_assign_native::<SharedMemory<E>>(
479 scope, expand, index, value, line_size, false,
480 )
481 }
482 }
483 }
484}