cubecl_core/frontend/container/array/
base.rs

1use std::marker::PhantomData;
2
3use cubecl_ir::{ExpandElement, Scope};
4
5use crate::prelude::{
6    LinedExpand, List, ListExpand, ListMut, ListMutExpand, SizedContainer, index_unchecked,
7};
8use crate::prelude::{assign, index, index_assign};
9use crate::{self as cubecl};
10use crate::{
11    frontend::CubeType,
12    ir::{Metadata, Type},
13    unexpanded,
14};
15use crate::{
16    frontend::{CubePrimitive, ExpandElementIntoMut, ExpandElementTyped},
17    prelude::Lined,
18};
19use cubecl_macros::{cube, intrinsic};
20
21/// A contiguous array of elements.
22pub struct Array<E> {
23    _val: PhantomData<E>,
24}
25
26type ArrayExpand<E> = ExpandElementTyped<Array<E>>;
27
28/// Module that contains the implementation details of the new function.
29mod new {
30
31    use cubecl_macros::intrinsic;
32
33    use super::*;
34    use crate::ir::Variable;
35
36    #[cube]
37    impl<T: CubePrimitive + Clone> Array<T> {
38        /// Create a new array of the given length.
39        #[allow(unused_variables)]
40        pub fn new(length: u32) -> Self {
41            intrinsic!(|scope| {
42                let size = length
43                    .constant()
44                    .expect("Array needs constant initialization value")
45                    .as_u32();
46                let elem = T::as_type(scope);
47                scope.create_local_array(Type::new(elem), size).into()
48            })
49        }
50    }
51
52    impl<T: CubePrimitive + Clone> Array<T> {
53        /// Create an array from data.
54        #[allow(unused_variables)]
55        pub fn from_data<C: CubePrimitive>(data: impl IntoIterator<Item = C>) -> Self {
56            Array { _val: PhantomData }
57        }
58
59        /// Expand function of [from_data](Array::from_data).
60        pub fn __expand_from_data<C: CubePrimitive>(
61            scope: &mut Scope,
62            data: ArrayData<C>,
63        ) -> <Self as CubeType>::ExpandType {
64            let var = scope.create_const_array(Type::new(T::as_type(scope)), data.values);
65            ExpandElementTyped::new(var)
66        }
67    }
68
69    /// Type useful for the expand function of [from_data](Array::from_data).
70    pub struct ArrayData<C> {
71        values: Vec<Variable>,
72        _ty: PhantomData<C>,
73    }
74
75    impl<C: CubePrimitive + Into<ExpandElementTyped<C>>, T: IntoIterator<Item = C>> From<T>
76        for ArrayData<C>
77    {
78        fn from(value: T) -> Self {
79            let values: Vec<Variable> = value
80                .into_iter()
81                .map(|value| {
82                    let value: ExpandElementTyped<C> = value.into();
83                    *value.expand
84                })
85                .collect();
86            ArrayData {
87                values,
88                _ty: PhantomData,
89            }
90        }
91    }
92}
93
94/// Module that contains the implementation details of the line_size function.
95mod line {
96    use crate::prelude::Line;
97
98    use super::*;
99
100    impl<P: CubePrimitive> Array<Line<P>> {
101        /// Get the size of each line contained in the tensor.
102        ///
103        /// Same as the following:
104        ///
105        /// ```rust, ignore
106        /// let size = tensor[0].size();
107        /// ```
108        pub fn line_size(&self) -> u32 {
109            unexpanded!()
110        }
111
112        // Expand function of [size](Tensor::line_size).
113        pub fn __expand_line_size(
114            expand: <Self as CubeType>::ExpandType,
115            scope: &mut Scope,
116        ) -> u32 {
117            expand.__expand_line_size_method(scope)
118        }
119    }
120}
121
122/// Module that contains the implementation details of vectorization functions.
123///
124/// TODO: Remove vectorization in favor of the line API.
125mod vectorization {
126
127    use super::*;
128
129    #[cube]
130    impl<T: CubePrimitive + Clone> Array<T> {
131        #[allow(unused_variables)]
132        pub fn vectorized(#[comptime] length: u32, #[comptime] line_size: u32) -> Self {
133            intrinsic!(|scope| {
134                scope
135                    .create_local_array(Type::new(T::as_type(scope)).line(line_size), length)
136                    .into()
137            })
138        }
139
140        #[allow(unused_variables)]
141        pub fn to_vectorized(self, #[comptime] line_size: u32) -> T {
142            intrinsic!(|scope| {
143                let factor = line_size;
144                let var = self.expand.clone();
145                let item = Type::new(var.storage_type()).line(factor);
146
147                let new_var = if factor == 1 {
148                    let new_var = scope.create_local(item);
149                    let element = index::expand(
150                        scope,
151                        self.clone(),
152                        ExpandElementTyped::from_lit(scope, 0u32),
153                    );
154                    assign::expand_no_check::<T>(scope, element, new_var.clone().into());
155                    new_var
156                } else {
157                    let new_var = scope.create_local_mut(item);
158                    for i in 0..factor {
159                        let expand: Self = self.expand.clone().into();
160                        let element =
161                            index::expand(scope, expand, ExpandElementTyped::from_lit(scope, i));
162                        index_assign::expand::<ExpandElementTyped<Array<T>>, T>(
163                            scope,
164                            new_var.clone().into(),
165                            ExpandElementTyped::from_lit(scope, i),
166                            element,
167                        );
168                    }
169                    new_var
170                };
171                new_var.into()
172            })
173        }
174    }
175}
176
177/// Module that contains the implementation details of the metadata functions.
178mod metadata {
179    use crate::{ir::Instruction, prelude::expand_length_native};
180
181    use super::*;
182
183    #[cube]
184    impl<E: CubeType> Array<E> {
185        /// Obtain the array length
186        #[allow(clippy::len_without_is_empty)]
187        pub fn len(&self) -> u32 {
188            intrinsic!(|scope| {
189                ExpandElement::Plain(expand_length_native(scope, *self.expand)).into()
190            })
191        }
192
193        /// Obtain the array buffer length
194        pub fn buffer_len(&self) -> u32 {
195            intrinsic!(|scope| {
196                let out = scope.create_local(Type::new(u32::as_type(scope)));
197                scope.register(Instruction::new(
198                    Metadata::BufferLength {
199                        var: self.expand.into(),
200                    },
201                    out.clone().into(),
202                ));
203                out.into()
204            })
205        }
206    }
207}
208
209/// Module that contains the implementation details of the index functions.
210mod indexation {
211    use cubecl_ir::{IndexAssignOperator, IndexOperator, Operator};
212
213    use crate::{
214        ir::Instruction,
215        prelude::{CubeIndex, CubeIndexMut},
216    };
217
218    use super::*;
219
220    #[cube]
221    impl<E: CubePrimitive> Array<E> {
222        /// Perform an unchecked index into the array
223        ///
224        /// # Safety
225        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
226        /// always in bounds
227        #[allow(unused_variables)]
228        pub unsafe fn index_unchecked(&self, i: u32) -> &E
229        where
230            Self: CubeIndex,
231        {
232            intrinsic!(|scope| {
233                let out = scope.create_local(self.expand.ty);
234                scope.register(Instruction::new(
235                    Operator::UncheckedIndex(IndexOperator {
236                        list: *self.expand,
237                        index: i.expand.consume(),
238                        line_size: 0,
239                        unroll_factor: 1,
240                    }),
241                    *out,
242                ));
243                out.into()
244            })
245        }
246
247        /// Perform an unchecked index assignment into the array
248        ///
249        /// # Safety
250        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
251        /// always in bounds
252        #[allow(unused_variables)]
253        pub unsafe fn index_assign_unchecked(&mut self, i: u32, value: E)
254        where
255            Self: CubeIndexMut,
256        {
257            intrinsic!(|scope| {
258                scope.register(Instruction::new(
259                    Operator::UncheckedIndexAssign(IndexAssignOperator {
260                        index: i.expand.consume(),
261                        value: value.expand.consume(),
262                        line_size: 0,
263                        unroll_factor: 1,
264                    }),
265                    *self.expand,
266                ));
267            })
268        }
269    }
270}
271
272impl<C: CubeType> CubeType for Array<C> {
273    type ExpandType = ExpandElementTyped<Array<C>>;
274}
275
276impl<C: CubeType> CubeType for &Array<C> {
277    type ExpandType = ExpandElementTyped<Array<C>>;
278}
279
280impl<C: CubeType> ExpandElementIntoMut for Array<C> {
281    fn elem_into_mut(_scope: &mut crate::ir::Scope, elem: ExpandElement) -> ExpandElement {
282        // The type can't be deeply cloned/copied.
283        elem
284    }
285}
286
287impl<T: CubePrimitive> SizedContainer for Array<T> {
288    type Item = T;
289}
290
291impl<T: CubeType> Iterator for &Array<T> {
292    type Item = T;
293
294    fn next(&mut self) -> Option<Self::Item> {
295        unexpanded!()
296    }
297}
298
299impl<T: CubePrimitive> List<T> for Array<T> {
300    fn __expand_read(
301        scope: &mut Scope,
302        this: ExpandElementTyped<Array<T>>,
303        idx: ExpandElementTyped<u32>,
304    ) -> ExpandElementTyped<T> {
305        index::expand(scope, this, idx)
306    }
307}
308
309impl<T: CubePrimitive> ListExpand<T> for ExpandElementTyped<Array<T>> {
310    fn __expand_read_method(
311        &self,
312        scope: &mut Scope,
313        idx: ExpandElementTyped<u32>,
314    ) -> ExpandElementTyped<T> {
315        index::expand(scope, self.clone(), idx)
316    }
317    fn __expand_read_unchecked_method(
318        &self,
319        scope: &mut Scope,
320        idx: ExpandElementTyped<u32>,
321    ) -> ExpandElementTyped<T> {
322        index_unchecked::expand(scope, self.clone(), idx)
323    }
324
325    fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
326        Self::__expand_len(scope, self.clone())
327    }
328}
329
330impl<T: CubePrimitive> Lined for Array<T> {}
331impl<T: CubePrimitive> LinedExpand for ExpandElementTyped<Array<T>> {
332    fn line_size(&self) -> u32 {
333        self.expand.ty.line_size()
334    }
335}
336
337impl<T: CubePrimitive> ListMut<T> for Array<T> {
338    fn __expand_write(
339        scope: &mut Scope,
340        this: ExpandElementTyped<Array<T>>,
341        idx: ExpandElementTyped<u32>,
342        value: ExpandElementTyped<T>,
343    ) {
344        index_assign::expand(scope, this, idx, value);
345    }
346}
347
348impl<T: CubePrimitive> ListMutExpand<T> for ExpandElementTyped<Array<T>> {
349    fn __expand_write_method(
350        &self,
351        scope: &mut Scope,
352        idx: ExpandElementTyped<u32>,
353        value: ExpandElementTyped<T>,
354    ) {
355        index_assign::expand(scope, self.clone(), idx, value);
356    }
357}