1use std::marker::PhantomData;
2
3use crate::{self as cubecl, unexpanded};
4use cubecl::prelude::*;
5use cubecl_ir::{Branch, ElemType, ExpandElement, FloatKind, RangeLoop, Type, Variable};
6use cubecl_macros::intrinsic;
7
8#[derive(Clone, Copy)]
9pub struct ReadOnly;
10#[derive(Clone, Copy)]
11pub struct ReadWrite;
12
13#[derive(Clone, Copy)]
19pub struct Slice<E: CubePrimitive, IO: SliceVisibility = ReadOnly> {
20    _e: PhantomData<E>,
21    _io: PhantomData<IO>,
22    _offset: PhantomData<u32>,
23    length: u32,
24}
25
26#[derive(CubeType)]
27pub enum SliceOrigin<E: CubePrimitive> {
28    Tensor(Tensor<E>),
29    Array(Array<E>),
30    SharedMemory(SharedMemory<E>),
31}
32
33impl<E: CubePrimitive> SliceOriginExpand<E> {
34    pub fn line_size(&self) -> u32 {
35        match self {
36            SliceOriginExpand::Tensor(t) => t.line_size(),
37            SliceOriginExpand::Array(t) => t.line_size(),
38            SliceOriginExpand::SharedMemory(t) => t.line_size(),
39        }
40    }
41}
42
43impl<E: CubePrimitive, IO: SliceVisibility> Iterator for Slice<E, IO> {
44    type Item = E;
45
46    fn next(&mut self) -> Option<Self::Item> {
47        unexpanded!()
48    }
49}
50
51pub trait SliceVisibility: Clone + Copy + Send + Sync + 'static {}
52
53impl SliceVisibility for ReadOnly {}
54
55impl SliceVisibility for ReadWrite {}
56
57pub struct SliceExpand<E: CubePrimitive, IO: SliceVisibility> {
58    pub(crate) origin: SliceOriginExpand<E>,
59    pub(crate) io: PhantomData<IO>,
60    pub(crate) offset: ExpandElementTyped<u32>,
61    pub(crate) length: ExpandElementTyped<u32>,
62    pub(crate) line_size: Option<u32>,
63}
64
65impl<E: CubePrimitive, IO: SliceVisibility> SliceExpand<E, IO> {
66    pub fn __to_raw_parts(&self) -> (Variable, Variable) {
67        let expand = match self.origin.clone() {
68            SliceOriginExpand::Tensor(expand) => expand.expand,
69            SliceOriginExpand::Array(expand) => expand.expand,
70            SliceOriginExpand::SharedMemory(expand) => expand.expand,
71        };
72
73        (*expand, *self.offset.expand)
74    }
75}
76
77#[cube]
78impl<E: CubePrimitive, IO: SliceVisibility> Slice<Line<E>, IO> {
79    #[allow(unused_variables)]
85    pub fn with_line_size(&self, #[comptime] line_size: u32) -> Slice<Line<E>, IO> {
86        intrinsic!(|scope| {
87            let (input, offset) = self.__to_raw_parts();
88            let mut item = input.ty;
89
90            if line_size == item.line_size() {
91                return self;
92            }
93
94            let current = input.ty.line_size();
95            let mut out = self.clone();
96
97            if current < line_size {
98                let ratio = line_size / current;
99                let length = cubecl::frontend::div::expand(scope, self.length, ratio.into());
100                let offset = cubecl::frontend::div::expand(scope, self.offset, ratio.into());
101                out.length = length;
102                out.offset = offset;
103            } else {
104                let ratio = current / line_size;
105                let length = cubecl::frontend::mul::expand(scope, self.length, ratio.into());
106                let offset = cubecl::frontend::mul::expand(scope, self.offset, ratio.into());
107                out.length = length;
108                out.offset = offset;
109            }
110
111            out.line_size = Some(line_size);
112            out
113        })
114    }
115}
116
117#[cube]
118impl<E: CubePrimitive, IO: SliceVisibility> Slice<E, IO> {
119    pub fn into_lined(&self) -> Slice<Line<E>, IO> {
121        intrinsic!(|_scope| {
122            SliceExpand::<Line<E>, IO> {
123                origin: self.origin.cast_unchecked(),
124                io: self.io.clone(),
125                offset: self.offset.clone(),
126                length: self.length.clone(),
127                line_size: None,
128            }
129        })
130    }
131    pub fn try_cast_unchecked<T: CubePrimitive>(&self) -> Slice<T, IO> {
137        intrinsic!(|scope| {
138            if T::as_type(scope) != E::as_type(scope) && !is_tf32::<E, T>(scope) {
139                let elems = [T::as_type(scope).elem_type(), E::as_type(scope).elem_type()];
140                let is_flex32_cast = elems.contains(&ElemType::Float(FloatKind::F32))
141                    && elems.contains(&ElemType::Float(FloatKind::Flex32));
142
143                if !is_flex32_cast {
144                    panic!(
145                        "Try cast unchecked should only be used to satisfy the rust type system."
146                    )
147                }
148            }
149
150            SliceExpand::<T, IO> {
151                origin: self.origin.cast_unchecked(),
152                io: self.io.clone(),
153                offset: self.offset.clone(),
154                length: self.length.clone(),
155                line_size: self.line_size.clone(),
156            }
157        })
158    }
159}
160
161#[cube]
162impl<E: CubePrimitive> Slice<E, ReadOnly> {
163    pub fn as_mut_unchecked(&self) -> Slice<E, ReadWrite> {
164        intrinsic!(|scope| {
165            SliceExpand::<E, ReadWrite> {
166                origin: self.origin,
167                io: PhantomData,
168                offset: self.offset.clone(),
169                length: self.length.clone(),
170                line_size: self.line_size.clone(),
171            }
172        })
173    }
174}
175
176impl<E: CubePrimitive> SliceOriginExpand<E> {
177    fn cast_unchecked<T: CubePrimitive>(self) -> SliceOriginExpand<T> {
178        match self {
179            SliceOriginExpand::Tensor(expand) => {
180                SliceOriginExpand::<T>::Tensor(expand.expand.into())
181            }
182            SliceOriginExpand::Array(expand) => SliceOriginExpand::<T>::Array(expand.expand.into()),
183            SliceOriginExpand::SharedMemory(expand) => {
184                SliceOriginExpand::<T>::SharedMemory(expand.expand.into())
185            }
186        }
187    }
188}
189
190impl<E: CubePrimitive, IO: SliceVisibility> Slice<E, IO> {
191    pub fn new(_origin: SliceOrigin<E>, _offset: u32, _length: u32) -> Self {
192        unexpanded!()
193    }
194    pub fn __expand_new(
195        scope: &mut Scope,
196        origin: SliceOriginExpand<E>,
197        start: ExpandElementTyped<u32>,
198        end: ExpandElementTyped<u32>,
199    ) -> SliceExpand<E, IO> {
200        Self::__expand_new_expand(scope, origin, start, end)
201    }
202    pub fn __expand_new_expand(
203        scope: &mut Scope,
204        origin: SliceOriginExpand<E>,
205        start: ExpandElementTyped<u32>,
206        end: ExpandElementTyped<u32>,
207    ) -> SliceExpand<E, IO> {
208        let length = cubecl::frontend::sub::expand(scope, end, start.clone());
209
210        SliceExpand::<E, IO> {
211            origin,
212            io: PhantomData,
213            offset: start,
214            length,
215            line_size: None,
216        }
217    }
218}
219
220#[cube]
221impl<E: CubePrimitive, IO: SliceVisibility> Slice<E, IO> {
222    pub fn len(&self) -> u32 {
224        self.length
225    }
226    pub fn is_empty(&self) -> bool {
228        self.length == 0
229    }
230}
231
232impl<E: CubePrimitive, IO: SliceVisibility> CubeType for Slice<E, IO> {
233    type ExpandType = SliceExpand<E, IO>;
234}
235
236impl<E: CubePrimitive, IO: SliceVisibility> CubeType for &Slice<E, IO> {
237    type ExpandType = SliceExpand<E, IO>;
238}
239
240impl<E: CubePrimitive, IO: SliceVisibility> CubeType for &mut Slice<E, IO> {
241    type ExpandType = SliceExpand<E, IO>;
242}
243
244impl<E: CubePrimitive, IO: SliceVisibility> IntoMut for SliceExpand<E, IO> {
245    fn into_mut(self, _scope: &mut cubecl_ir::Scope) -> Self {
246        self
247    }
248}
249
250impl<E: CubePrimitive, IO: SliceVisibility> CubeDebug for SliceExpand<E, IO> {}
251impl<E: CubePrimitive, IO: SliceVisibility> Clone for SliceExpand<E, IO> {
252    fn clone(&self) -> Self {
253        Self {
254            origin: self.origin.clone(),
255            offset: self.offset.clone(),
256            length: self.length.clone(),
257            line_size: self.line_size,
258            io: PhantomData,
259        }
260    }
261}
262
263impl<E: CubePrimitive> SizedContainer for Slice<E, ReadOnly> {
265    type Item = E;
266}
267
268impl<E: CubePrimitive> Iterable<E> for SliceExpand<E, ReadOnly> {
269    fn expand(
270        self,
271        scope: &mut Scope,
272        mut body: impl FnMut(&mut Scope, <E as CubeType>::ExpandType),
273    ) {
274        let index_ty = Type::new(u32::as_type(scope));
275        let len: ExpandElement = self.length.clone().into();
276
277        let mut child = scope.child();
278        let i = child.create_local_restricted(index_ty);
279
280        let index = i.clone().into();
281        let item = index::expand(&mut child, self, index);
282        body(&mut child, item);
283
284        scope.register(Branch::RangeLoop(Box::new(RangeLoop {
285            i: *i,
286            start: 0u32.into(),
287            end: *len,
288            step: None,
289            inclusive: false,
290            scope: child,
291        })));
292    }
293
294    fn expand_unroll(
295        self,
296        _scope: &mut Scope,
297        _body: impl FnMut(&mut Scope, <E as CubeType>::ExpandType),
298    ) {
299        unimplemented!("Can't unroll slice iterator")
300    }
301}
302impl<E: CubePrimitive, IO: SliceVisibility> CubeIndex for Slice<E, IO> {
303    type Output = E;
304    type Idx = u32;
305
306    fn expand_index(
307        scope: &mut Scope,
308        array: Self::ExpandType,
309        index: ExpandElementTyped<u32>,
310    ) -> <Self::Output as CubeType>::ExpandType {
311        array.__expand_read_method(scope, index)
312    }
313}
314
315impl<E: CubePrimitive, IO: SliceVisibility> CubeIndexExpand for SliceExpand<E, IO> {
316    type Output = E::ExpandType;
317    type Idx = ExpandElementTyped<u32>;
318
319    fn expand_index(self, scope: &mut Scope, index: ExpandElementTyped<u32>) -> Self::Output {
320        self.__expand_read_method(scope, index)
321    }
322    fn expand_index_unchecked(
323        self,
324        scope: &mut Scope,
325        index: ExpandElementTyped<u32>,
326    ) -> Self::Output {
327        self.__expand_read_unchecked_method(scope, index)
328    }
329}
330
331impl<E: CubePrimitive, IO: SliceVisibility> List<E> for Slice<E, IO> {}
332impl<E: CubePrimitive, IO: SliceVisibility> ListExpand<E> for SliceExpand<E, IO> {
333    fn __expand_read_method(
334        &self,
335        scope: &mut cubecl_ir::Scope,
336        index: ExpandElementTyped<u32>,
337    ) -> <E as CubeType>::ExpandType {
338        read_offset::expand::<E>(
339            scope,
340            self.origin.clone(),
341            self.offset.clone(),
342            index,
343            self.line_size,
344            true,
345        )
346    }
347    fn __expand_read_unchecked_method(
348        &self,
349        scope: &mut cubecl_ir::Scope,
350        index: ExpandElementTyped<u32>,
351    ) -> <E as CubeType>::ExpandType {
352        read_offset::expand::<E>(
353            scope,
354            self.origin.clone(),
355            self.offset.clone(),
356            index,
357            self.line_size,
358            false,
359        )
360    }
361
362    fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
363        Self::__expand_len(scope, self.clone())
364    }
365}
366
367impl<E: CubePrimitive, IO: SliceVisibility> Lined for Slice<E, IO> {}
368impl<E: CubePrimitive, IO: SliceVisibility> LinedExpand for SliceExpand<E, IO> {
369    fn line_size(&self) -> u32 {
370        self.line_size.unwrap_or_else(|| self.origin.line_size())
371    }
372}
373
374impl<E: CubePrimitive> CubeIndexMut for Slice<E, ReadWrite> {
375    fn expand_index_mut(
376        scope: &mut Scope,
377        array: Self::ExpandType,
378        index: ExpandElementTyped<u32>,
379        value: ExpandElementTyped<E>,
380    ) {
381        array.__expand_write_method(scope, index, value)
382    }
383}
384
385impl<E: CubePrimitive> CubeIndexMutExpand for SliceExpand<E, ReadWrite> {
386    fn expand_index_mut(
387        self,
388        scope: &mut Scope,
389        index: ExpandElementTyped<u32>,
390        value: Self::Output,
391    ) {
392        self.__expand_write_method(scope, index, value)
393    }
394}
395
396impl<E: CubePrimitive> ListMut<E> for Slice<E, ReadWrite> {}
397impl<E: CubePrimitive> ListMutExpand<E> for SliceExpand<E, ReadWrite> {
398    fn __expand_write_method(
399        &self,
400        scope: &mut cubecl_ir::Scope,
401        index: ExpandElementTyped<u32>,
402        value: ExpandElementTyped<E>,
403    ) {
404        write_offset::expand::<E>(
405            scope,
406            self.origin.clone(),
407            self.offset.clone(),
408            index,
409            value,
410            self.line_size,
411        )
412    }
413}
414
415mod read_offset {
416    use super::*;
417
418    pub fn expand<E: CubePrimitive>(
419        scope: &mut cubecl::prelude::Scope,
420        origin: SliceOriginExpand<E>,
421        offset: <u32 as cubecl::prelude::CubeType>::ExpandType,
422        index: <u32 as cubecl::prelude::CubeType>::ExpandType,
423        line_size: Option<u32>,
424        checked: bool,
425    ) -> <E as cubecl::prelude::CubeType>::ExpandType {
426        let index = cubecl::frontend::add::expand(scope, offset, index);
427
428        match origin {
429            SliceOriginExpand::Tensor(expand) => {
430                expand_index_native::<Tensor<E>>(scope, expand, index, line_size, checked)
431            }
432            SliceOriginExpand::Array(expand) => {
433                expand_index_native::<Array<E>>(scope, expand, index, line_size, checked)
434            }
435            SliceOriginExpand::SharedMemory(expand) => {
436                expand_index_native::<SharedMemory<E>>(scope, expand, index, line_size, checked)
437            }
438        }
439    }
440}
441
442mod write_offset {
443    use super::*;
444
445    pub fn expand<E: CubePrimitive>(
446        scope: &mut cubecl::prelude::Scope,
447        origin: SliceOriginExpand<E>,
448        offset: <u32 as cubecl::prelude::CubeType>::ExpandType,
449        index: <u32 as cubecl::prelude::CubeType>::ExpandType,
450        value: <E as cubecl::prelude::CubeType>::ExpandType,
451        line_size: Option<u32>,
452    ) {
453        let index = cubecl::frontend::add::expand(scope, offset, index);
454
455        match origin {
456            SliceOriginExpand::Tensor(expand) => expand_index_assign_native::<Tensor<E>>(
457                scope, expand, index, value, line_size, true,
458            ),
459            SliceOriginExpand::Array(expand) => expand_index_assign_native::<Array<E>>(
460                scope, expand, index, value, line_size, false,
461            ),
462            SliceOriginExpand::SharedMemory(expand) => {
463                expand_index_assign_native::<SharedMemory<E>>(
464                    scope, expand, index, value, line_size, false,
465                )
466            }
467        }
468    }
469}