cubecl_core/frontend/container/array/
base.rs

1use std::marker::PhantomData;
2
3use cubecl_ir::{ExpandElement, LineSize, Scope};
4
5use crate::prelude::{
6    LinedExpand, List, ListExpand, ListMut, ListMutExpand, SizedContainer, index_unchecked,
7};
8use crate::prelude::{assign, index, index_assign};
9use crate::{self as cubecl};
10use crate::{
11    frontend::CubeType,
12    ir::{Metadata, Type},
13    unexpanded,
14};
15use crate::{
16    frontend::{CubePrimitive, ExpandElementIntoMut, ExpandElementTyped},
17    prelude::Lined,
18};
19use cubecl_macros::{cube, intrinsic};
20
21/// A contiguous array of elements.
22pub struct Array<E> {
23    _val: PhantomData<E>,
24}
25
26type ArrayExpand<E> = ExpandElementTyped<Array<E>>;
27
28/// Module that contains the implementation details of the new function.
29mod new {
30
31    use cubecl_macros::intrinsic;
32
33    use super::*;
34    use crate::ir::Variable;
35
36    #[cube]
37    impl<T: CubePrimitive + Clone> Array<T> {
38        /// Create a new array of the given length.
39        #[allow(unused_variables)]
40        pub fn new(#[comptime] length: usize) -> Self {
41            intrinsic!(|scope| {
42                let elem = T::as_type(scope);
43                scope.create_local_array(Type::new(elem), length).into()
44            })
45        }
46    }
47
48    impl<T: CubePrimitive + Clone> Array<T> {
49        /// Create an array from data.
50        #[allow(unused_variables)]
51        pub fn from_data<C: CubePrimitive>(data: impl IntoIterator<Item = C>) -> Self {
52            intrinsic!(|scope| {
53                scope
54                    .create_const_array(Type::new(T::as_type(scope)), data.values)
55                    .into()
56            })
57        }
58
59        /// Expand function of [from_data](Array::from_data).
60        pub fn __expand_from_data<C: CubePrimitive>(
61            scope: &mut Scope,
62            data: ArrayData<C>,
63        ) -> <Self as CubeType>::ExpandType {
64            let var = scope.create_const_array(Type::new(T::as_type(scope)), data.values);
65            ExpandElementTyped::new(var)
66        }
67    }
68
69    /// Type useful for the expand function of [from_data](Array::from_data).
70    pub struct ArrayData<C> {
71        values: Vec<Variable>,
72        _ty: PhantomData<C>,
73    }
74
75    impl<C: CubePrimitive + Into<ExpandElementTyped<C>>, T: IntoIterator<Item = C>> From<T>
76        for ArrayData<C>
77    {
78        fn from(value: T) -> Self {
79            let values: Vec<Variable> = value
80                .into_iter()
81                .map(|value| {
82                    let value: ExpandElementTyped<C> = value.into();
83                    *value.expand
84                })
85                .collect();
86            ArrayData {
87                values,
88                _ty: PhantomData,
89            }
90        }
91    }
92}
93
94/// Module that contains the implementation details of the line_size function.
95mod line {
96    use crate::prelude::Line;
97
98    use super::*;
99
100    impl<P: CubePrimitive> Array<Line<P>> {
101        /// Get the size of each line contained in the tensor.
102        ///
103        /// Same as the following:
104        ///
105        /// ```rust, ignore
106        /// let size = tensor[0].size();
107        /// ```
108        pub fn line_size(&self) -> LineSize {
109            unexpanded!()
110        }
111
112        // Expand function of [size](Tensor::line_size).
113        pub fn __expand_line_size(
114            expand: <Self as CubeType>::ExpandType,
115            scope: &mut Scope,
116        ) -> LineSize {
117            expand.__expand_line_size_method(scope)
118        }
119    }
120}
121
122/// Module that contains the implementation details of vectorization functions.
123///
124/// TODO: Remove vectorization in favor of the line API.
125mod vectorization {
126
127    use cubecl_ir::LineSize;
128
129    use super::*;
130
131    #[cube]
132    impl<T: CubePrimitive + Clone> Array<T> {
133        #[allow(unused_variables)]
134        pub fn lined(#[comptime] length: usize, #[comptime] line_size: LineSize) -> Self {
135            intrinsic!(|scope| {
136                scope
137                    .create_local_array(Type::new(T::as_type(scope)).line(line_size), length)
138                    .into()
139            })
140        }
141
142        #[allow(unused_variables)]
143        pub fn to_lined(self, #[comptime] line_size: LineSize) -> T {
144            intrinsic!(|scope| {
145                let factor = line_size;
146                let var = self.expand.clone();
147                let item = Type::new(var.storage_type()).line(factor);
148
149                let new_var = if factor == 1 {
150                    let new_var = scope.create_local(item);
151                    let element =
152                        index::expand(scope, self.clone(), ExpandElementTyped::from_lit(scope, 0));
153                    assign::expand_no_check::<T>(scope, element, new_var.clone().into());
154                    new_var
155                } else {
156                    let new_var = scope.create_local_mut(item);
157                    for i in 0..factor {
158                        let expand: Self = self.expand.clone().into();
159                        let element =
160                            index::expand(scope, expand, ExpandElementTyped::from_lit(scope, i));
161                        index_assign::expand::<ExpandElementTyped<Array<T>>, T>(
162                            scope,
163                            new_var.clone().into(),
164                            ExpandElementTyped::from_lit(scope, i),
165                            element,
166                        );
167                    }
168                    new_var
169                };
170                new_var.into()
171            })
172        }
173    }
174}
175
176/// Module that contains the implementation details of the metadata functions.
177mod metadata {
178    use crate::{ir::Instruction, prelude::expand_length_native};
179
180    use super::*;
181
182    #[cube]
183    impl<E: CubeType> Array<E> {
184        /// Obtain the array length
185        #[allow(clippy::len_without_is_empty)]
186        pub fn len(&self) -> usize {
187            intrinsic!(|scope| {
188                ExpandElement::Plain(expand_length_native(scope, *self.expand)).into()
189            })
190        }
191
192        /// Obtain the array buffer length
193        pub fn buffer_len(&self) -> usize {
194            intrinsic!(|scope| {
195                let out = scope.create_local(Type::new(usize::as_type(scope)));
196                scope.register(Instruction::new(
197                    Metadata::BufferLength {
198                        var: self.expand.into(),
199                    },
200                    out.clone().into(),
201                ));
202                out.into()
203            })
204        }
205    }
206}
207
208/// Module that contains the implementation details of the index functions.
209mod indexation {
210    use cubecl_ir::{IndexAssignOperator, IndexOperator, Operator};
211
212    use crate::{
213        ir::Instruction,
214        prelude::{CubeIndex, CubeIndexMut},
215    };
216
217    use super::*;
218
219    #[cube]
220    impl<E: CubePrimitive> Array<E> {
221        /// Perform an unchecked index into the array
222        ///
223        /// # Safety
224        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
225        /// always in bounds
226        #[allow(unused_variables)]
227        pub unsafe fn index_unchecked(&self, i: usize) -> &E
228        where
229            Self: CubeIndex,
230        {
231            intrinsic!(|scope| {
232                let out = scope.create_local(self.expand.ty);
233                scope.register(Instruction::new(
234                    Operator::UncheckedIndex(IndexOperator {
235                        list: *self.expand,
236                        index: i.expand.consume(),
237                        line_size: 0,
238                        unroll_factor: 1,
239                    }),
240                    *out,
241                ));
242                out.into()
243            })
244        }
245
246        /// Perform an unchecked index assignment into the array
247        ///
248        /// # Safety
249        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
250        /// always in bounds
251        #[allow(unused_variables)]
252        pub unsafe fn index_assign_unchecked(&mut self, i: usize, value: E)
253        where
254            Self: CubeIndexMut,
255        {
256            intrinsic!(|scope| {
257                scope.register(Instruction::new(
258                    Operator::UncheckedIndexAssign(IndexAssignOperator {
259                        index: i.expand.consume(),
260                        value: value.expand.consume(),
261                        line_size: 0,
262                        unroll_factor: 1,
263                    }),
264                    *self.expand,
265                ));
266            })
267        }
268    }
269}
270
271impl<C: CubeType> CubeType for Array<C> {
272    type ExpandType = ExpandElementTyped<Array<C>>;
273}
274
275impl<C: CubeType> CubeType for &Array<C> {
276    type ExpandType = ExpandElementTyped<Array<C>>;
277}
278
279impl<C: CubeType> ExpandElementIntoMut for Array<C> {
280    fn elem_into_mut(_scope: &mut crate::ir::Scope, elem: ExpandElement) -> ExpandElement {
281        // The type can't be deeply cloned/copied.
282        elem
283    }
284}
285
286impl<T: CubePrimitive> SizedContainer for Array<T> {
287    type Item = T;
288}
289
290impl<T: CubeType> Iterator for &Array<T> {
291    type Item = T;
292
293    fn next(&mut self) -> Option<Self::Item> {
294        unexpanded!()
295    }
296}
297
298impl<T: CubePrimitive> List<T> for Array<T> {
299    fn __expand_read(
300        scope: &mut Scope,
301        this: ExpandElementTyped<Array<T>>,
302        idx: ExpandElementTyped<usize>,
303    ) -> ExpandElementTyped<T> {
304        index::expand(scope, this, idx)
305    }
306}
307
308impl<T: CubePrimitive> ListExpand<T> for ExpandElementTyped<Array<T>> {
309    fn __expand_read_method(
310        &self,
311        scope: &mut Scope,
312        idx: ExpandElementTyped<usize>,
313    ) -> ExpandElementTyped<T> {
314        index::expand(scope, self.clone(), idx)
315    }
316    fn __expand_read_unchecked_method(
317        &self,
318        scope: &mut Scope,
319        idx: ExpandElementTyped<usize>,
320    ) -> ExpandElementTyped<T> {
321        index_unchecked::expand(scope, self.clone(), idx)
322    }
323
324    fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<usize> {
325        Self::__expand_len(scope, self.clone())
326    }
327}
328
329impl<T: CubePrimitive> Lined for Array<T> {}
330impl<T: CubePrimitive> LinedExpand for ExpandElementTyped<Array<T>> {
331    fn line_size(&self) -> LineSize {
332        self.expand.ty.line_size()
333    }
334}
335
336impl<T: CubePrimitive> ListMut<T> for Array<T> {
337    fn __expand_write(
338        scope: &mut Scope,
339        this: ExpandElementTyped<Array<T>>,
340        idx: ExpandElementTyped<usize>,
341        value: ExpandElementTyped<T>,
342    ) {
343        index_assign::expand(scope, this, idx, value);
344    }
345}
346
347impl<T: CubePrimitive> ListMutExpand<T> for ExpandElementTyped<Array<T>> {
348    fn __expand_write_method(
349        &self,
350        scope: &mut Scope,
351        idx: ExpandElementTyped<usize>,
352        value: ExpandElementTyped<T>,
353    ) {
354        index_assign::expand(scope, self.clone(), idx, value);
355    }
356}