Skip to main content

cubecl_core/frontend/container/array/
base.rs

1use alloc::vec::Vec;
2use core::{
3    marker::PhantomData,
4    ops::{Deref, DerefMut},
5};
6
7use cubecl_ir::{ExpandElement, LineSize, Scope};
8
9use crate::prelude::{
10    LinedExpand, List, ListExpand, ListMut, ListMutExpand, SizedContainer, index_unchecked,
11};
12use crate::prelude::{assign, index, index_assign};
13use crate::{self as cubecl};
14use crate::{
15    frontend::CubeType,
16    ir::{Metadata, Type},
17    unexpanded,
18};
19use crate::{
20    frontend::{CubePrimitive, ExpandElementIntoMut, ExpandElementTyped},
21    prelude::Lined,
22};
23use cubecl_macros::{cube, intrinsic};
24
25/// A contiguous array of elements.
26pub struct Array<E> {
27    _val: PhantomData<E>,
28}
29
30type ArrayExpand<E> = ExpandElementTyped<Array<E>>;
31
32/// Module that contains the implementation details of the new function.
33mod new {
34
35    use cubecl_macros::intrinsic;
36
37    use super::*;
38    use crate::ir::Variable;
39
40    #[cube]
41    impl<T: CubePrimitive + Clone> Array<T> {
42        /// Create a new array of the given length.
43        #[allow(unused_variables)]
44        pub fn new(#[comptime] length: usize) -> Self {
45            intrinsic!(|scope| {
46                let elem = T::as_type(scope);
47                scope.create_local_array(Type::new(elem), length).into()
48            })
49        }
50    }
51
52    impl<T: CubePrimitive + Clone> Array<T> {
53        /// Create an array from data.
54        #[allow(unused_variables)]
55        pub fn from_data<C: CubePrimitive>(data: impl IntoIterator<Item = C>) -> Self {
56            intrinsic!(|scope| {
57                scope
58                    .create_const_array(Type::new(T::as_type(scope)), data.values)
59                    .into()
60            })
61        }
62
63        /// Expand function of [`from_data`](Array::from_data).
64        pub fn __expand_from_data<C: CubePrimitive>(
65            scope: &mut Scope,
66            data: ArrayData<C>,
67        ) -> <Self as CubeType>::ExpandType {
68            let var = scope.create_const_array(Type::new(T::as_type(scope)), data.values);
69            ExpandElementTyped::new(var)
70        }
71    }
72
73    /// Type useful for the expand function of [`from_data`](Array::from_data).
74    pub struct ArrayData<C> {
75        values: Vec<Variable>,
76        _ty: PhantomData<C>,
77    }
78
79    impl<C: CubePrimitive + Into<ExpandElementTyped<C>>, T: IntoIterator<Item = C>> From<T>
80        for ArrayData<C>
81    {
82        fn from(value: T) -> Self {
83            let values: Vec<Variable> = value
84                .into_iter()
85                .map(|value| {
86                    let value: ExpandElementTyped<C> = value.into();
87                    *value.expand
88                })
89                .collect();
90            ArrayData {
91                values,
92                _ty: PhantomData,
93            }
94        }
95    }
96}
97
98/// Module that contains the implementation details of the `line_size` function.
99mod line {
100    use crate::prelude::Line;
101
102    use super::*;
103
104    impl<P: CubePrimitive> Array<Line<P>> {
105        /// Get the size of each line contained in the tensor.
106        ///
107        /// Same as the following:
108        ///
109        /// ```rust, ignore
110        /// let size = tensor[0].size();
111        /// ```
112        pub fn line_size(&self) -> LineSize {
113            unexpanded!()
114        }
115
116        // Expand function of [size](Tensor::line_size).
117        pub fn __expand_line_size(
118            expand: <Self as CubeType>::ExpandType,
119            scope: &mut Scope,
120        ) -> LineSize {
121            expand.__expand_line_size_method(scope)
122        }
123    }
124}
125
126/// Module that contains the implementation details of vectorization functions.
127///
128/// TODO: Remove vectorization in favor of the line API.
129mod vectorization {
130
131    use cubecl_ir::LineSize;
132
133    use super::*;
134
135    #[cube]
136    impl<T: CubePrimitive + Clone> Array<T> {
137        #[allow(unused_variables)]
138        pub fn lined(#[comptime] length: usize, #[comptime] line_size: LineSize) -> Self {
139            intrinsic!(|scope| {
140                scope
141                    .create_local_array(Type::new(T::as_type(scope)).line(line_size), length)
142                    .into()
143            })
144        }
145
146        #[allow(unused_variables)]
147        pub fn to_lined(self, #[comptime] line_size: LineSize) -> T {
148            intrinsic!(|scope| {
149                let factor = line_size;
150                let var = self.expand.clone();
151                let item = Type::new(var.storage_type()).line(factor);
152
153                let new_var = if factor == 1 {
154                    let new_var = scope.create_local(item);
155                    let element =
156                        index::expand(scope, self.clone(), ExpandElementTyped::from_lit(scope, 0));
157                    assign::expand_no_check::<T>(scope, element, new_var.clone().into());
158                    new_var
159                } else {
160                    let new_var = scope.create_local_mut(item);
161                    for i in 0..factor {
162                        let expand: Self = self.expand.clone().into();
163                        let element =
164                            index::expand(scope, expand, ExpandElementTyped::from_lit(scope, i));
165                        index_assign::expand::<ExpandElementTyped<Array<T>>, T>(
166                            scope,
167                            new_var.clone().into(),
168                            ExpandElementTyped::from_lit(scope, i),
169                            element,
170                        );
171                    }
172                    new_var
173                };
174                new_var.into()
175            })
176        }
177    }
178}
179
180/// Module that contains the implementation details of the metadata functions.
181mod metadata {
182    use crate::{ir::Instruction, prelude::expand_length_native};
183
184    use super::*;
185
186    #[cube]
187    impl<E: CubeType> Array<E> {
188        /// Obtain the array length
189        #[allow(clippy::len_without_is_empty)]
190        pub fn len(&self) -> usize {
191            intrinsic!(|scope| {
192                ExpandElement::Plain(expand_length_native(scope, *self.expand)).into()
193            })
194        }
195
196        /// Obtain the array buffer length
197        pub fn buffer_len(&self) -> usize {
198            intrinsic!(|scope| {
199                let out = scope.create_local(Type::new(usize::as_type(scope)));
200                scope.register(Instruction::new(
201                    Metadata::BufferLength {
202                        var: self.expand.into(),
203                    },
204                    out.clone().into(),
205                ));
206                out.into()
207            })
208        }
209    }
210}
211
212/// Module that contains the implementation details of the index functions.
213mod indexation {
214    use cubecl_ir::{IndexAssignOperator, IndexOperator, Operator};
215
216    use crate::{
217        ir::Instruction,
218        prelude::{CubeIndex, CubeIndexMut},
219    };
220
221    use super::*;
222
223    #[cube]
224    impl<E: CubePrimitive> Array<E> {
225        /// Perform an unchecked index into the array
226        ///
227        /// # Safety
228        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
229        /// always in bounds
230        #[allow(unused_variables)]
231        pub unsafe fn index_unchecked(&self, i: usize) -> &E
232        where
233            Self: CubeIndex,
234        {
235            intrinsic!(|scope| {
236                let out = scope.create_local(self.expand.ty);
237                scope.register(Instruction::new(
238                    Operator::UncheckedIndex(IndexOperator {
239                        list: *self.expand,
240                        index: i.expand.consume(),
241                        line_size: 0,
242                        unroll_factor: 1,
243                    }),
244                    *out,
245                ));
246                out.into()
247            })
248        }
249
250        /// Perform an unchecked index assignment into the array
251        ///
252        /// # Safety
253        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
254        /// always in bounds
255        #[allow(unused_variables)]
256        pub unsafe fn index_assign_unchecked(&mut self, i: usize, value: E)
257        where
258            Self: CubeIndexMut,
259        {
260            intrinsic!(|scope| {
261                scope.register(Instruction::new(
262                    Operator::UncheckedIndexAssign(IndexAssignOperator {
263                        index: i.expand.consume(),
264                        value: value.expand.consume(),
265                        line_size: 0,
266                        unroll_factor: 1,
267                    }),
268                    *self.expand,
269                ));
270            })
271        }
272    }
273}
274
275impl<C: CubeType> CubeType for Array<C> {
276    type ExpandType = ExpandElementTyped<Array<C>>;
277}
278
279impl<C: CubeType> CubeType for &Array<C> {
280    type ExpandType = ExpandElementTyped<Array<C>>;
281}
282
283impl<C: CubeType> ExpandElementIntoMut for Array<C> {
284    fn elem_into_mut(_scope: &mut crate::ir::Scope, elem: ExpandElement) -> ExpandElement {
285        // The type can't be deeply cloned/copied.
286        elem
287    }
288}
289
290impl<T: CubePrimitive> SizedContainer for Array<T> {
291    type Item = T;
292}
293
294impl<T: CubeType> Iterator for &Array<T> {
295    type Item = T;
296
297    fn next(&mut self) -> Option<Self::Item> {
298        unexpanded!()
299    }
300}
301
302impl<T: CubePrimitive> List<T> for Array<T> {
303    fn __expand_read(
304        scope: &mut Scope,
305        this: ExpandElementTyped<Array<T>>,
306        idx: ExpandElementTyped<usize>,
307    ) -> ExpandElementTyped<T> {
308        index::expand(scope, this, idx)
309    }
310}
311
312impl<T: CubePrimitive> Deref for Array<T> {
313    type Target = [T];
314
315    fn deref(&self) -> &Self::Target {
316        unexpanded!()
317    }
318}
319
320impl<T: CubePrimitive> DerefMut for Array<T> {
321    fn deref_mut(&mut self) -> &mut Self::Target {
322        unexpanded!()
323    }
324}
325
326impl<T: CubePrimitive> ListExpand<T> for ExpandElementTyped<Array<T>> {
327    fn __expand_read_method(
328        &self,
329        scope: &mut Scope,
330        idx: ExpandElementTyped<usize>,
331    ) -> ExpandElementTyped<T> {
332        index::expand(scope, self.clone(), idx)
333    }
334    fn __expand_read_unchecked_method(
335        &self,
336        scope: &mut Scope,
337        idx: ExpandElementTyped<usize>,
338    ) -> ExpandElementTyped<T> {
339        index_unchecked::expand(scope, self.clone(), idx)
340    }
341
342    fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<usize> {
343        Self::__expand_len(scope, self.clone())
344    }
345}
346
347impl<T: CubePrimitive> Lined for Array<T> {}
348impl<T: CubePrimitive> LinedExpand for ExpandElementTyped<Array<T>> {
349    fn line_size(&self) -> LineSize {
350        self.expand.ty.line_size()
351    }
352}
353
354impl<T: CubePrimitive> ListMut<T> for Array<T> {
355    fn __expand_write(
356        scope: &mut Scope,
357        this: ExpandElementTyped<Array<T>>,
358        idx: ExpandElementTyped<usize>,
359        value: ExpandElementTyped<T>,
360    ) {
361        index_assign::expand(scope, this, idx, value);
362    }
363}
364
365impl<T: CubePrimitive> ListMutExpand<T> for ExpandElementTyped<Array<T>> {
366    fn __expand_write_method(
367        &self,
368        scope: &mut Scope,
369        idx: ExpandElementTyped<usize>,
370        value: ExpandElementTyped<T>,
371    ) {
372        index_assign::expand(scope, self.clone(), idx, value);
373    }
374}