1use std::{marker::PhantomData, num::NonZero};
2
3use crate::{self as cubecl, unexpanded};
4use cubecl::prelude::*;
5use cubecl_ir::{Branch, Elem, ExpandElement, FloatKind, Item, RangeLoop, Variable};
6use cubecl_macros::intrinsic;
7
8#[derive(Clone)]
9pub struct ReadOnly;
10#[derive(Clone)]
11pub struct ReadWrite;
12
13#[derive(Clone, Copy)]
19pub struct Slice<E: CubePrimitive, IO: SliceVisibility = ReadOnly> {
20 _e: PhantomData<E>,
21 _io: PhantomData<IO>,
22 _offset: PhantomData<u32>,
23 length: u32,
24}
25
26#[derive(CubeType)]
27pub enum SliceOrigin<E: CubePrimitive> {
28 Tensor(Tensor<E>),
29 Array(Array<E>),
30 SharedMemory(SharedMemory<E>),
31}
32
33impl<E: CubePrimitive, IO: SliceVisibility> Iterator for Slice<E, IO> {
34 type Item = E;
35
36 fn next(&mut self) -> Option<Self::Item> {
37 unexpanded!()
38 }
39}
40
41pub trait SliceVisibility {}
42
43impl SliceVisibility for ReadOnly {}
44
45impl SliceVisibility for ReadWrite {}
46
47pub struct SliceExpand<E: CubePrimitive, IO: SliceVisibility> {
48 pub(crate) origin: SliceOriginExpand<E>,
49 pub(crate) io: PhantomData<IO>,
50 pub(crate) offset: ExpandElementTyped<u32>,
51 pub(crate) length: ExpandElementTyped<u32>,
52 pub(crate) line_size: Option<u32>,
53}
54
55impl<E: CubePrimitive, IO: SliceVisibility> SliceExpand<E, IO> {
56 pub fn __to_raw_parts(&self) -> (Variable, Variable) {
57 let expand = match self.origin.clone() {
58 SliceOriginExpand::Tensor(expand) => expand.expand,
59 SliceOriginExpand::Array(expand) => expand.expand,
60 SliceOriginExpand::SharedMemory(expand) => expand.expand,
61 };
62
63 (*expand, *self.offset.expand)
64 }
65}
66
67#[cube]
68impl<E: CubePrimitive, IO: SliceVisibility> Slice<Line<E>, IO> {
69 #[allow(unused_variables)]
75 pub fn with_line_size(&self, #[comptime] line_size: u32) -> Slice<Line<E>, IO> {
76 intrinsic!(|scope| {
77 let (input, offset) = self.__to_raw_parts();
78 let mut item = input.item;
79
80 if line_size as u8 == item.vectorization.unwrap_or(NonZero::new(1).unwrap()).get() {
81 return self;
82 }
83
84 let current = input.item.vectorization.map(|a| a.get()).unwrap_or(1) as u32;
85 let mut out = self.clone();
86
87 if current < line_size {
88 let ratio = line_size / current;
89 let length = cubecl::frontend::div::expand(scope, self.length, ratio.into());
90 let offset = cubecl::frontend::div::expand(scope, self.offset, ratio.into());
91 out.length = length;
92 out.offset = offset;
93 } else {
94 let ratio = current / line_size;
95 let length = cubecl::frontend::mul::expand(scope, self.length, ratio.into());
96 let offset = cubecl::frontend::mul::expand(scope, self.offset, ratio.into());
97 out.length = length;
98 out.offset = offset;
99 }
100
101 out.line_size = Some(line_size);
102 out
103 })
104 }
105}
106
107#[cube]
108impl<E: CubePrimitive, IO: SliceVisibility> Slice<E, IO> {
109 pub fn into_lined(&self) -> Slice<Line<E>, IO> {
111 intrinsic!(|_scope| {
112 SliceExpand::<Line<E>, IO> {
113 origin: self.origin.cast_unchecked(),
114 io: self.io.clone(),
115 offset: self.offset.clone(),
116 length: self.length.clone(),
117 line_size: None,
118 }
119 })
120 }
121 pub fn try_cast_unchecked<T: CubePrimitive>(&self) -> Slice<T, IO> {
127 intrinsic!(|scope| {
128 if T::as_elem(scope) != E::as_elem(scope) && !is_tf32::<E, T>(scope) {
129 let elems = [T::as_elem(scope), E::as_elem(scope)];
130 let is_flex32_cast = elems.contains(&Elem::Float(FloatKind::F32))
131 && elems.contains(&Elem::Float(FloatKind::Flex32));
132
133 if !is_flex32_cast {
134 panic!(
135 "Try cast unchecked should only be used to satisfy the rust type system."
136 )
137 }
138 }
139
140 SliceExpand::<T, IO> {
141 origin: self.origin.cast_unchecked(),
142 io: self.io.clone(),
143 offset: self.offset.clone(),
144 length: self.length.clone(),
145 line_size: self.line_size.clone(),
146 }
147 })
148 }
149}
150
151impl<E: CubePrimitive> SliceOriginExpand<E> {
152 fn cast_unchecked<T: CubePrimitive>(self) -> SliceOriginExpand<T> {
153 match self {
154 SliceOriginExpand::Tensor(expand) => {
155 SliceOriginExpand::<T>::Tensor(expand.expand.into())
156 }
157 SliceOriginExpand::Array(expand) => SliceOriginExpand::<T>::Array(expand.expand.into()),
158 SliceOriginExpand::SharedMemory(expand) => {
159 SliceOriginExpand::<T>::SharedMemory(expand.expand.into())
160 }
161 }
162 }
163}
164
165impl<E: CubePrimitive, IO: SliceVisibility> Slice<E, IO> {
166 pub fn new(_origin: SliceOrigin<E>, _offset: u32, _length: u32) -> Self {
167 unexpanded!()
168 }
169 pub fn __expand_new(
170 scope: &mut Scope,
171 origin: SliceOriginExpand<E>,
172 start: ExpandElementTyped<u32>,
173 end: ExpandElementTyped<u32>,
174 ) -> SliceExpand<E, IO> {
175 Self::__expand_new_expand(scope, origin, start, end)
176 }
177 pub fn __expand_new_expand(
178 scope: &mut Scope,
179 origin: SliceOriginExpand<E>,
180 start: ExpandElementTyped<u32>,
181 end: ExpandElementTyped<u32>,
182 ) -> SliceExpand<E, IO> {
183 let length = cubecl::frontend::sub::expand(scope, end, start.clone());
184
185 SliceExpand::<E, IO> {
186 origin,
187 io: PhantomData,
188 offset: start,
189 length,
190 line_size: None,
191 }
192 }
193}
194
195#[cube]
196impl<E: CubePrimitive, IO: SliceVisibility> Slice<E, IO> {
197 pub fn len(&self) -> u32 {
199 self.length
200 }
201 pub fn is_empty(&self) -> bool {
203 self.length == 0
204 }
205}
206
207impl<E: CubePrimitive, IO: SliceVisibility> CubeType for Slice<E, IO> {
208 type ExpandType = SliceExpand<E, IO>;
209}
210
211impl<E: CubePrimitive, IO: SliceVisibility> IntoMut for SliceExpand<E, IO> {
212 fn into_mut(self, _scope: &mut cubecl_ir::Scope) -> Self {
213 self
214 }
215}
216
217impl<E: CubePrimitive, IO: SliceVisibility> CubeDebug for SliceExpand<E, IO> {}
218impl<E: CubePrimitive, IO: SliceVisibility> Clone for SliceExpand<E, IO> {
219 fn clone(&self) -> Self {
220 Self {
221 origin: self.origin.clone(),
222 offset: self.offset.clone(),
223 length: self.length.clone(),
224 line_size: self.line_size,
225 io: PhantomData,
226 }
227 }
228}
229
230impl<E: CubePrimitive> SizedContainer for Slice<E, ReadOnly> {
232 type Item = E;
233}
234
235impl<E: CubePrimitive> Iterable<E> for SliceExpand<E, ReadOnly> {
236 fn expand(
237 self,
238 scope: &mut Scope,
239 mut body: impl FnMut(&mut Scope, <E as CubeType>::ExpandType),
240 ) {
241 let index_ty = Item::new(u32::as_elem(scope));
242 let len: ExpandElement = self.length.clone().into();
243
244 let mut child = scope.child();
245 let i = child.create_local_restricted(index_ty);
246
247 let index = i.clone().into();
248 let item = index::expand(&mut child, self, index);
249 body(&mut child, item);
250
251 scope.register(Branch::RangeLoop(Box::new(RangeLoop {
252 i: *i,
253 start: 0u32.into(),
254 end: *len,
255 step: None,
256 inclusive: false,
257 scope: child,
258 })));
259 }
260
261 fn expand_unroll(
262 self,
263 _scope: &mut Scope,
264 _body: impl FnMut(&mut Scope, <E as CubeType>::ExpandType),
265 ) {
266 unimplemented!("Can't unroll slice iterator")
267 }
268}
269impl<E: CubePrimitive> CubeIndex for Slice<E, ReadOnly> {
270 type Output = E;
271
272 fn expand_index(
273 scope: &mut Scope,
274 array: Self::ExpandType,
275 index: ExpandElementTyped<u32>,
276 ) -> <Self::Output as CubeType>::ExpandType {
277 array.__expand_read_method(scope, index)
278 }
279}
280
281impl<E: CubePrimitive> CubeIndexExpand for SliceExpand<E, ReadOnly> {
282 type Output = E::ExpandType;
283
284 fn expand_index(self, scope: &mut Scope, index: ExpandElementTyped<u32>) -> Self::Output {
285 self.__expand_read_method(scope, index)
286 }
287 fn expand_index_unchecked(
288 self,
289 scope: &mut Scope,
290 index: ExpandElementTyped<u32>,
291 ) -> Self::Output {
292 self.__expand_read_unchecked_method(scope, index)
293 }
294}
295
296impl<E: CubePrimitive> List<E> for Slice<E, ReadOnly> {}
297impl<E: CubePrimitive> ListExpand<E> for SliceExpand<E, ReadOnly> {
298 fn __expand_read_method(
299 &self,
300 scope: &mut cubecl_ir::Scope,
301 index: ExpandElementTyped<u32>,
302 ) -> <E as CubeType>::ExpandType {
303 read_offset::expand::<E>(
304 scope,
305 self.origin.clone(),
306 self.offset.clone(),
307 index,
308 self.line_size,
309 true,
310 )
311 }
312 fn __expand_read_unchecked_method(
313 &self,
314 scope: &mut cubecl_ir::Scope,
315 index: ExpandElementTyped<u32>,
316 ) -> <E as CubeType>::ExpandType {
317 read_offset::expand::<E>(
318 scope,
319 self.origin.clone(),
320 self.offset.clone(),
321 index,
322 self.line_size,
323 false,
324 )
325 }
326}
327
328impl<E: CubePrimitive> CubeIndex for Slice<E, ReadWrite> {
329 type Output = E;
330
331 fn expand_index(
332 scope: &mut Scope,
333 array: Self::ExpandType,
334 index: ExpandElementTyped<u32>,
335 ) -> <Self::Output as CubeType>::ExpandType {
336 array.__expand_read_method(scope, index)
337 }
338}
339
340impl<E: CubePrimitive> CubeIndexExpand for SliceExpand<E, ReadWrite> {
341 type Output = E::ExpandType;
342
343 fn expand_index(self, scope: &mut Scope, index: ExpandElementTyped<u32>) -> Self::Output {
344 self.__expand_read_method(scope, index)
345 }
346 fn expand_index_unchecked(
347 self,
348 scope: &mut Scope,
349 index: ExpandElementTyped<u32>,
350 ) -> Self::Output {
351 self.__expand_read_unchecked_method(scope, index)
352 }
353}
354
355impl<E: CubePrimitive> List<E> for Slice<E, ReadWrite> {}
356impl<E: CubePrimitive> ListExpand<E> for SliceExpand<E, ReadWrite> {
357 fn __expand_read_method(
358 &self,
359 scope: &mut cubecl_ir::Scope,
360 index: ExpandElementTyped<u32>,
361 ) -> <E as CubeType>::ExpandType {
362 read_offset::expand::<E>(
363 scope,
364 self.origin.clone(),
365 self.offset.clone(),
366 index,
367 self.line_size,
368 true,
369 )
370 }
371 fn __expand_read_unchecked_method(
372 &self,
373 scope: &mut cubecl_ir::Scope,
374 index: ExpandElementTyped<u32>,
375 ) -> <E as CubeType>::ExpandType {
376 read_offset::expand::<E>(
377 scope,
378 self.origin.clone(),
379 self.offset.clone(),
380 index,
381 self.line_size,
382 false,
383 )
384 }
385}
386
387impl<E: CubePrimitive> CubeIndexMut for Slice<E, ReadWrite> {
388 fn expand_index_mut(
389 scope: &mut Scope,
390 array: Self::ExpandType,
391 index: ExpandElementTyped<u32>,
392 value: ExpandElementTyped<E>,
393 ) {
394 array.__expand_write_method(scope, index, value)
395 }
396}
397
398impl<E: CubePrimitive> CubeIndexMutExpand for SliceExpand<E, ReadWrite> {
399 type Output = E::ExpandType;
400
401 fn expand_index_mut(
402 self,
403 scope: &mut Scope,
404 index: ExpandElementTyped<u32>,
405 value: Self::Output,
406 ) {
407 self.__expand_write_method(scope, index, value)
408 }
409}
410
411impl<E: CubePrimitive> ListMut<E> for Slice<E, ReadWrite> {}
412impl<E: CubePrimitive> ListMutExpand<E> for SliceExpand<E, ReadWrite> {
413 fn __expand_write_method(
414 &self,
415 scope: &mut cubecl_ir::Scope,
416 index: ExpandElementTyped<u32>,
417 value: ExpandElementTyped<E>,
418 ) {
419 write_offset::expand::<E>(
420 scope,
421 self.origin.clone(),
422 self.offset.clone(),
423 index,
424 value,
425 self.line_size,
426 )
427 }
428}
429
430mod read_offset {
431 use super::*;
432
433 pub fn expand<E: CubePrimitive>(
434 scope: &mut cubecl::prelude::Scope,
435 origin: SliceOriginExpand<E>,
436 offset: <u32 as cubecl::prelude::CubeType>::ExpandType,
437 index: <u32 as cubecl::prelude::CubeType>::ExpandType,
438 line_size: Option<u32>,
439 checked: bool,
440 ) -> <E as cubecl::prelude::CubeType>::ExpandType {
441 let index = cubecl::frontend::add::expand(scope, offset, index);
442
443 match origin {
444 SliceOriginExpand::Tensor(expand) => {
445 expand_index_native::<Tensor<E>>(scope, expand, index, line_size, checked)
446 }
447 SliceOriginExpand::Array(expand) => {
448 expand_index_native::<Array<E>>(scope, expand, index, line_size, checked)
449 }
450 SliceOriginExpand::SharedMemory(expand) => {
451 expand_index_native::<SharedMemory<E>>(scope, expand, index, line_size, checked)
452 }
453 }
454 }
455}
456
457mod write_offset {
458 use super::*;
459
460 pub fn expand<E: CubePrimitive>(
461 scope: &mut cubecl::prelude::Scope,
462 origin: SliceOriginExpand<E>,
463 offset: <u32 as cubecl::prelude::CubeType>::ExpandType,
464 index: <u32 as cubecl::prelude::CubeType>::ExpandType,
465 value: <E as cubecl::prelude::CubeType>::ExpandType,
466 line_size: Option<u32>,
467 ) {
468 let index = cubecl::frontend::add::expand(scope, offset, index);
469
470 match origin {
471 SliceOriginExpand::Tensor(expand) => expand_index_assign_native::<Tensor<E>>(
472 scope, expand, index, value, line_size, true,
473 ),
474 SliceOriginExpand::Array(expand) => expand_index_assign_native::<Array<E>>(
475 scope, expand, index, value, line_size, false,
476 ),
477 SliceOriginExpand::SharedMemory(expand) => {
478 expand_index_assign_native::<SharedMemory<E>>(
479 scope, expand, index, value, line_size, false,
480 )
481 }
482 }
483 }
484}