cubecl_core/frontend/container/array/
base.rs

1use std::{marker::PhantomData, num::NonZero};
2
3use cubecl_ir::{ExpandElement, Scope};
4
5use crate::frontend::{CubePrimitive, ExpandElementBaseInit, ExpandElementTyped};
6use crate::prelude::{List, ListExpand, ListMut, ListMutExpand, SizedContainer};
7use crate::{
8    frontend::CubeType,
9    ir::{Item, Metadata},
10    unexpanded,
11};
12use crate::{
13    frontend::indexation::Index,
14    prelude::{assign, index, index_assign},
15};
16
17/// A contiguous array of elements.
18pub struct Array<E> {
19    _val: PhantomData<E>,
20}
21
22/// Module that contains the implementation details of the new function.
23mod new {
24    use super::*;
25    use crate::ir::Variable;
26
27    impl<T: CubePrimitive + Clone> Array<T> {
28        /// Create a new array of the given length.
29        #[allow(unused_variables)]
30        pub fn new<L: Index>(length: L) -> Self {
31            Array { _val: PhantomData }
32        }
33
34        /// Create an array from data.
35        pub fn from_data<C: CubePrimitive>(_data: impl IntoIterator<Item = C>) -> Self {
36            Array { _val: PhantomData }
37        }
38
39        /// Expand function of [new](Array::new).
40        pub fn __expand_new(
41            scope: &mut Scope,
42            size: ExpandElementTyped<u32>,
43        ) -> <Self as CubeType>::ExpandType {
44            let size = size
45                .constant()
46                .expect("Array need constant initialization value")
47                .as_u32();
48            let elem = T::as_elem(scope);
49            scope.create_local_array(Item::new(elem), size).into()
50        }
51
52        /// Expand function of [from_data](Array::from_data).
53        pub fn __expand_from_data<C: CubePrimitive>(
54            scope: &mut Scope,
55            data: ArrayData<C>,
56        ) -> <Self as CubeType>::ExpandType {
57            let var = scope.create_const_array(Item::new(T::as_elem(scope)), data.values);
58            ExpandElementTyped::new(var)
59        }
60    }
61
62    /// Type useful for the expand function of [from_data](Array::from_data).
63    pub struct ArrayData<C> {
64        values: Vec<Variable>,
65        _ty: PhantomData<C>,
66    }
67
68    impl<C: CubePrimitive + Into<ExpandElementTyped<C>>, T: IntoIterator<Item = C>> From<T>
69        for ArrayData<C>
70    {
71        fn from(value: T) -> Self {
72            let values: Vec<Variable> = value
73                .into_iter()
74                .map(|value| {
75                    let value: ExpandElementTyped<C> = value.into();
76                    *value.expand
77                })
78                .collect();
79            ArrayData {
80                values,
81                _ty: PhantomData,
82            }
83        }
84    }
85}
86
87/// Module that contains the implementation details of the line_size function.
88mod line {
89    use crate::prelude::Line;
90
91    use super::*;
92
93    impl<P: CubePrimitive> Array<Line<P>> {
94        /// Get the size of each line contained in the tensor.
95        ///
96        /// Same as the following:
97        ///
98        /// ```rust, ignore
99        /// let size = tensor[0].size();
100        /// ```
101        pub fn line_size(&self) -> u32 {
102            unexpanded!()
103        }
104
105        // Expand function of [size](Tensor::line_size).
106        pub fn __expand_line_size(
107            expand: <Self as CubeType>::ExpandType,
108            scope: &mut Scope,
109        ) -> u32 {
110            expand.__expand_line_size_method(scope)
111        }
112    }
113
114    impl<P: CubePrimitive> ExpandElementTyped<Array<Line<P>>> {
115        /// Comptime version of [size](Array::line_size).
116        pub fn line_size(&self) -> u32 {
117            self.expand
118                .item
119                .vectorization
120                .unwrap_or(NonZero::new(1).unwrap())
121                .get() as u32
122        }
123
124        // Expand method of [size](Array::line_size).
125        pub fn __expand_line_size_method(&self, _content: &mut Scope) -> u32 {
126            self.line_size()
127        }
128    }
129}
130
131/// Module that contains the implementation details of vectorization functions.
132///
133/// TODO: Remove vectorization in favor of the line API.
134mod vectorization {
135    use super::*;
136
137    impl<T: CubePrimitive + Clone> Array<T> {
138        #[allow(unused_variables)]
139        pub fn vectorized<L: Index>(length: L, vectorization_factor: u32) -> Self {
140            Array { _val: PhantomData }
141        }
142
143        pub fn to_vectorized(self, _vectorization_factor: u32) -> T {
144            unexpanded!()
145        }
146
147        pub fn __expand_vectorized(
148            scope: &mut Scope,
149            size: ExpandElementTyped<u32>,
150            vectorization_factor: u32,
151        ) -> <Self as CubeType>::ExpandType {
152            let size = size.value();
153            let size = match size.kind {
154                crate::ir::VariableKind::ConstantScalar(value) => value.as_u32(),
155                _ => panic!("Shared memory need constant initialization value"),
156            };
157            scope
158                .create_local_array(
159                    Item::vectorized(T::as_elem(scope), NonZero::new(vectorization_factor as u8)),
160                    size,
161                )
162                .into()
163        }
164    }
165
166    impl<C: CubePrimitive> ExpandElementTyped<Array<C>> {
167        pub fn __expand_to_vectorized_method(
168            self,
169            scope: &mut Scope,
170            vectorization_factor: ExpandElementTyped<u32>,
171        ) -> ExpandElementTyped<C> {
172            let factor = vectorization_factor
173                .constant()
174                .expect("Vectorization must be comptime")
175                .as_u32();
176            let var = self.expand.clone();
177            let item = Item::vectorized(var.item.elem(), NonZero::new(factor as u8));
178
179            let new_var = if factor == 1 {
180                let new_var = scope.create_local(item);
181                let element = index::expand(
182                    scope,
183                    self.clone(),
184                    ExpandElementTyped::from_lit(scope, 0u32),
185                );
186                assign::expand::<C>(scope, element, new_var.clone().into());
187                new_var
188            } else {
189                let new_var = scope.create_local_mut(item);
190                for i in 0..factor {
191                    let expand: Self = self.expand.clone().into();
192                    let element =
193                        index::expand(scope, expand, ExpandElementTyped::from_lit(scope, i));
194                    index_assign::expand::<Array<C>>(
195                        scope,
196                        new_var.clone().into(),
197                        ExpandElementTyped::from_lit(scope, i),
198                        element,
199                    );
200                }
201                new_var
202            };
203            new_var.into()
204        }
205    }
206}
207
208/// Module that contains the implementation details of the metadata functions.
209mod metadata {
210    use crate::ir::Instruction;
211
212    use super::*;
213
214    impl<E: CubeType> Array<E> {
215        /// Obtain the array length
216        #[allow(clippy::len_without_is_empty)]
217        pub fn len(&self) -> u32 {
218            unexpanded!()
219        }
220
221        /// Obtain the array buffer length
222        pub fn buffer_len(&self) -> u32 {
223            unexpanded!()
224        }
225    }
226
227    impl<T: CubeType> ExpandElementTyped<Array<T>> {
228        // Expand method of [len](Array::len).
229        pub fn __expand_len_method(self, scope: &mut Scope) -> ExpandElementTyped<u32> {
230            let out = scope.create_local(Item::new(u32::as_elem(scope)));
231            scope.register(Instruction::new(
232                Metadata::Length {
233                    var: self.expand.into(),
234                },
235                out.clone().into(),
236            ));
237            out.into()
238        }
239
240        // Expand method of [buffer_len](Array::buffer_len).
241        pub fn __expand_buffer_len_method(self, scope: &mut Scope) -> ExpandElementTyped<u32> {
242            let out = scope.create_local(Item::new(u32::as_elem(scope)));
243            scope.register(Instruction::new(
244                Metadata::BufferLength {
245                    var: self.expand.into(),
246                },
247                out.clone().into(),
248            ));
249            out.into()
250        }
251    }
252}
253
254/// Module that contains the implementation details of the index functions.
255mod indexation {
256    use cubecl_ir::Operator;
257
258    use crate::{
259        ir::{BinaryOperator, Instruction},
260        prelude::{CubeIndex, CubeIndexMut},
261    };
262
263    use super::*;
264
265    impl<E: CubePrimitive> Array<E> {
266        /// Perform an unchecked index into the array
267        ///
268        /// # Safety
269        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
270        /// always in bounds
271        pub unsafe fn index_unchecked<I: Index>(&self, _i: I) -> &E
272        where
273            Self: CubeIndex<I>,
274        {
275            unexpanded!()
276        }
277
278        /// Perform an unchecked index assignment into the array
279        ///
280        /// # Safety
281        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
282        /// always in bounds
283        pub unsafe fn index_assign_unchecked<I: Index>(&mut self, _i: I, _value: E)
284        where
285            Self: CubeIndexMut<I>,
286        {
287            unexpanded!()
288        }
289    }
290
291    impl<E: CubePrimitive> ExpandElementTyped<Array<E>> {
292        pub fn __expand_index_unchecked_method(
293            self,
294            scope: &mut Scope,
295            i: ExpandElementTyped<u32>,
296        ) -> ExpandElementTyped<E> {
297            let out = scope.create_local(self.expand.item);
298            scope.register(Instruction::new(
299                Operator::UncheckedIndex(BinaryOperator {
300                    lhs: *self.expand,
301                    rhs: i.expand.consume(),
302                }),
303                *out,
304            ));
305            out.into()
306        }
307
308        pub fn __expand_index_assign_unchecked_method(
309            self,
310            scope: &mut Scope,
311            i: ExpandElementTyped<u32>,
312            value: ExpandElementTyped<E>,
313        ) {
314            scope.register(Instruction::new(
315                Operator::UncheckedIndexAssign(BinaryOperator {
316                    lhs: i.expand.consume(),
317                    rhs: value.expand.consume(),
318                }),
319                *self.expand,
320            ));
321        }
322    }
323}
324
325impl<C: CubeType> CubeType for Array<C> {
326    type ExpandType = ExpandElementTyped<Array<C>>;
327}
328
329impl<C: CubeType> CubeType for &Array<C> {
330    type ExpandType = ExpandElementTyped<Array<C>>;
331}
332
333impl<C: CubeType> ExpandElementBaseInit for Array<C> {
334    fn init_elem(_scope: &mut crate::ir::Scope, elem: ExpandElement) -> ExpandElement {
335        // The type can't be deeply cloned/copied.
336        elem
337    }
338}
339
340impl<T: CubeType<ExpandType = ExpandElementTyped<T>>> SizedContainer for Array<T> {
341    type Item = T;
342}
343
344impl<T: CubeType> Iterator for &Array<T> {
345    type Item = T;
346
347    fn next(&mut self) -> Option<Self::Item> {
348        unexpanded!()
349    }
350}
351
352impl<T: CubePrimitive> List<T> for Array<T> {
353    fn __expand_read(
354        scope: &mut Scope,
355        this: ExpandElementTyped<Array<T>>,
356        idx: ExpandElementTyped<u32>,
357    ) -> ExpandElementTyped<T> {
358        index::expand(scope, this, idx)
359    }
360}
361
362impl<T: CubePrimitive> ListExpand<T> for ExpandElementTyped<Array<T>> {
363    fn __expand_read_method(
364        self,
365        scope: &mut Scope,
366        idx: ExpandElementTyped<u32>,
367    ) -> ExpandElementTyped<T> {
368        index::expand(scope, self, idx)
369    }
370}
371
372impl<T: CubePrimitive> ListMut<T> for Array<T> {
373    fn __expand_write(
374        scope: &mut Scope,
375        this: ExpandElementTyped<Array<T>>,
376        idx: ExpandElementTyped<u32>,
377        value: ExpandElementTyped<T>,
378    ) {
379        index_assign::expand(scope, this, idx, value);
380    }
381}
382
383impl<T: CubePrimitive> ListMutExpand<T> for ExpandElementTyped<Array<T>> {
384    fn __expand_write_method(
385        self,
386        scope: &mut Scope,
387        idx: ExpandElementTyped<u32>,
388        value: ExpandElementTyped<T>,
389    ) {
390        index_assign::expand(scope, self, idx, value);
391    }
392}