Skip to main content

cubecl_core/frontend/container/array/
base.rs

1use alloc::vec::Vec;
2use core::{
3    marker::PhantomData,
4    ops::{Deref, DerefMut},
5};
6
7use cubecl_ir::{ManagedVariable, Scope, VectorSize};
8
9use crate::frontend::{CubePrimitive, NativeExpand};
10use crate::prelude::*;
11use crate::{self as cubecl};
12use crate::{
13    frontend::CubeType,
14    ir::{Metadata, Type},
15    unexpanded,
16};
17use cubecl_macros::{cube, intrinsic};
18
19/// A contiguous array of elements.
20pub struct Array<E> {
21    _val: PhantomData<E>,
22}
23
24type ArrayExpand<E> = NativeExpand<Array<E>>;
25
26/// Module that contains the implementation details of the new function.
27mod new {
28
29    use cubecl_macros::intrinsic;
30
31    use super::*;
32    use crate::ir::Variable;
33
34    #[cube]
35    impl<T: CubePrimitive + Clone> Array<T> {
36        /// Create a new array of the given length.
37        #[allow(unused_variables)]
38        pub fn new(#[comptime] length: usize) -> Self {
39            intrinsic!(|scope| {
40                let elem = T::as_type(scope);
41                scope.create_local_array(elem, length).into()
42            })
43        }
44    }
45
46    impl<T: CubePrimitive + Clone> Array<T> {
47        /// Create an array from data.
48        #[allow(unused_variables)]
49        pub fn from_data<C: CubePrimitive>(data: impl IntoIterator<Item = C>) -> Self {
50            intrinsic!(|scope| {
51                scope
52                    .create_const_array(Type::new(T::as_type(scope)), data.values)
53                    .into()
54            })
55        }
56
57        /// Expand function of [`from_data`](Array::from_data).
58        pub fn __expand_from_data<C: CubePrimitive>(
59            scope: &mut Scope,
60            data: ArrayData<C>,
61        ) -> <Self as CubeType>::ExpandType {
62            let var = scope.create_const_array(T::as_type(scope), data.values);
63            NativeExpand::new(var)
64        }
65    }
66
67    /// Type useful for the expand function of [`from_data`](Array::from_data).
68    pub struct ArrayData<C> {
69        values: Vec<Variable>,
70        _ty: PhantomData<C>,
71    }
72
73    impl<C: CubePrimitive + Into<NativeExpand<C>>, T: IntoIterator<Item = C>> From<T> for ArrayData<C> {
74        fn from(value: T) -> Self {
75            let values: Vec<Variable> = value
76                .into_iter()
77                .map(|value| {
78                    let value: NativeExpand<C> = value.into();
79                    *value.expand
80                })
81                .collect();
82            ArrayData {
83                values,
84                _ty: PhantomData,
85            }
86        }
87    }
88}
89
90/// Module that contains the implementation details of the `vector_size` function.
91mod vector {
92    use super::*;
93
94    impl<P: CubePrimitive> Array<P> {
95        /// Get the size of each vector contained in the tensor.
96        ///
97        /// Same as the following:
98        ///
99        /// ```rust, ignore
100        /// let size = tensor[0].vector_size();
101        /// ```
102        pub fn vector_size(&self) -> VectorSize {
103            P::vector_size()
104        }
105
106        // Expand function of [size](Tensor::vector_size).
107        pub fn __expand_vector_size(
108            expand: <Self as CubeType>::ExpandType,
109            scope: &mut Scope,
110        ) -> VectorSize {
111            expand.__expand_vector_size_method(scope)
112        }
113    }
114}
115
116/// Module that contains the implementation details of vectorization functions.
117mod vectorization {
118    use super::*;
119
120    #[cube]
121    impl<T: CubePrimitive + Clone> Array<T> {
122        #[allow(unused_variables)]
123        pub fn to_vectorized<N: Size>(self) -> T {
124            let factor = N::value();
125            intrinsic!(|scope| {
126                let var = self.expand.clone();
127                let item = Type::new(var.storage_type()).with_vector_size(factor);
128
129                let new_var = if factor == 1 {
130                    let new_var = scope.create_local(item);
131                    let element =
132                        index::expand(scope, self.clone(), NativeExpand::from_lit(scope, 0));
133                    assign::expand_no_check::<T>(scope, element, new_var.clone().into());
134                    new_var
135                } else {
136                    let new_var = scope.create_local_mut(item);
137                    for i in 0..factor {
138                        let expand: Self = self.expand.clone().into();
139                        let element =
140                            index::expand(scope, expand, NativeExpand::from_lit(scope, i));
141                        index_assign::expand::<NativeExpand<Array<T>>, T>(
142                            scope,
143                            new_var.clone().into(),
144                            NativeExpand::from_lit(scope, i),
145                            element,
146                        );
147                    }
148                    new_var
149                };
150                new_var.into()
151            })
152        }
153    }
154}
155
156/// Module that contains the implementation details of the metadata functions.
157mod metadata {
158    use crate::{ir::Instruction, prelude::expand_length_native};
159
160    use super::*;
161
162    #[cube]
163    impl<E: CubeType> Array<E> {
164        /// Obtain the array length
165        #[allow(clippy::len_without_is_empty)]
166        pub fn len(&self) -> usize {
167            intrinsic!(|scope| {
168                ManagedVariable::Plain(expand_length_native(scope, *self.expand)).into()
169            })
170        }
171
172        /// Obtain the array buffer length
173        pub fn buffer_len(&self) -> usize {
174            intrinsic!(|scope| {
175                let out = scope.create_local(usize::as_type(scope));
176                scope.register(Instruction::new(
177                    Metadata::BufferLength {
178                        var: self.expand.into(),
179                    },
180                    out.clone().into(),
181                ));
182                out.into()
183            })
184        }
185    }
186}
187
188/// Module that contains the implementation details of the index functions.
189mod indexation {
190    use cubecl_ir::{IndexAssignOperator, IndexOperator, Operator};
191
192    use crate::ir::Instruction;
193
194    use super::*;
195
196    #[cube]
197    impl<E: CubePrimitive> Array<E> {
198        /// Perform an unchecked index into the array
199        ///
200        /// # Safety
201        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
202        /// always in bounds
203        #[allow(unused_variables)]
204        pub unsafe fn index_unchecked(&self, i: usize) -> &E {
205            intrinsic!(|scope| {
206                let out = scope.create_local(self.expand.ty);
207                scope.register(Instruction::new(
208                    Operator::UncheckedIndex(IndexOperator {
209                        list: *self.expand,
210                        index: i.expand.consume(),
211                        vector_size: 0,
212                        unroll_factor: 1,
213                    }),
214                    *out,
215                ));
216                out.into()
217            })
218        }
219
220        /// Perform an unchecked index assignment into the array
221        ///
222        /// # Safety
223        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
224        /// always in bounds
225        #[allow(unused_variables)]
226        pub unsafe fn index_assign_unchecked(&mut self, i: usize, value: E) {
227            intrinsic!(|scope| {
228                scope.register(Instruction::new(
229                    Operator::UncheckedIndexAssign(IndexAssignOperator {
230                        index: i.expand.consume(),
231                        value: value.expand.consume(),
232                        vector_size: 0,
233                        unroll_factor: 1,
234                    }),
235                    *self.expand,
236                ));
237            })
238        }
239    }
240}
241
242impl<C: CubeType> CubeType for Array<C> {
243    type ExpandType = NativeExpand<Array<C>>;
244}
245
246impl<C: CubeType> CubeType for &Array<C> {
247    type ExpandType = NativeExpand<Array<C>>;
248}
249
250impl<C: CubeType> IntoMut for NativeExpand<Array<C>> {
251    fn into_mut(self, _scope: &mut crate::ir::Scope) -> Self {
252        // The type can't be deeply cloned/copied.
253        self
254    }
255}
256
257impl<T: CubePrimitive> SizedContainer for Array<T> {
258    type Item = T;
259}
260
261impl<T: CubeType> Iterator for &Array<T> {
262    type Item = T;
263
264    fn next(&mut self) -> Option<Self::Item> {
265        unexpanded!()
266    }
267}
268
269impl<T: CubePrimitive> List<T> for Array<T> {
270    fn __expand_read(
271        scope: &mut Scope,
272        this: NativeExpand<Array<T>>,
273        idx: NativeExpand<usize>,
274    ) -> NativeExpand<T> {
275        index::expand(scope, this, idx)
276    }
277}
278
279impl<T: CubePrimitive> Deref for Array<T> {
280    type Target = [T];
281
282    fn deref(&self) -> &Self::Target {
283        unexpanded!()
284    }
285}
286
287impl<T: CubePrimitive> DerefMut for Array<T> {
288    fn deref_mut(&mut self) -> &mut Self::Target {
289        unexpanded!()
290    }
291}
292
293impl<T: CubePrimitive> ListExpand<T> for NativeExpand<Array<T>> {
294    fn __expand_read_method(&self, scope: &mut Scope, idx: NativeExpand<usize>) -> NativeExpand<T> {
295        index::expand(scope, self.clone(), idx)
296    }
297    fn __expand_read_unchecked_method(
298        &self,
299        scope: &mut Scope,
300        idx: NativeExpand<usize>,
301    ) -> NativeExpand<T> {
302        index_unchecked::expand(scope, self.clone(), idx)
303    }
304
305    fn __expand_len_method(&self, scope: &mut Scope) -> NativeExpand<usize> {
306        Self::__expand_len(scope, self.clone())
307    }
308}
309
310impl<T: CubePrimitive> Vectorized for Array<T> {}
311impl<T: CubePrimitive> VectorizedExpand for NativeExpand<Array<T>> {
312    fn vector_size(&self) -> VectorSize {
313        self.expand.ty.vector_size()
314    }
315}
316
317impl<T: CubePrimitive> ListMut<T> for Array<T> {
318    fn __expand_write(
319        scope: &mut Scope,
320        this: NativeExpand<Array<T>>,
321        idx: NativeExpand<usize>,
322        value: NativeExpand<T>,
323    ) {
324        index_assign::expand(scope, this, idx, value);
325    }
326}
327
328impl<T: CubePrimitive> ListMutExpand<T> for NativeExpand<Array<T>> {
329    fn __expand_write_method(
330        &self,
331        scope: &mut Scope,
332        idx: NativeExpand<usize>,
333        value: NativeExpand<T>,
334    ) {
335        index_assign::expand(scope, self.clone(), idx, value);
336    }
337}