1use std::marker::PhantomData;
2
3use crate::{self as cubecl, unexpanded};
4use cubecl::prelude::*;
5use cubecl_ir::{Branch, ElemType, ExpandElement, FloatKind, RangeLoop, Type, Variable};
6use cubecl_macros::intrinsic;
7
8#[derive(Clone, Copy)]
9pub struct ReadOnly;
10#[derive(Clone, Copy)]
11pub struct ReadWrite;
12
13#[derive(Clone, Copy)]
19pub struct Slice<E: CubePrimitive, IO: SliceVisibility = ReadOnly> {
20 _e: PhantomData<E>,
21 _io: PhantomData<IO>,
22 _offset: PhantomData<u32>,
23 length: u32,
24}
25
26#[derive(CubeType)]
27pub enum SliceOrigin<E: CubePrimitive> {
28 Tensor(Tensor<E>),
29 Array(Array<E>),
30 SharedMemory(SharedMemory<E>),
31}
32
33impl<E: CubePrimitive> SliceOriginExpand<E> {
34 pub fn line_size(&self) -> u32 {
35 match self {
36 SliceOriginExpand::Tensor(t) => t.line_size(),
37 SliceOriginExpand::Array(t) => t.line_size(),
38 SliceOriginExpand::SharedMemory(t) => t.line_size(),
39 }
40 }
41}
42
43impl<E: CubePrimitive, IO: SliceVisibility> Iterator for Slice<E, IO> {
44 type Item = E;
45
46 fn next(&mut self) -> Option<Self::Item> {
47 unexpanded!()
48 }
49}
50
51pub trait SliceVisibility: Clone + Copy + Send + Sync + 'static {}
52
53impl SliceVisibility for ReadOnly {}
54
55impl SliceVisibility for ReadWrite {}
56
57pub struct SliceExpand<E: CubePrimitive, IO: SliceVisibility> {
58 pub(crate) origin: SliceOriginExpand<E>,
59 pub(crate) io: PhantomData<IO>,
60 pub(crate) offset: ExpandElementTyped<u32>,
61 pub(crate) length: ExpandElementTyped<u32>,
62 pub(crate) line_size: Option<u32>,
63}
64
65impl<E: CubePrimitive, IO: SliceVisibility> SliceExpand<E, IO> {
66 pub fn __to_raw_parts(&self) -> (Variable, Variable) {
67 let expand = match self.origin.clone() {
68 SliceOriginExpand::Tensor(expand) => expand.expand,
69 SliceOriginExpand::Array(expand) => expand.expand,
70 SliceOriginExpand::SharedMemory(expand) => expand.expand,
71 };
72
73 (*expand, *self.offset.expand)
74 }
75}
76
77#[cube]
78impl<E: CubePrimitive, IO: SliceVisibility> Slice<Line<E>, IO> {
79 #[allow(unused_variables)]
85 pub fn with_line_size(&self, #[comptime] line_size: u32) -> Slice<Line<E>, IO> {
86 intrinsic!(|scope| {
87 let (input, offset) = self.__to_raw_parts();
88 let mut item = input.ty;
89
90 if line_size == item.line_size() {
91 return self;
92 }
93
94 let current = input.ty.line_size();
95 let mut out = self.clone();
96
97 if current < line_size {
98 let ratio = line_size / current;
99 let length = cubecl::frontend::div::expand(scope, self.length, ratio.into());
100 let offset = cubecl::frontend::div::expand(scope, self.offset, ratio.into());
101 out.length = length;
102 out.offset = offset;
103 } else {
104 let ratio = current / line_size;
105 let length = cubecl::frontend::mul::expand(scope, self.length, ratio.into());
106 let offset = cubecl::frontend::mul::expand(scope, self.offset, ratio.into());
107 out.length = length;
108 out.offset = offset;
109 }
110
111 out.line_size = Some(line_size);
112 out
113 })
114 }
115}
116
117#[cube]
118impl<E: CubePrimitive, IO: SliceVisibility> Slice<E, IO> {
119 pub fn into_lined(&self) -> Slice<Line<E>, IO> {
121 intrinsic!(|_scope| {
122 SliceExpand::<Line<E>, IO> {
123 origin: self.origin.cast_unchecked(),
124 io: self.io.clone(),
125 offset: self.offset.clone(),
126 length: self.length.clone(),
127 line_size: None,
128 }
129 })
130 }
131 pub fn try_cast_unchecked<T: CubePrimitive>(&self) -> Slice<T, IO> {
137 intrinsic!(|scope| {
138 if T::as_type(scope) != E::as_type(scope) && !is_tf32::<E, T>(scope) {
139 let elems = [T::as_type(scope).elem_type(), E::as_type(scope).elem_type()];
140 let is_flex32_cast = elems.contains(&ElemType::Float(FloatKind::F32))
141 && elems.contains(&ElemType::Float(FloatKind::Flex32));
142
143 if !is_flex32_cast {
144 panic!(
145 "Try cast unchecked should only be used to satisfy the rust type system."
146 )
147 }
148 }
149
150 SliceExpand::<T, IO> {
151 origin: self.origin.cast_unchecked(),
152 io: self.io.clone(),
153 offset: self.offset.clone(),
154 length: self.length.clone(),
155 line_size: self.line_size.clone(),
156 }
157 })
158 }
159}
160
161#[cube]
162impl<E: CubePrimitive> Slice<E, ReadOnly> {
163 pub fn as_mut_unchecked(&self) -> Slice<E, ReadWrite> {
164 intrinsic!(|scope| {
165 SliceExpand::<E, ReadWrite> {
166 origin: self.origin,
167 io: PhantomData,
168 offset: self.offset.clone(),
169 length: self.length.clone(),
170 line_size: self.line_size.clone(),
171 }
172 })
173 }
174}
175
176impl<E: CubePrimitive> SliceOriginExpand<E> {
177 fn cast_unchecked<T: CubePrimitive>(self) -> SliceOriginExpand<T> {
178 match self {
179 SliceOriginExpand::Tensor(expand) => {
180 SliceOriginExpand::<T>::Tensor(expand.expand.into())
181 }
182 SliceOriginExpand::Array(expand) => SliceOriginExpand::<T>::Array(expand.expand.into()),
183 SliceOriginExpand::SharedMemory(expand) => {
184 SliceOriginExpand::<T>::SharedMemory(expand.expand.into())
185 }
186 }
187 }
188}
189
190impl<E: CubePrimitive, IO: SliceVisibility> Slice<E, IO> {
191 pub fn new(_origin: SliceOrigin<E>, _offset: u32, _length: u32) -> Self {
192 unexpanded!()
193 }
194 pub fn __expand_new(
195 scope: &mut Scope,
196 origin: SliceOriginExpand<E>,
197 start: ExpandElementTyped<u32>,
198 end: ExpandElementTyped<u32>,
199 ) -> SliceExpand<E, IO> {
200 Self::__expand_new_expand(scope, origin, start, end)
201 }
202 pub fn __expand_new_expand(
203 scope: &mut Scope,
204 origin: SliceOriginExpand<E>,
205 start: ExpandElementTyped<u32>,
206 end: ExpandElementTyped<u32>,
207 ) -> SliceExpand<E, IO> {
208 let length = cubecl::frontend::sub::expand(scope, end, start.clone());
209
210 SliceExpand::<E, IO> {
211 origin,
212 io: PhantomData,
213 offset: start,
214 length,
215 line_size: None,
216 }
217 }
218}
219
220#[cube]
221impl<E: CubePrimitive, IO: SliceVisibility> Slice<E, IO> {
222 pub fn len(&self) -> u32 {
224 self.length
225 }
226 pub fn is_empty(&self) -> bool {
228 self.length == 0
229 }
230}
231
232impl<E: CubePrimitive, IO: SliceVisibility> CubeType for Slice<E, IO> {
233 type ExpandType = SliceExpand<E, IO>;
234}
235
236impl<E: CubePrimitive, IO: SliceVisibility> CubeType for &Slice<E, IO> {
237 type ExpandType = SliceExpand<E, IO>;
238}
239
240impl<E: CubePrimitive, IO: SliceVisibility> CubeType for &mut Slice<E, IO> {
241 type ExpandType = SliceExpand<E, IO>;
242}
243
244impl<E: CubePrimitive, IO: SliceVisibility> IntoMut for SliceExpand<E, IO> {
245 fn into_mut(self, _scope: &mut cubecl_ir::Scope) -> Self {
246 self
247 }
248}
249
250impl<E: CubePrimitive, IO: SliceVisibility> CubeDebug for SliceExpand<E, IO> {}
251impl<E: CubePrimitive, IO: SliceVisibility> Clone for SliceExpand<E, IO> {
252 fn clone(&self) -> Self {
253 Self {
254 origin: self.origin.clone(),
255 offset: self.offset.clone(),
256 length: self.length.clone(),
257 line_size: self.line_size,
258 io: PhantomData,
259 }
260 }
261}
262
263impl<E: CubePrimitive> SizedContainer for Slice<E, ReadOnly> {
265 type Item = E;
266}
267
268impl<E: CubePrimitive> Iterable<E> for SliceExpand<E, ReadOnly> {
269 fn expand(
270 self,
271 scope: &mut Scope,
272 mut body: impl FnMut(&mut Scope, <E as CubeType>::ExpandType),
273 ) {
274 let index_ty = Type::new(u32::as_type(scope));
275 let len: ExpandElement = self.length.clone().into();
276
277 let mut child = scope.child();
278 let i = child.create_local_restricted(index_ty);
279
280 let index = i.clone().into();
281 let item = index::expand(&mut child, self, index);
282 body(&mut child, item);
283
284 scope.register(Branch::RangeLoop(Box::new(RangeLoop {
285 i: *i,
286 start: 0u32.into(),
287 end: *len,
288 step: None,
289 inclusive: false,
290 scope: child,
291 })));
292 }
293
294 fn expand_unroll(
295 self,
296 _scope: &mut Scope,
297 _body: impl FnMut(&mut Scope, <E as CubeType>::ExpandType),
298 ) {
299 unimplemented!("Can't unroll slice iterator")
300 }
301}
302impl<E: CubePrimitive, IO: SliceVisibility> CubeIndex for Slice<E, IO> {
303 type Output = E;
304 type Idx = u32;
305
306 fn expand_index(
307 scope: &mut Scope,
308 array: Self::ExpandType,
309 index: ExpandElementTyped<u32>,
310 ) -> <Self::Output as CubeType>::ExpandType {
311 array.__expand_read_method(scope, index)
312 }
313}
314
315impl<E: CubePrimitive, IO: SliceVisibility> CubeIndexExpand for SliceExpand<E, IO> {
316 type Output = E::ExpandType;
317 type Idx = ExpandElementTyped<u32>;
318
319 fn expand_index(self, scope: &mut Scope, index: ExpandElementTyped<u32>) -> Self::Output {
320 self.__expand_read_method(scope, index)
321 }
322 fn expand_index_unchecked(
323 self,
324 scope: &mut Scope,
325 index: ExpandElementTyped<u32>,
326 ) -> Self::Output {
327 self.__expand_read_unchecked_method(scope, index)
328 }
329}
330
331impl<E: CubePrimitive, IO: SliceVisibility> List<E> for Slice<E, IO> {}
332impl<E: CubePrimitive, IO: SliceVisibility> ListExpand<E> for SliceExpand<E, IO> {
333 fn __expand_read_method(
334 &self,
335 scope: &mut cubecl_ir::Scope,
336 index: ExpandElementTyped<u32>,
337 ) -> <E as CubeType>::ExpandType {
338 read_offset::expand::<E>(
339 scope,
340 self.origin.clone(),
341 self.offset.clone(),
342 index,
343 self.line_size,
344 true,
345 )
346 }
347 fn __expand_read_unchecked_method(
348 &self,
349 scope: &mut cubecl_ir::Scope,
350 index: ExpandElementTyped<u32>,
351 ) -> <E as CubeType>::ExpandType {
352 read_offset::expand::<E>(
353 scope,
354 self.origin.clone(),
355 self.offset.clone(),
356 index,
357 self.line_size,
358 false,
359 )
360 }
361
362 fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
363 Self::__expand_len(scope, self.clone())
364 }
365}
366
367impl<E: CubePrimitive, IO: SliceVisibility> Lined for Slice<E, IO> {}
368impl<E: CubePrimitive, IO: SliceVisibility> LinedExpand for SliceExpand<E, IO> {
369 fn line_size(&self) -> u32 {
370 self.line_size.unwrap_or_else(|| self.origin.line_size())
371 }
372}
373
374impl<E: CubePrimitive> CubeIndexMut for Slice<E, ReadWrite> {
375 fn expand_index_mut(
376 scope: &mut Scope,
377 array: Self::ExpandType,
378 index: ExpandElementTyped<u32>,
379 value: ExpandElementTyped<E>,
380 ) {
381 array.__expand_write_method(scope, index, value)
382 }
383}
384
385impl<E: CubePrimitive> CubeIndexMutExpand for SliceExpand<E, ReadWrite> {
386 fn expand_index_mut(
387 self,
388 scope: &mut Scope,
389 index: ExpandElementTyped<u32>,
390 value: Self::Output,
391 ) {
392 self.__expand_write_method(scope, index, value)
393 }
394}
395
396impl<E: CubePrimitive> ListMut<E> for Slice<E, ReadWrite> {}
397impl<E: CubePrimitive> ListMutExpand<E> for SliceExpand<E, ReadWrite> {
398 fn __expand_write_method(
399 &self,
400 scope: &mut cubecl_ir::Scope,
401 index: ExpandElementTyped<u32>,
402 value: ExpandElementTyped<E>,
403 ) {
404 write_offset::expand::<E>(
405 scope,
406 self.origin.clone(),
407 self.offset.clone(),
408 index,
409 value,
410 self.line_size,
411 )
412 }
413}
414
415mod read_offset {
416 use super::*;
417
418 pub fn expand<E: CubePrimitive>(
419 scope: &mut cubecl::prelude::Scope,
420 origin: SliceOriginExpand<E>,
421 offset: <u32 as cubecl::prelude::CubeType>::ExpandType,
422 index: <u32 as cubecl::prelude::CubeType>::ExpandType,
423 line_size: Option<u32>,
424 checked: bool,
425 ) -> <E as cubecl::prelude::CubeType>::ExpandType {
426 let index = cubecl::frontend::add::expand(scope, offset, index);
427
428 match origin {
429 SliceOriginExpand::Tensor(expand) => {
430 expand_index_native::<Tensor<E>>(scope, expand, index, line_size, checked)
431 }
432 SliceOriginExpand::Array(expand) => {
433 expand_index_native::<Array<E>>(scope, expand, index, line_size, checked)
434 }
435 SliceOriginExpand::SharedMemory(expand) => {
436 expand_index_native::<SharedMemory<E>>(scope, expand, index, line_size, checked)
437 }
438 }
439 }
440}
441
442mod write_offset {
443 use super::*;
444
445 pub fn expand<E: CubePrimitive>(
446 scope: &mut cubecl::prelude::Scope,
447 origin: SliceOriginExpand<E>,
448 offset: <u32 as cubecl::prelude::CubeType>::ExpandType,
449 index: <u32 as cubecl::prelude::CubeType>::ExpandType,
450 value: <E as cubecl::prelude::CubeType>::ExpandType,
451 line_size: Option<u32>,
452 ) {
453 let index = cubecl::frontend::add::expand(scope, offset, index);
454
455 match origin {
456 SliceOriginExpand::Tensor(expand) => expand_index_assign_native::<Tensor<E>>(
457 scope, expand, index, value, line_size, true,
458 ),
459 SliceOriginExpand::Array(expand) => expand_index_assign_native::<Array<E>>(
460 scope, expand, index, value, line_size, false,
461 ),
462 SliceOriginExpand::SharedMemory(expand) => {
463 expand_index_assign_native::<SharedMemory<E>>(
464 scope, expand, index, value, line_size, false,
465 )
466 }
467 }
468 }
469}