cubecl_core/frontend/container/array/
base.rs

1use std::{marker::PhantomData, num::NonZero};
2
3use cubecl_ir::{ExpandElement, Scope};
4
5use crate as cubecl;
6use crate::frontend::{CubePrimitive, ExpandElementIntoMut, ExpandElementTyped};
7use crate::prelude::{List, ListExpand, ListMut, ListMutExpand, SizedContainer, index_unchecked};
8use crate::prelude::{assign, index, index_assign};
9use crate::{
10    frontend::CubeType,
11    ir::{Item, Metadata},
12    unexpanded,
13};
14use cubecl_macros::{cube, intrinsic};
15
16/// A contiguous array of elements.
17pub struct Array<E> {
18    _val: PhantomData<E>,
19}
20
21type ArrayExpand<E> = ExpandElementTyped<Array<E>>;
22
23/// Module that contains the implementation details of the new function.
24mod new {
25
26    use cubecl_macros::intrinsic;
27
28    use super::*;
29    use crate::ir::Variable;
30
31    #[cube]
32    impl<T: CubePrimitive + Clone> Array<T> {
33        /// Create a new array of the given length.
34        #[allow(unused_variables)]
35        pub fn new(length: u32) -> Self {
36            intrinsic!(|scope| {
37                let size = length
38                    .constant()
39                    .expect("Array needs constant initialization value")
40                    .as_u32();
41                let elem = T::as_elem(scope);
42                scope.create_local_array(Item::new(elem), size).into()
43            })
44        }
45    }
46
47    impl<T: CubePrimitive + Clone> Array<T> {
48        /// Create an array from data.
49        #[allow(unused_variables)]
50        pub fn from_data<C: CubePrimitive>(data: impl IntoIterator<Item = C>) -> Self {
51            Array { _val: PhantomData }
52        }
53
54        /// Expand function of [from_data](Array::from_data).
55        pub fn __expand_from_data<C: CubePrimitive>(
56            scope: &mut Scope,
57            data: ArrayData<C>,
58        ) -> <Self as CubeType>::ExpandType {
59            let var = scope.create_const_array(Item::new(T::as_elem(scope)), data.values);
60            ExpandElementTyped::new(var)
61        }
62    }
63
64    /// Type useful for the expand function of [from_data](Array::from_data).
65    pub struct ArrayData<C> {
66        values: Vec<Variable>,
67        _ty: PhantomData<C>,
68    }
69
70    impl<C: CubePrimitive + Into<ExpandElementTyped<C>>, T: IntoIterator<Item = C>> From<T>
71        for ArrayData<C>
72    {
73        fn from(value: T) -> Self {
74            let values: Vec<Variable> = value
75                .into_iter()
76                .map(|value| {
77                    let value: ExpandElementTyped<C> = value.into();
78                    *value.expand
79                })
80                .collect();
81            ArrayData {
82                values,
83                _ty: PhantomData,
84            }
85        }
86    }
87}
88
89/// Module that contains the implementation details of the line_size function.
90mod line {
91    use crate::prelude::Line;
92
93    use super::*;
94
95    impl<P: CubePrimitive> Array<Line<P>> {
96        /// Get the size of each line contained in the tensor.
97        ///
98        /// Same as the following:
99        ///
100        /// ```rust, ignore
101        /// let size = tensor[0].size();
102        /// ```
103        pub fn line_size(&self) -> u32 {
104            unexpanded!()
105        }
106
107        // Expand function of [size](Tensor::line_size).
108        pub fn __expand_line_size(
109            expand: <Self as CubeType>::ExpandType,
110            scope: &mut Scope,
111        ) -> u32 {
112            expand.__expand_line_size_method(scope)
113        }
114    }
115
116    impl<P: CubePrimitive> ExpandElementTyped<Array<Line<P>>> {
117        /// Comptime version of [size](Array::line_size).
118        pub fn line_size(&self) -> u32 {
119            self.expand
120                .item
121                .vectorization
122                .unwrap_or(NonZero::new(1).unwrap())
123                .get() as u32
124        }
125
126        // Expand method of [size](Array::line_size).
127        pub fn __expand_line_size_method(&self, _content: &mut Scope) -> u32 {
128            self.line_size()
129        }
130    }
131}
132
133/// Module that contains the implementation details of vectorization functions.
134///
135/// TODO: Remove vectorization in favor of the line API.
136mod vectorization {
137
138    use super::*;
139
140    #[cube]
141    impl<T: CubePrimitive + Clone> Array<T> {
142        #[allow(unused_variables)]
143        pub fn vectorized(#[comptime] length: u32, #[comptime] vectorization_factor: u32) -> Self {
144            intrinsic!(|scope| {
145                scope
146                    .create_local_array(
147                        Item::vectorized(
148                            T::as_elem(scope),
149                            NonZero::new(vectorization_factor as u8),
150                        ),
151                        length,
152                    )
153                    .into()
154            })
155        }
156
157        #[allow(unused_variables)]
158        pub fn to_vectorized(self, #[comptime] vectorization_factor: u32) -> T {
159            intrinsic!(|scope| {
160                let factor = vectorization_factor;
161                let var = self.expand.clone();
162                let item = Item::vectorized(var.item.elem(), NonZero::new(factor as u8));
163
164                let new_var = if factor == 1 {
165                    let new_var = scope.create_local(item);
166                    let element = index::expand(
167                        scope,
168                        self.clone(),
169                        ExpandElementTyped::from_lit(scope, 0u32),
170                    );
171                    assign::expand_no_check::<T>(scope, element, new_var.clone().into());
172                    new_var
173                } else {
174                    let new_var = scope.create_local_mut(item);
175                    for i in 0..factor {
176                        let expand: Self = self.expand.clone().into();
177                        let element =
178                            index::expand(scope, expand, ExpandElementTyped::from_lit(scope, i));
179                        index_assign::expand::<ExpandElementTyped<Array<T>>, T>(
180                            scope,
181                            new_var.clone().into(),
182                            ExpandElementTyped::from_lit(scope, i),
183                            element,
184                        );
185                    }
186                    new_var
187                };
188                new_var.into()
189            })
190        }
191    }
192}
193
194/// Module that contains the implementation details of the metadata functions.
195mod metadata {
196    use crate::{ir::Instruction, prelude::expand_length_native};
197
198    use super::*;
199
200    #[cube]
201    impl<E: CubeType> Array<E> {
202        /// Obtain the array length
203        #[allow(clippy::len_without_is_empty)]
204        pub fn len(&self) -> u32 {
205            intrinsic!(|scope| {
206                ExpandElement::Plain(expand_length_native(scope, *self.expand)).into()
207            })
208        }
209
210        /// Obtain the array buffer length
211        pub fn buffer_len(&self) -> u32 {
212            intrinsic!(|scope| {
213                let out = scope.create_local(Item::new(u32::as_elem(scope)));
214                scope.register(Instruction::new(
215                    Metadata::BufferLength {
216                        var: self.expand.into(),
217                    },
218                    out.clone().into(),
219                ));
220                out.into()
221            })
222        }
223    }
224}
225
226/// Module that contains the implementation details of the index functions.
227mod indexation {
228    use cubecl_ir::{IndexAssignOperator, IndexOperator, Operator};
229
230    use crate::{
231        ir::Instruction,
232        prelude::{CubeIndex, CubeIndexMut},
233    };
234
235    use super::*;
236
237    #[cube]
238    impl<E: CubePrimitive> Array<E> {
239        /// Perform an unchecked index into the array
240        ///
241        /// # Safety
242        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
243        /// always in bounds
244        #[allow(unused_variables)]
245        pub unsafe fn index_unchecked(&self, i: u32) -> &E
246        where
247            Self: CubeIndex,
248        {
249            intrinsic!(|scope| {
250                let out = scope.create_local(self.expand.item);
251                scope.register(Instruction::new(
252                    Operator::UncheckedIndex(IndexOperator {
253                        list: *self.expand,
254                        index: i.expand.consume(),
255                        line_size: 0,
256                    }),
257                    *out,
258                ));
259                out.into()
260            })
261        }
262
263        /// Perform an unchecked index assignment into the array
264        ///
265        /// # Safety
266        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
267        /// always in bounds
268        #[allow(unused_variables)]
269        pub unsafe fn index_assign_unchecked(&mut self, i: u32, value: E)
270        where
271            Self: CubeIndexMut,
272        {
273            intrinsic!(|scope| {
274                scope.register(Instruction::new(
275                    Operator::UncheckedIndexAssign(IndexAssignOperator {
276                        index: i.expand.consume(),
277                        value: value.expand.consume(),
278                        line_size: 0,
279                    }),
280                    *self.expand,
281                ));
282            })
283        }
284    }
285}
286
287impl<C: CubeType> CubeType for Array<C> {
288    type ExpandType = ExpandElementTyped<Array<C>>;
289}
290
291impl<C: CubeType> CubeType for &Array<C> {
292    type ExpandType = ExpandElementTyped<Array<C>>;
293}
294
295impl<C: CubeType> ExpandElementIntoMut for Array<C> {
296    fn elem_into_mut(_scope: &mut crate::ir::Scope, elem: ExpandElement) -> ExpandElement {
297        // The type can't be deeply cloned/copied.
298        elem
299    }
300}
301
302impl<T: CubePrimitive> SizedContainer for Array<T> {
303    type Item = T;
304}
305
306impl<T: CubeType> Iterator for &Array<T> {
307    type Item = T;
308
309    fn next(&mut self) -> Option<Self::Item> {
310        unexpanded!()
311    }
312}
313
314impl<T: CubePrimitive> List<T> for Array<T> {
315    fn __expand_read(
316        scope: &mut Scope,
317        this: ExpandElementTyped<Array<T>>,
318        idx: ExpandElementTyped<u32>,
319    ) -> ExpandElementTyped<T> {
320        index::expand(scope, this, idx)
321    }
322}
323
324impl<T: CubePrimitive> ListExpand<T> for ExpandElementTyped<Array<T>> {
325    fn __expand_read_method(
326        &self,
327        scope: &mut Scope,
328        idx: ExpandElementTyped<u32>,
329    ) -> ExpandElementTyped<T> {
330        index::expand(scope, self.clone(), idx)
331    }
332    fn __expand_read_unchecked_method(
333        &self,
334        scope: &mut Scope,
335        idx: ExpandElementTyped<u32>,
336    ) -> ExpandElementTyped<T> {
337        index_unchecked::expand(scope, self.clone(), idx)
338    }
339}
340
341impl<T: CubePrimitive> ListMut<T> for Array<T> {
342    fn __expand_write(
343        scope: &mut Scope,
344        this: ExpandElementTyped<Array<T>>,
345        idx: ExpandElementTyped<u32>,
346        value: ExpandElementTyped<T>,
347    ) {
348        index_assign::expand(scope, this, idx, value);
349    }
350}
351
352impl<T: CubePrimitive> ListMutExpand<T> for ExpandElementTyped<Array<T>> {
353    fn __expand_write_method(
354        &self,
355        scope: &mut Scope,
356        idx: ExpandElementTyped<u32>,
357        value: ExpandElementTyped<T>,
358    ) {
359        index_assign::expand(scope, self.clone(), idx, value);
360    }
361}