1use std::{
2 marker::PhantomData,
3 ops::{Deref, DerefMut},
4};
5
6use crate::{self as cubecl, unexpanded};
7use cubecl::prelude::*;
8use cubecl_ir::{Branch, ElemType, ExpandElement, FloatKind, LineSize, RangeLoop, Type, Variable};
9use cubecl_macros::intrinsic;
10
11#[derive(Clone, Copy)]
12pub struct ReadOnly;
13#[derive(Clone, Copy)]
14pub struct ReadWrite;
15
16#[derive(Clone, Copy)]
22pub struct Slice<E: CubePrimitive, IO: SliceVisibility = ReadOnly> {
23 _e: PhantomData<E>,
24 _io: PhantomData<IO>,
25 _offset: PhantomData<usize>,
26 length: usize,
27}
28
29#[derive(CubeType)]
30pub enum SliceOrigin<E: CubePrimitive> {
31 Tensor(Tensor<E>),
32 Array(Array<E>),
33 SharedMemory(SharedMemory<E>),
34}
35
36impl<E: CubePrimitive> SliceOriginExpand<E> {
37 pub fn line_size(&self) -> LineSize {
38 match self {
39 SliceOriginExpand::Tensor(t) => t.line_size(),
40 SliceOriginExpand::Array(t) => t.line_size(),
41 SliceOriginExpand::SharedMemory(t) => t.line_size(),
42 }
43 }
44}
45
46impl<E: CubePrimitive, IO: SliceVisibility> Iterator for Slice<E, IO> {
47 type Item = E;
48
49 fn next(&mut self) -> Option<Self::Item> {
50 unexpanded!()
51 }
52}
53
54pub trait SliceVisibility: Clone + Copy + Send + Sync + 'static {}
55
56impl SliceVisibility for ReadOnly {}
57
58impl SliceVisibility for ReadWrite {}
59
60pub struct SliceExpand<E: CubePrimitive, IO: SliceVisibility> {
61 pub(crate) origin: SliceOriginExpand<E>,
62 pub(crate) io: PhantomData<IO>,
63 pub(crate) offset: ExpandElementTyped<usize>,
64 pub(crate) length: ExpandElementTyped<usize>,
65 pub(crate) line_size: Option<LineSize>,
66}
67
68impl<E: CubePrimitive, IO: SliceVisibility> SliceExpand<E, IO> {
69 pub fn __to_raw_parts(&self) -> (Variable, Variable) {
70 let expand = match self.origin.clone() {
71 SliceOriginExpand::Tensor(expand) => expand.expand,
72 SliceOriginExpand::Array(expand) => expand.expand,
73 SliceOriginExpand::SharedMemory(expand) => expand.expand,
74 };
75
76 (*expand, *self.offset.expand)
77 }
78}
79
80#[cube]
81impl<E: CubePrimitive, IO: SliceVisibility> Slice<Line<E>, IO> {
82 #[allow(unused_variables)]
88 pub fn with_line_size(&self, #[comptime] line_size: LineSize) -> Slice<Line<E>, IO> {
89 intrinsic!(|scope| {
90 let (input, offset) = self.__to_raw_parts();
91 let mut item = input.ty;
92
93 if line_size == item.line_size() {
94 return self;
95 }
96
97 let current = input.ty.line_size();
98 let mut out = self.clone();
99
100 if current < line_size {
101 let ratio = line_size / current;
102 let length = cubecl::frontend::div::expand(scope, self.length, ratio.into());
103 let offset = cubecl::frontend::div::expand(scope, self.offset, ratio.into());
104 out.length = length;
105 out.offset = offset;
106 } else {
107 let ratio = current / line_size;
108 let length = cubecl::frontend::mul::expand(scope, self.length, ratio.into());
109 let offset = cubecl::frontend::mul::expand(scope, self.offset, ratio.into());
110 out.length = length;
111 out.offset = offset;
112 }
113
114 out.line_size = Some(line_size);
115 out
116 })
117 }
118}
119
120#[cube]
121impl<E: CubePrimitive, IO: SliceVisibility> Slice<E, IO> {
122 pub fn into_lined(&self) -> Slice<Line<E>, IO> {
124 intrinsic!(|_scope| {
125 SliceExpand::<Line<E>, IO> {
126 origin: self.origin.cast_unchecked(),
127 io: self.io.clone(),
128 offset: self.offset.clone(),
129 length: self.length.clone(),
130 line_size: None,
131 }
132 })
133 }
134 pub fn downcast<T: CubePrimitive>(&self) -> Slice<T, IO> {
139 intrinsic!(|scope| {
140 if T::as_type(scope) != E::as_type(scope) && !is_tf32::<E, T>(scope) {
141 let elems = [T::as_type(scope).elem_type(), E::as_type(scope).elem_type()];
142 let is_flex32_cast = elems.contains(&ElemType::Float(FloatKind::F32))
143 && elems.contains(&ElemType::Float(FloatKind::Flex32));
144
145 if !is_flex32_cast {
146 panic!("Downcast should only be used to satisfy the Rust type system.")
147 }
148 }
149
150 SliceExpand::<T, IO> {
151 origin: self.origin.cast_unchecked(),
152 io: self.io.clone(),
153 offset: self.offset.clone(),
154 length: self.length.clone(),
155 line_size: self.line_size.clone(),
156 }
157 })
158 }
159}
160
161#[cube]
162impl<E: CubePrimitive> Slice<E, ReadOnly> {
163 pub fn as_mut_unchecked(&self) -> Slice<E, ReadWrite> {
164 intrinsic!(|scope| {
165 SliceExpand::<E, ReadWrite> {
166 origin: self.origin,
167 io: PhantomData,
168 offset: self.offset.clone(),
169 length: self.length.clone(),
170 line_size: self.line_size.clone(),
171 }
172 })
173 }
174}
175
176impl<E: CubePrimitive> SliceOriginExpand<E> {
177 fn cast_unchecked<T: CubePrimitive>(self) -> SliceOriginExpand<T> {
178 match self {
179 SliceOriginExpand::Tensor(expand) => {
180 SliceOriginExpand::<T>::Tensor(expand.expand.into())
181 }
182 SliceOriginExpand::Array(expand) => SliceOriginExpand::<T>::Array(expand.expand.into()),
183 SliceOriginExpand::SharedMemory(expand) => {
184 SliceOriginExpand::<T>::SharedMemory(expand.expand.into())
185 }
186 }
187 }
188}
189
190impl<E: CubePrimitive, IO: SliceVisibility> Slice<E, IO> {
191 pub fn new(_origin: SliceOrigin<E>, _offset: usize, _length: usize) -> Self {
192 unexpanded!()
193 }
194 pub fn __expand_new(
195 scope: &mut Scope,
196 origin: SliceOriginExpand<E>,
197 start: ExpandElementTyped<usize>,
198 end: ExpandElementTyped<usize>,
199 ) -> SliceExpand<E, IO> {
200 Self::__expand_new_expand(scope, origin, start, end)
201 }
202 pub fn __expand_new_expand(
203 scope: &mut Scope,
204 origin: SliceOriginExpand<E>,
205 start: ExpandElementTyped<usize>,
206 end: ExpandElementTyped<usize>,
207 ) -> SliceExpand<E, IO> {
208 let length = cubecl::frontend::sub::expand(scope, end, start.clone());
209
210 SliceExpand::<E, IO> {
211 origin,
212 io: PhantomData,
213 offset: start,
214 length,
215 line_size: None,
216 }
217 }
218}
219
220#[cube]
221impl<E: CubePrimitive, IO: SliceVisibility> Slice<E, IO> {
222 pub fn len(&self) -> usize {
224 self.length
225 }
226 pub fn is_empty(&self) -> bool {
228 self.length == 0
229 }
230}
231
232impl<E: CubePrimitive, IO: SliceVisibility> CubeType for Slice<E, IO> {
233 type ExpandType = SliceExpand<E, IO>;
234}
235
236impl<E: CubePrimitive, IO: SliceVisibility> CubeType for &Slice<E, IO> {
237 type ExpandType = SliceExpand<E, IO>;
238}
239
240impl<E: CubePrimitive, IO: SliceVisibility> CubeType for &mut Slice<E, IO> {
241 type ExpandType = SliceExpand<E, IO>;
242}
243
244impl<E: CubePrimitive, IO: SliceVisibility> IntoMut for SliceExpand<E, IO> {
245 fn into_mut(self, _scope: &mut cubecl_ir::Scope) -> Self {
246 self
247 }
248}
249
250impl<E: CubePrimitive, IO: SliceVisibility> CubeDebug for SliceExpand<E, IO> {}
251impl<E: CubePrimitive, IO: SliceVisibility> Clone for SliceExpand<E, IO> {
252 fn clone(&self) -> Self {
253 Self {
254 origin: self.origin.clone(),
255 offset: self.offset.clone(),
256 length: self.length.clone(),
257 line_size: self.line_size,
258 io: PhantomData,
259 }
260 }
261}
262
263impl<E: CubePrimitive> SizedContainer for Slice<E, ReadOnly> {
265 type Item = E;
266}
267
268impl<E: CubePrimitive> Iterable<E> for SliceExpand<E, ReadOnly> {
269 fn expand(
270 self,
271 scope: &mut Scope,
272 mut body: impl FnMut(&mut Scope, <E as CubeType>::ExpandType),
273 ) {
274 let index_ty = Type::new(u32::as_type(scope));
275 let len: ExpandElement = self.length.clone().into();
276
277 let mut child = scope.child();
278 let i = child.create_local_restricted(index_ty);
279
280 let index = i.clone().into();
281 let item = index::expand(&mut child, self, index);
282 body(&mut child, item);
283
284 scope.register(Branch::RangeLoop(Box::new(RangeLoop {
285 i: *i,
286 start: 0usize.into(),
287 end: *len,
288 step: None,
289 inclusive: false,
290 scope: child,
291 })));
292 }
293
294 fn expand_unroll(
295 self,
296 _scope: &mut Scope,
297 _body: impl FnMut(&mut Scope, <E as CubeType>::ExpandType),
298 ) {
299 unimplemented!("Can't unroll slice iterator")
300 }
301}
302impl<E: CubePrimitive, IO: SliceVisibility> CubeIndex for Slice<E, IO> {
303 type Output = E;
304 type Idx = usize;
305
306 fn expand_index(
307 scope: &mut Scope,
308 array: Self::ExpandType,
309 index: ExpandElementTyped<usize>,
310 ) -> <Self::Output as CubeType>::ExpandType {
311 array.__expand_read_method(scope, index)
312 }
313}
314
315impl<E: CubePrimitive, IO: SliceVisibility> CubeIndexExpand for SliceExpand<E, IO> {
316 type Output = E::ExpandType;
317 type Idx = ExpandElementTyped<usize>;
318
319 fn expand_index(self, scope: &mut Scope, index: ExpandElementTyped<usize>) -> Self::Output {
320 self.__expand_read_method(scope, index)
321 }
322 fn expand_index_unchecked(
323 self,
324 scope: &mut Scope,
325 index: ExpandElementTyped<usize>,
326 ) -> Self::Output {
327 self.__expand_read_unchecked_method(scope, index)
328 }
329}
330
331impl<E: CubePrimitive, IO: SliceVisibility> List<E> for Slice<E, IO> {}
332impl<E: CubePrimitive, IO: SliceVisibility> ListExpand<E> for SliceExpand<E, IO> {
333 fn __expand_read_method(
334 &self,
335 scope: &mut cubecl_ir::Scope,
336 index: ExpandElementTyped<usize>,
337 ) -> <E as CubeType>::ExpandType {
338 read_offset::expand::<E>(
339 scope,
340 self.origin.clone(),
341 self.offset.clone(),
342 index,
343 self.line_size,
344 true,
345 )
346 }
347 fn __expand_read_unchecked_method(
348 &self,
349 scope: &mut cubecl_ir::Scope,
350 index: ExpandElementTyped<usize>,
351 ) -> <E as CubeType>::ExpandType {
352 read_offset::expand::<E>(
353 scope,
354 self.origin.clone(),
355 self.offset.clone(),
356 index,
357 self.line_size,
358 false,
359 )
360 }
361
362 fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<usize> {
363 Self::__expand_len(scope, self.clone())
364 }
365}
366
367impl<T: CubePrimitive, IO: SliceVisibility> Deref for Slice<T, IO> {
368 type Target = [T];
369
370 fn deref(&self) -> &Self::Target {
371 unexpanded!()
372 }
373}
374
375impl<T: CubePrimitive> DerefMut for Slice<T, ReadWrite> {
376 fn deref_mut(&mut self) -> &mut Self::Target {
377 unexpanded!()
378 }
379}
380
381impl<E: CubePrimitive, IO: SliceVisibility> Lined for Slice<E, IO> {}
382impl<E: CubePrimitive, IO: SliceVisibility> LinedExpand for SliceExpand<E, IO> {
383 fn line_size(&self) -> LineSize {
384 self.line_size.unwrap_or_else(|| self.origin.line_size())
385 }
386}
387
388impl<E: CubePrimitive> CubeIndexMut for Slice<E, ReadWrite> {
389 fn expand_index_mut(
390 scope: &mut Scope,
391 array: Self::ExpandType,
392 index: ExpandElementTyped<usize>,
393 value: ExpandElementTyped<E>,
394 ) {
395 array.__expand_write_method(scope, index, value)
396 }
397}
398
399impl<E: CubePrimitive> CubeIndexMutExpand for SliceExpand<E, ReadWrite> {
400 fn expand_index_mut(
401 self,
402 scope: &mut Scope,
403 index: ExpandElementTyped<usize>,
404 value: Self::Output,
405 ) {
406 self.__expand_write_method(scope, index, value)
407 }
408}
409
410impl<E: CubePrimitive> ListMut<E> for Slice<E, ReadWrite> {}
411impl<E: CubePrimitive> ListMutExpand<E> for SliceExpand<E, ReadWrite> {
412 fn __expand_write_method(
413 &self,
414 scope: &mut cubecl_ir::Scope,
415 index: ExpandElementTyped<usize>,
416 value: ExpandElementTyped<E>,
417 ) {
418 write_offset::expand::<E>(
419 scope,
420 self.origin.clone(),
421 self.offset.clone(),
422 index,
423 value,
424 self.line_size,
425 )
426 }
427}
428
429mod read_offset {
430 use super::*;
431
432 pub fn expand<E: CubePrimitive>(
433 scope: &mut cubecl::prelude::Scope,
434 origin: SliceOriginExpand<E>,
435 offset: <usize as cubecl::prelude::CubeType>::ExpandType,
436 index: <usize as cubecl::prelude::CubeType>::ExpandType,
437 line_size: Option<LineSize>,
438 checked: bool,
439 ) -> <E as cubecl::prelude::CubeType>::ExpandType {
440 let index = cubecl::frontend::add::expand(scope, offset, index);
441
442 match origin {
443 SliceOriginExpand::Tensor(expand) => {
444 expand_index_native::<Tensor<E>>(scope, expand, index, line_size, checked)
445 }
446 SliceOriginExpand::Array(expand) => {
447 expand_index_native::<Array<E>>(scope, expand, index, line_size, checked)
448 }
449 SliceOriginExpand::SharedMemory(expand) => {
450 expand_index_native::<SharedMemory<E>>(scope, expand, index, line_size, checked)
451 }
452 }
453 }
454}
455
456mod write_offset {
457 use super::*;
458
459 pub fn expand<E: CubePrimitive>(
460 scope: &mut cubecl::prelude::Scope,
461 origin: SliceOriginExpand<E>,
462 offset: <usize as cubecl::prelude::CubeType>::ExpandType,
463 index: <usize as cubecl::prelude::CubeType>::ExpandType,
464 value: <E as cubecl::prelude::CubeType>::ExpandType,
465 line_size: Option<LineSize>,
466 ) {
467 let index = cubecl::frontend::add::expand(scope, offset, index);
468
469 match origin {
470 SliceOriginExpand::Tensor(expand) => expand_index_assign_native::<Tensor<E>>(
471 scope, expand, index, value, line_size, true,
472 ),
473 SliceOriginExpand::Array(expand) => expand_index_assign_native::<Array<E>>(
474 scope, expand, index, value, line_size, false,
475 ),
476 SliceOriginExpand::SharedMemory(expand) => {
477 expand_index_assign_native::<SharedMemory<E>>(
478 scope, expand, index, value, line_size, false,
479 )
480 }
481 }
482 }
483}