cubecl_core/frontend/container/array/
base.rs

1use std::{marker::PhantomData, num::NonZero};
2
3use crate::frontend::{
4    CubePrimitive, ExpandElement, ExpandElementBaseInit, ExpandElementTyped, IntoRuntime,
5};
6use crate::prelude::SizedContainer;
7use crate::{
8    frontend::CubeType,
9    ir::{Item, Metadata},
10    unexpanded,
11};
12use crate::{
13    frontend::{indexation::Index, CubeContext},
14    prelude::{assign, index, index_assign},
15};
16
17/// A contiguous array of elements.
18pub struct Array<E> {
19    _val: PhantomData<E>,
20}
21
22/// Module that contains the implementation details of the new function.
23mod new {
24    use super::*;
25    use crate::ir::Variable;
26
27    impl<T: CubePrimitive + Clone> Array<T> {
28        /// Create a new array of the given length.
29        #[allow(unused_variables)]
30        pub fn new<L: Index>(length: L) -> Self {
31            Array { _val: PhantomData }
32        }
33
34        /// Create an array from data.
35        pub fn from_data<C: CubePrimitive>(_data: impl IntoIterator<Item = C>) -> Self {
36            Array { _val: PhantomData }
37        }
38
39        /// Expand function of [new](Array::new).
40        pub fn __expand_new(
41            context: &mut CubeContext,
42            size: ExpandElementTyped<u32>,
43        ) -> <Self as CubeType>::ExpandType {
44            let size = size
45                .constant()
46                .expect("Array need constant initialization value")
47                .as_u32();
48            let elem = T::as_elem(context);
49            context.create_local_array(Item::new(elem), size).into()
50        }
51
52        /// Expand function of [from_data](Array::from_data).
53        pub fn __expand_from_data<C: CubePrimitive>(
54            context: &mut CubeContext,
55            data: ArrayData<C>,
56        ) -> <Self as CubeType>::ExpandType {
57            let var = context.create_const_array(Item::new(T::as_elem(context)), data.values);
58            ExpandElementTyped::new(var)
59        }
60    }
61
62    /// Type useful for the expand function of [from_data](Array::from_data).
63    pub struct ArrayData<C> {
64        values: Vec<Variable>,
65        _ty: PhantomData<C>,
66    }
67
68    impl<C: CubePrimitive + Into<ExpandElementTyped<C>>, T: IntoIterator<Item = C>> From<T>
69        for ArrayData<C>
70    {
71        fn from(value: T) -> Self {
72            let values: Vec<Variable> = value
73                .into_iter()
74                .map(|value| {
75                    let value: ExpandElementTyped<C> = value.into();
76                    *value.expand
77                })
78                .collect();
79            ArrayData {
80                values,
81                _ty: PhantomData,
82            }
83        }
84    }
85}
86
87/// Module that contains the implementation details of the line_size function.
88mod line {
89    use crate::prelude::Line;
90
91    use super::*;
92
93    impl<P: CubePrimitive> Array<Line<P>> {
94        /// Get the size of each line contained in the tensor.
95        ///
96        /// Same as the following:
97        ///
98        /// ```rust, ignore
99        /// let size = tensor[0].size();
100        /// ```
101        pub fn line_size(&self) -> u32 {
102            unexpanded!()
103        }
104
105        // Expand function of [size](Tensor::line_size).
106        pub fn __expand_line_size(
107            expand: <Self as CubeType>::ExpandType,
108            context: &mut CubeContext,
109        ) -> u32 {
110            expand.__expand_line_size_method(context)
111        }
112    }
113
114    impl<P: CubePrimitive> ExpandElementTyped<Array<Line<P>>> {
115        /// Comptime version of [size](Array::line_size).
116        pub fn line_size(&self) -> u32 {
117            self.expand
118                .item
119                .vectorization
120                .unwrap_or(NonZero::new(1).unwrap())
121                .get() as u32
122        }
123
124        // Expand method of [size](Array::line_size).
125        pub fn __expand_line_size_method(&self, _content: &mut CubeContext) -> u32 {
126            self.line_size()
127        }
128    }
129}
130
131/// Module that contains the implementation details of vectorization functions.
132///
133/// TODO: Remove vectorization in favor of the line API.
134mod vectorization {
135    use super::*;
136
137    impl<T: CubePrimitive + Clone> Array<T> {
138        #[allow(unused_variables)]
139        pub fn vectorized<L: Index>(length: L, vectorization_factor: u32) -> Self {
140            Array { _val: PhantomData }
141        }
142
143        pub fn to_vectorized(self, _vectorization_factor: u32) -> T {
144            unexpanded!()
145        }
146
147        pub fn __expand_vectorized(
148            context: &mut CubeContext,
149            size: ExpandElementTyped<u32>,
150            vectorization_factor: u32,
151        ) -> <Self as CubeType>::ExpandType {
152            let size = size.value();
153            let size = match size.kind {
154                crate::ir::VariableKind::ConstantScalar(value) => value.as_u32(),
155                _ => panic!("Shared memory need constant initialization value"),
156            };
157            context
158                .create_local_array(
159                    Item::vectorized(
160                        T::as_elem(context),
161                        NonZero::new(vectorization_factor as u8),
162                    ),
163                    size,
164                )
165                .into()
166        }
167    }
168
169    impl<C: CubePrimitive> ExpandElementTyped<Array<C>> {
170        pub fn __expand_to_vectorized_method(
171            self,
172            context: &mut CubeContext,
173            vectorization_factor: ExpandElementTyped<u32>,
174        ) -> ExpandElementTyped<C> {
175            let factor = vectorization_factor
176                .constant()
177                .expect("Vectorization must be comptime")
178                .as_u32();
179            let var = self.expand.clone();
180            let item = Item::vectorized(var.item.elem(), NonZero::new(factor as u8));
181
182            let new_var = if factor == 1 {
183                let new_var = context.create_local(item);
184                let element = index::expand(
185                    context,
186                    self.clone(),
187                    ExpandElementTyped::from_lit(context, 0u32),
188                );
189                assign::expand(context, element, new_var.clone().into());
190                new_var
191            } else {
192                let new_var = context.create_local_mut(item);
193                for i in 0..factor {
194                    let expand: Self = self.expand.clone().into();
195                    let element =
196                        index::expand(context, expand, ExpandElementTyped::from_lit(context, i));
197                    index_assign::expand::<Array<C>>(
198                        context,
199                        new_var.clone().into(),
200                        ExpandElementTyped::from_lit(context, i),
201                        element,
202                    );
203                }
204                new_var
205            };
206            new_var.into()
207        }
208    }
209}
210
211/// Module that contains the implementation details of the metadata functions.
212mod metadata {
213    use crate::ir::Instruction;
214
215    use super::*;
216
217    impl<E: CubeType> Array<E> {
218        /// Obtain the array length
219        #[allow(clippy::len_without_is_empty)]
220        pub fn len(&self) -> u32 {
221            unexpanded!()
222        }
223
224        /// Obtain the array buffer length
225        pub fn buffer_len(&self) -> u32 {
226            unexpanded!()
227        }
228    }
229
230    impl<T: CubeType> ExpandElementTyped<Array<T>> {
231        // Expand method of [len](Array::len).
232        pub fn __expand_len_method(self, context: &mut CubeContext) -> ExpandElementTyped<u32> {
233            let out = context.create_local(Item::new(u32::as_elem(context)));
234            context.register(Instruction::new(
235                Metadata::Length {
236                    var: self.expand.into(),
237                },
238                out.clone().into(),
239            ));
240            out.into()
241        }
242
243        // Expand method of [buffer_len](Array::buffer_len).
244        pub fn __expand_buffer_len_method(
245            self,
246            context: &mut CubeContext,
247        ) -> ExpandElementTyped<u32> {
248            let out = context.create_local(Item::new(u32::as_elem(context)));
249            context.register(Instruction::new(
250                Metadata::BufferLength {
251                    var: self.expand.into(),
252                },
253                out.clone().into(),
254            ));
255            out.into()
256        }
257    }
258}
259
260/// Module that contains the implementation details of the index functions.
261mod indexation {
262    use crate::{
263        ir::{BinaryOperator, Instruction, Operator},
264        prelude::{CubeIndex, CubeIndexMut},
265    };
266
267    use super::*;
268
269    impl<E: CubePrimitive> Array<E> {
270        /// Perform an unchecked index into the array
271        ///
272        /// # Safety
273        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
274        /// always in bounds
275        pub unsafe fn index_unchecked<I: Index>(&self, _i: I) -> &E
276        where
277            Self: CubeIndex<I>,
278        {
279            unexpanded!()
280        }
281
282        /// Perform an unchecked index assignment into the array
283        ///
284        /// # Safety
285        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
286        /// always in bounds
287        pub unsafe fn index_assign_unchecked<I: Index>(&mut self, _i: I, _value: E)
288        where
289            Self: CubeIndexMut<I>,
290        {
291            unexpanded!()
292        }
293    }
294
295    impl<E: CubePrimitive> ExpandElementTyped<Array<E>> {
296        pub fn __expand_index_unchecked_method(
297            self,
298            context: &mut CubeContext,
299            i: ExpandElementTyped<u32>,
300        ) -> ExpandElementTyped<E> {
301            let out = context.create_local(self.expand.item);
302            context.register(Instruction::new(
303                Operator::UncheckedIndex(BinaryOperator {
304                    lhs: *self.expand,
305                    rhs: i.expand.consume(),
306                }),
307                *out,
308            ));
309            out.into()
310        }
311
312        pub fn __expand_index_assign_unchecked_method(
313            self,
314            context: &mut CubeContext,
315            i: ExpandElementTyped<u32>,
316            value: ExpandElementTyped<E>,
317        ) {
318            context.register(Instruction::new(
319                Operator::UncheckedIndexAssign(BinaryOperator {
320                    lhs: i.expand.consume(),
321                    rhs: value.expand.consume(),
322                }),
323                *self.expand,
324            ));
325        }
326    }
327}
328
329impl<E: CubePrimitive> IntoRuntime for Array<E> {
330    fn __expand_runtime_method(self, _context: &mut CubeContext) -> Self::ExpandType {
331        unimplemented!("Array can't exist at compile time")
332    }
333}
334
335impl<C: CubeType> CubeType for Array<C> {
336    type ExpandType = ExpandElementTyped<Array<C>>;
337}
338
339impl<C: CubeType> CubeType for &Array<C> {
340    type ExpandType = ExpandElementTyped<Array<C>>;
341}
342
343impl<C: CubeType> ExpandElementBaseInit for Array<C> {
344    fn init_elem(_context: &mut crate::prelude::CubeContext, elem: ExpandElement) -> ExpandElement {
345        // The type can't be deeply cloned/copied.
346        elem
347    }
348}
349
350impl<T: CubeType<ExpandType = ExpandElementTyped<T>>> SizedContainer for Array<T> {
351    type Item = T;
352}
353
354impl<T: CubeType> Iterator for &Array<T> {
355    type Item = T;
356
357    fn next(&mut self) -> Option<Self::Item> {
358        unexpanded!()
359    }
360}