1use super::Line;
2use crate::{
3 frontend::{
4 Array, CubePrimitive, CubeType, ExpandElementTyped, Init, SharedMemory, SizedContainer,
5 Tensor, indexation::Index,
6 },
7 ir::{Instruction, Scope},
8 prelude::{CubeDebug, List, ListExpand, ListMut, ListMutExpand, index, index_assign},
9 unexpanded,
10};
11use cubecl_common::tf32;
12use cubecl_ir::{ExpandElement, Operator};
13use std::marker::PhantomData;
14
15#[derive(Clone, Copy)]
21pub struct Slice<E> {
22 _e: PhantomData<E>,
23}
24
25#[derive(Clone, Copy)]
31pub struct SliceMut<E> {
32 _e: PhantomData<E>,
33}
34
35#[allow(unused)]
36mod metadata {
37 use core::num::NonZero;
38
39 use cubecl_ir::{Elem, FloatKind, Item, NonSemantic};
40
41 use crate::prelude::cube_comment;
42
43 use super::*;
44
45 impl<E> Slice<E> {
46 #[allow(clippy::len_without_is_empty)]
48 pub fn len(&self) -> u32 {
49 unexpanded!()
50 }
51
52 pub fn into_lined(&self) -> Slice<Line<E>>
54 where
55 E: CubePrimitive,
56 {
57 unexpanded!()
58 }
59 pub fn try_cast_unchecked<T>(&self) -> Slice<T>
64 where
65 E: CubePrimitive,
66 T: CubePrimitive,
67 {
68 unexpanded!()
69 }
70 }
71
72 impl<E: CubePrimitive> Slice<Line<E>> {
73 pub fn with_line_size(&self, line_size: u32) -> Slice<Line<E>> {
80 unexpanded!()
81 }
82 }
83
84 impl<E> SliceMut<E> {
85 #[allow(clippy::len_without_is_empty)]
87 pub fn len(&self) -> u32 {
88 unexpanded!()
89 }
90
91 pub fn into_lined(self) -> SliceMut<Line<E>>
93 where
94 E: CubePrimitive,
95 {
96 unexpanded!()
97 }
98
99 pub fn try_cast_unchecked<T>(&self) -> SliceMut<T>
104 where
105 E: CubePrimitive,
106 T: CubePrimitive,
107 {
108 unexpanded!()
109 }
110 }
111
112 impl<E: CubePrimitive> SliceMut<Line<E>> {
113 pub fn with_line_size(&self, line_size: u32) -> SliceMut<Line<E>> {
120 unexpanded!()
121 }
122 }
123
124 impl<E: CubePrimitive> ExpandElementTyped<Slice<E>> {
125 pub fn __expand_len_method(self, scope: &mut Scope) -> ExpandElementTyped<u32> {
127 let elem: ExpandElementTyped<Array<u32>> = self.expand.into();
128 elem.__expand_len_method(scope)
129 }
130
131 pub fn __expand_into_lined_method(
133 self,
134 _scope: &mut Scope,
135 ) -> ExpandElementTyped<Slice<Line<E>>>
136 where
137 E: CubePrimitive,
138 {
139 self.expand.into()
140 }
141
142 pub fn __expand_try_cast_unchecked_method<T>(
144 self,
145 scope: &mut Scope,
146 ) -> ExpandElementTyped<Slice<T>>
147 where
148 E: CubePrimitive,
149 T: CubePrimitive,
150 {
151 if T::as_elem(scope) != E::as_elem(scope) && !is_tf32::<E, T>(scope) {
152 let elems = [T::as_elem(scope), E::as_elem(scope)];
153 let is_flex32_cast = elems.contains(&Elem::Float(FloatKind::F32))
154 && elems.contains(&Elem::Float(FloatKind::Flex32));
155
156 if !is_flex32_cast {
157 panic!(
158 "Try cast unchecked should only be used to satisfy the rust type system."
159 )
160 }
161 }
162
163 self.expand.into()
164 }
165
166 pub fn __expand_clone_method(self, _scope: &mut Scope) -> ExpandElementTyped<Slice<Line<E>>>
167 where
168 E: CubePrimitive,
169 {
170 self.expand.into()
171 }
172 }
173
174 impl<E: CubePrimitive> ExpandElementTyped<Slice<Line<E>>> {
175 pub fn __expand_with_line_size_method(
177 self,
178 scope: &mut Scope,
179 line_size: u32,
180 ) -> ExpandElementTyped<Slice<Line<E>>>
181 where
182 E: CubePrimitive,
183 {
184 let input = self.clone().into_variable();
185 let mut item = input.item;
186
187 if line_size as u8 == item.vectorization.unwrap_or(NonZero::new(1).unwrap()).get() {
188 return self;
189 }
190
191 item.vectorization = NonZero::new(line_size as u8);
192 let out = scope.create_slice(item);
193
194 scope.register(Instruction::new(
195 Operator::ReinterpretSlice(cubecl_ir::ReinterpretSliceOperator {
196 input,
197 line_size,
198 }),
199 *out,
200 ));
201
202 out.into()
203 }
204 }
205
206 impl<E: CubePrimitive> ExpandElementTyped<SliceMut<E>> {
207 pub fn __expand_len_method(self, scope: &mut Scope) -> ExpandElementTyped<u32> {
209 let elem: ExpandElementTyped<Array<u32>> = self.expand.into();
210 elem.__expand_len_method(scope)
211 }
212
213 pub fn __expand_into_lined_method(
215 self,
216 _scope: &mut Scope,
217 ) -> ExpandElementTyped<SliceMut<Line<E>>>
218 where
219 E: CubePrimitive,
220 {
221 self.expand.into()
222 }
223
224 pub fn __expand_try_cast_unchecked_method<T>(
226 self,
227 scope: &mut Scope,
228 ) -> ExpandElementTyped<SliceMut<T>>
229 where
230 E: CubePrimitive,
231 T: CubePrimitive,
232 {
233 if T::as_elem(scope) != E::as_elem(scope) && !is_tf32::<E, T>(scope) {
234 panic!("Try cast unchecked should only be used to satisfy the rust type system.")
235 }
236
237 self.expand.into()
238 }
239 }
240
241 impl<E: CubePrimitive> ExpandElementTyped<SliceMut<Line<E>>> {
242 pub fn __expand_with_line_size_method(
244 self,
245 scope: &mut Scope,
246 line_size: u32,
247 ) -> ExpandElementTyped<SliceMut<Line<E>>>
248 where
249 E: CubePrimitive,
250 {
251 let input = self.clone().into_variable();
252 let mut item = input.item;
253
254 if line_size as u8 == item.vectorization.unwrap_or(NonZero::new(1).unwrap()).get() {
255 return self;
256 }
257
258 item.vectorization = NonZero::new(line_size as u8);
259 let out = scope.create_slice(item);
260
261 scope.register(Instruction::new(
262 Operator::ReinterpretSlice(cubecl_ir::ReinterpretSliceOperator {
263 input,
264 line_size,
265 }),
266 *out,
267 ));
268 out.into()
269 }
270 }
271}
272
273pub(crate) fn is_tf32<C: CubePrimitive, T: CubePrimitive>(scope: &mut Scope) -> bool {
274 let ty_c = C::as_elem(scope);
275 let ty_t = T::as_elem(scope);
276 let ty_f32 = f32::as_elem(scope);
277 let ty_tf32 = tf32::as_elem(scope);
278
279 (ty_c == ty_f32 && ty_t == ty_tf32) || (ty_c == ty_tf32 && ty_t == ty_f32)
280}
281
282mod indexation {
284 use cubecl_ir::{BinaryOperator, Instruction, Operator};
285
286 use crate::prelude::{CubeIndex, CubeIndexMut};
287
288 use super::*;
289
290 impl<E: CubePrimitive> Slice<E> {
291 pub unsafe fn index_unchecked<I: Index>(&self, _i: I) -> &E
297 where
298 Self: CubeIndex<I>,
299 {
300 unexpanded!()
301 }
302 }
303
304 impl<E: CubePrimitive> SliceMut<E> {
305 pub unsafe fn index_unchecked<I: Index>(&self, _i: I) -> &E
311 where
312 Self: CubeIndex<I>,
313 {
314 unexpanded!()
315 }
316
317 pub unsafe fn index_assign_unchecked<I: Index>(&mut self, _i: I, _value: E)
323 where
324 Self: CubeIndexMut<I>,
325 {
326 unexpanded!()
327 }
328 }
329
330 impl<E: CubePrimitive> ExpandElementTyped<Slice<E>> {
331 pub fn __expand_index_unchecked_method(
332 self,
333 scope: &mut Scope,
334 i: ExpandElementTyped<u32>,
335 ) -> ExpandElementTyped<E> {
336 let out = scope.create_local(self.expand.item);
337 scope.register(Instruction::new(
338 Operator::UncheckedIndex(BinaryOperator {
339 lhs: *self.expand,
340 rhs: i.expand.consume(),
341 }),
342 *out,
343 ));
344 out.into()
345 }
346 }
347
348 impl<E: CubePrimitive> ExpandElementTyped<SliceMut<E>> {
349 pub fn __expand_index_unchecked_method(
350 self,
351 scope: &mut Scope,
352 i: ExpandElementTyped<u32>,
353 ) -> ExpandElementTyped<E> {
354 let out = scope.create_local(self.expand.item);
355 scope.register(Instruction::new(
356 Operator::UncheckedIndex(BinaryOperator {
357 lhs: *self.expand,
358 rhs: i.expand.consume(),
359 }),
360 *out,
361 ));
362 out.into()
363 }
364
365 pub fn __expand_index_assign_unchecked_method(
366 self,
367 scope: &mut Scope,
368 i: ExpandElementTyped<u32>,
369 value: ExpandElementTyped<E>,
370 ) {
371 scope.register(Instruction::new(
372 Operator::UncheckedIndexAssign(BinaryOperator {
373 lhs: i.expand.consume(),
374 rhs: value.expand.consume(),
375 }),
376 *self.expand,
377 ));
378 }
379 }
380}
381
382impl<E: CubeType> CubeType for Slice<E> {
383 type ExpandType = ExpandElementTyped<Slice<E>>;
384}
385
386impl<C: CubeType> Init for ExpandElementTyped<Slice<C>> {
387 fn init(self, _scope: &mut Scope) -> Self {
388 self
390 }
391}
392
393impl<E: CubeType> CubeType for SliceMut<E> {
394 type ExpandType = ExpandElementTyped<SliceMut<E>>;
395}
396
397impl<E: CubeType> CubeType for &mut SliceMut<E> {
398 type ExpandType = ExpandElementTyped<SliceMut<E>>;
399}
400
401impl<C: CubeType> Init for ExpandElementTyped<SliceMut<C>> {
402 fn init(self, _scope: &mut Scope) -> Self {
403 self
405 }
406}
407
408impl<C: CubeType<ExpandType = ExpandElementTyped<C>>> SizedContainer for Slice<C> {
409 type Item = C;
410}
411
412impl<T: CubeType> Iterator for Slice<T> {
413 type Item = T;
414
415 fn next(&mut self) -> Option<Self::Item> {
416 unexpanded!()
417 }
418}
419
420pub trait SliceOperator<E: CubeType>: CubeType<ExpandType = Self::Expand> {
421 type Expand: SliceOperatorExpand<E>;
422
423 #[allow(unused_variables)]
427 fn slice<Start: Index, End: Index>(&self, start: Start, end: End) -> Slice<E> {
428 unexpanded!()
429 }
430 fn __expand_slice(
432 scope: &mut Scope,
433 expand: Self::Expand,
434 start: ExpandElementTyped<u32>,
435 end: ExpandElementTyped<u32>,
436 ) -> ExpandElementTyped<Slice<E>> {
437 expand.__expand_slice_method(scope, start, end)
438 }
439
440 #[allow(unused_variables)]
444 fn slice_mut<Start: Index, End: Index>(&mut self, start: Start, end: End) -> SliceMut<E> {
445 unexpanded!()
446 }
447
448 fn __expand_slice_mut(
450 scope: &mut Scope,
451 expand: Self::Expand,
452 start: ExpandElementTyped<u32>,
453 end: ExpandElementTyped<u32>,
454 ) -> ExpandElementTyped<SliceMut<E>> {
455 expand.__expand_slice_mut_method(scope, start, end)
456 }
457
458 #[allow(unused_variables)]
460 fn to_slice(&self) -> Slice<E> {
461 unexpanded!()
462 }
463
464 fn __expand_to_slice(scope: &mut Scope, expand: Self::Expand) -> ExpandElementTyped<Slice<E>> {
466 expand.__expand_to_slice_method(scope)
467 }
468
469 #[allow(unused_variables, clippy::wrong_self_convention)]
471 fn to_slice_mut(&mut self) -> SliceMut<E> {
472 unexpanded!()
473 }
474
475 fn __expand_to_slice_mut(
477 scope: &mut Scope,
478 expand: Self::Expand,
479 ) -> ExpandElementTyped<SliceMut<E>> {
480 expand.__expand_to_slice_mut_method(scope)
481 }
482}
483
484pub trait SliceOperatorExpand<E: CubeType>: Into<ExpandElement> + Clone + Init + CubeDebug {
485 fn slice_base<Start: Index, End: Index>(
486 &self,
487 scope: &mut Scope,
488 start: Start,
489 end: End,
490 ) -> ExpandElement;
491
492 fn __expand_slice_method(
493 &self,
494 scope: &mut Scope,
495 start: ExpandElementTyped<u32>,
496 end: ExpandElementTyped<u32>,
497 ) -> ExpandElementTyped<Slice<E>> {
498 ExpandElementTyped::new(self.slice_base(scope, start, end))
499 }
500
501 fn __expand_slice_mut_method(
502 &self,
503 scope: &mut Scope,
504 start: ExpandElementTyped<u32>,
505 end: ExpandElementTyped<u32>,
506 ) -> ExpandElementTyped<SliceMut<E>> {
507 ExpandElementTyped::new(self.slice_base(scope, start, end))
508 }
509
510 fn __expand_to_slice_method(&self, _scope: &mut Scope) -> ExpandElementTyped<Slice<E>> {
511 let expand = self.clone().into();
512 ExpandElementTyped::new(expand)
513 }
514
515 fn __expand_to_slice_mut_method(&self, _scope: &mut Scope) -> ExpandElementTyped<SliceMut<E>> {
516 let expand = self.clone().into();
517 ExpandElementTyped::new(expand)
518 }
519}
520
521macro_rules! slice_op {
522 ($type:ident) => {
523 impl<E: CubePrimitive> SliceOperator<E> for $type<E> {
524 type Expand = ExpandElementTyped<$type<E>>;
525 }
526
527 impl<E: CubePrimitive> SliceOperatorExpand<E> for ExpandElementTyped<$type<E>> {
528 fn slice_base<Start: Index, End: Index>(
529 &self,
530 scope: &mut Scope,
531 start: Start,
532 end: End,
533 ) -> ExpandElement {
534 slice_expand(scope, self.clone(), start, end)
535 }
536 }
537 };
538 (slice $type:ident) => {
539 impl<E: CubePrimitive> SliceOperator<E> for $type<E> {
540 type Expand = ExpandElementTyped<$type<E>>;
541 }
542
543 impl<E: CubePrimitive> SliceOperatorExpand<E> for ExpandElementTyped<$type<E>> {
544 fn slice_base<Start: Index, End: Index>(
545 &self,
546 scope: &mut Scope,
547 start: Start,
548 end: End,
549 ) -> ExpandElement {
550 slice_expand(scope, self.clone(), start, end)
551 }
552 }
553 };
554}
555
556slice_op!(Array);
557slice_op!(Tensor);
558slice_op!(SharedMemory);
559slice_op!(slice Slice);
560slice_op!(slice SliceMut);
561
562pub fn slice_expand<I: Into<ExpandElement>, S1: Index, S2: Index>(
563 scope: &mut Scope,
564 input: I,
565 start: S1,
566 end: S2, ) -> ExpandElement {
568 let input = input.into();
569 let out = scope.create_slice(input.item);
570
571 scope.register(Instruction::new(
572 Operator::Slice(cubecl_ir::SliceOperator {
573 input: *input,
574 start: start.value(),
575 end: end.value(),
576 }),
577 *out,
578 ));
579
580 out
581}
582
583impl<T: CubePrimitive> List<T> for Slice<T> {
584 fn __expand_read(
585 scope: &mut Scope,
586 this: ExpandElementTyped<Slice<T>>,
587 idx: ExpandElementTyped<u32>,
588 ) -> ExpandElementTyped<T> {
589 index::expand(scope, this, idx)
590 }
591}
592
593impl<T: CubePrimitive> ListExpand<T> for ExpandElementTyped<Slice<T>> {
594 fn __expand_read_method(
595 self,
596 scope: &mut Scope,
597 idx: ExpandElementTyped<u32>,
598 ) -> ExpandElementTyped<T> {
599 index::expand(scope, self, idx)
600 }
601}
602
603impl<T: CubePrimitive> List<T> for SliceMut<T> {
604 fn __expand_read(
605 scope: &mut Scope,
606 this: ExpandElementTyped<SliceMut<T>>,
607 idx: ExpandElementTyped<u32>,
608 ) -> ExpandElementTyped<T> {
609 index::expand(scope, this, idx)
610 }
611}
612
613impl<T: CubePrimitive> ListExpand<T> for ExpandElementTyped<SliceMut<T>> {
614 fn __expand_read_method(
615 self,
616 scope: &mut Scope,
617 idx: ExpandElementTyped<u32>,
618 ) -> ExpandElementTyped<T> {
619 index::expand(scope, self, idx)
620 }
621}
622
623impl<T: CubePrimitive> ListMut<T> for SliceMut<T> {
624 fn __expand_write(
625 scope: &mut Scope,
626 this: ExpandElementTyped<SliceMut<T>>,
627 idx: ExpandElementTyped<u32>,
628 value: ExpandElementTyped<T>,
629 ) {
630 index_assign::expand(scope, this, idx, value);
631 }
632}
633
634impl<T: CubePrimitive> ListMutExpand<T> for ExpandElementTyped<SliceMut<T>> {
635 fn __expand_write_method(
636 self,
637 scope: &mut Scope,
638 idx: ExpandElementTyped<u32>,
639 value: ExpandElementTyped<T>,
640 ) {
641 index_assign::expand(scope, self, idx, value);
642 }
643}