Skip to main content

cubecl_core/frontend/container/array/
base.rs

1use std::{
2    marker::PhantomData,
3    ops::{Deref, DerefMut},
4};
5
6use cubecl_ir::{ExpandElement, LineSize, Scope};
7
8use crate::prelude::{
9    LinedExpand, List, ListExpand, ListMut, ListMutExpand, SizedContainer, index_unchecked,
10};
11use crate::prelude::{assign, index, index_assign};
12use crate::{self as cubecl};
13use crate::{
14    frontend::CubeType,
15    ir::{Metadata, Type},
16    unexpanded,
17};
18use crate::{
19    frontend::{CubePrimitive, ExpandElementIntoMut, ExpandElementTyped},
20    prelude::Lined,
21};
22use cubecl_macros::{cube, intrinsic};
23
24/// A contiguous array of elements.
25pub struct Array<E> {
26    _val: PhantomData<E>,
27}
28
29type ArrayExpand<E> = ExpandElementTyped<Array<E>>;
30
31/// Module that contains the implementation details of the new function.
32mod new {
33
34    use cubecl_macros::intrinsic;
35
36    use super::*;
37    use crate::ir::Variable;
38
39    #[cube]
40    impl<T: CubePrimitive + Clone> Array<T> {
41        /// Create a new array of the given length.
42        #[allow(unused_variables)]
43        pub fn new(#[comptime] length: usize) -> Self {
44            intrinsic!(|scope| {
45                let elem = T::as_type(scope);
46                scope.create_local_array(Type::new(elem), length).into()
47            })
48        }
49    }
50
51    impl<T: CubePrimitive + Clone> Array<T> {
52        /// Create an array from data.
53        #[allow(unused_variables)]
54        pub fn from_data<C: CubePrimitive>(data: impl IntoIterator<Item = C>) -> Self {
55            intrinsic!(|scope| {
56                scope
57                    .create_const_array(Type::new(T::as_type(scope)), data.values)
58                    .into()
59            })
60        }
61
62        /// Expand function of [`from_data`](Array::from_data).
63        pub fn __expand_from_data<C: CubePrimitive>(
64            scope: &mut Scope,
65            data: ArrayData<C>,
66        ) -> <Self as CubeType>::ExpandType {
67            let var = scope.create_const_array(Type::new(T::as_type(scope)), data.values);
68            ExpandElementTyped::new(var)
69        }
70    }
71
72    /// Type useful for the expand function of [`from_data`](Array::from_data).
73    pub struct ArrayData<C> {
74        values: Vec<Variable>,
75        _ty: PhantomData<C>,
76    }
77
78    impl<C: CubePrimitive + Into<ExpandElementTyped<C>>, T: IntoIterator<Item = C>> From<T>
79        for ArrayData<C>
80    {
81        fn from(value: T) -> Self {
82            let values: Vec<Variable> = value
83                .into_iter()
84                .map(|value| {
85                    let value: ExpandElementTyped<C> = value.into();
86                    *value.expand
87                })
88                .collect();
89            ArrayData {
90                values,
91                _ty: PhantomData,
92            }
93        }
94    }
95}
96
97/// Module that contains the implementation details of the `line_size` function.
98mod line {
99    use crate::prelude::Line;
100
101    use super::*;
102
103    impl<P: CubePrimitive> Array<Line<P>> {
104        /// Get the size of each line contained in the tensor.
105        ///
106        /// Same as the following:
107        ///
108        /// ```rust, ignore
109        /// let size = tensor[0].size();
110        /// ```
111        pub fn line_size(&self) -> LineSize {
112            unexpanded!()
113        }
114
115        // Expand function of [size](Tensor::line_size).
116        pub fn __expand_line_size(
117            expand: <Self as CubeType>::ExpandType,
118            scope: &mut Scope,
119        ) -> LineSize {
120            expand.__expand_line_size_method(scope)
121        }
122    }
123}
124
125/// Module that contains the implementation details of vectorization functions.
126///
127/// TODO: Remove vectorization in favor of the line API.
128mod vectorization {
129
130    use cubecl_ir::LineSize;
131
132    use super::*;
133
134    #[cube]
135    impl<T: CubePrimitive + Clone> Array<T> {
136        #[allow(unused_variables)]
137        pub fn lined(#[comptime] length: usize, #[comptime] line_size: LineSize) -> Self {
138            intrinsic!(|scope| {
139                scope
140                    .create_local_array(Type::new(T::as_type(scope)).line(line_size), length)
141                    .into()
142            })
143        }
144
145        #[allow(unused_variables)]
146        pub fn to_lined(self, #[comptime] line_size: LineSize) -> T {
147            intrinsic!(|scope| {
148                let factor = line_size;
149                let var = self.expand.clone();
150                let item = Type::new(var.storage_type()).line(factor);
151
152                let new_var = if factor == 1 {
153                    let new_var = scope.create_local(item);
154                    let element =
155                        index::expand(scope, self.clone(), ExpandElementTyped::from_lit(scope, 0));
156                    assign::expand_no_check::<T>(scope, element, new_var.clone().into());
157                    new_var
158                } else {
159                    let new_var = scope.create_local_mut(item);
160                    for i in 0..factor {
161                        let expand: Self = self.expand.clone().into();
162                        let element =
163                            index::expand(scope, expand, ExpandElementTyped::from_lit(scope, i));
164                        index_assign::expand::<ExpandElementTyped<Array<T>>, T>(
165                            scope,
166                            new_var.clone().into(),
167                            ExpandElementTyped::from_lit(scope, i),
168                            element,
169                        );
170                    }
171                    new_var
172                };
173                new_var.into()
174            })
175        }
176    }
177}
178
179/// Module that contains the implementation details of the metadata functions.
180mod metadata {
181    use crate::{ir::Instruction, prelude::expand_length_native};
182
183    use super::*;
184
185    #[cube]
186    impl<E: CubeType> Array<E> {
187        /// Obtain the array length
188        #[allow(clippy::len_without_is_empty)]
189        pub fn len(&self) -> usize {
190            intrinsic!(|scope| {
191                ExpandElement::Plain(expand_length_native(scope, *self.expand)).into()
192            })
193        }
194
195        /// Obtain the array buffer length
196        pub fn buffer_len(&self) -> usize {
197            intrinsic!(|scope| {
198                let out = scope.create_local(Type::new(usize::as_type(scope)));
199                scope.register(Instruction::new(
200                    Metadata::BufferLength {
201                        var: self.expand.into(),
202                    },
203                    out.clone().into(),
204                ));
205                out.into()
206            })
207        }
208    }
209}
210
211/// Module that contains the implementation details of the index functions.
212mod indexation {
213    use cubecl_ir::{IndexAssignOperator, IndexOperator, Operator};
214
215    use crate::{
216        ir::Instruction,
217        prelude::{CubeIndex, CubeIndexMut},
218    };
219
220    use super::*;
221
222    #[cube]
223    impl<E: CubePrimitive> Array<E> {
224        /// Perform an unchecked index into the array
225        ///
226        /// # Safety
227        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
228        /// always in bounds
229        #[allow(unused_variables)]
230        pub unsafe fn index_unchecked(&self, i: usize) -> &E
231        where
232            Self: CubeIndex,
233        {
234            intrinsic!(|scope| {
235                let out = scope.create_local(self.expand.ty);
236                scope.register(Instruction::new(
237                    Operator::UncheckedIndex(IndexOperator {
238                        list: *self.expand,
239                        index: i.expand.consume(),
240                        line_size: 0,
241                        unroll_factor: 1,
242                    }),
243                    *out,
244                ));
245                out.into()
246            })
247        }
248
249        /// Perform an unchecked index assignment into the array
250        ///
251        /// # Safety
252        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
253        /// always in bounds
254        #[allow(unused_variables)]
255        pub unsafe fn index_assign_unchecked(&mut self, i: usize, value: E)
256        where
257            Self: CubeIndexMut,
258        {
259            intrinsic!(|scope| {
260                scope.register(Instruction::new(
261                    Operator::UncheckedIndexAssign(IndexAssignOperator {
262                        index: i.expand.consume(),
263                        value: value.expand.consume(),
264                        line_size: 0,
265                        unroll_factor: 1,
266                    }),
267                    *self.expand,
268                ));
269            })
270        }
271    }
272}
273
274impl<C: CubeType> CubeType for Array<C> {
275    type ExpandType = ExpandElementTyped<Array<C>>;
276}
277
278impl<C: CubeType> CubeType for &Array<C> {
279    type ExpandType = ExpandElementTyped<Array<C>>;
280}
281
282impl<C: CubeType> ExpandElementIntoMut for Array<C> {
283    fn elem_into_mut(_scope: &mut crate::ir::Scope, elem: ExpandElement) -> ExpandElement {
284        // The type can't be deeply cloned/copied.
285        elem
286    }
287}
288
289impl<T: CubePrimitive> SizedContainer for Array<T> {
290    type Item = T;
291}
292
293impl<T: CubeType> Iterator for &Array<T> {
294    type Item = T;
295
296    fn next(&mut self) -> Option<Self::Item> {
297        unexpanded!()
298    }
299}
300
301impl<T: CubePrimitive> List<T> for Array<T> {
302    fn __expand_read(
303        scope: &mut Scope,
304        this: ExpandElementTyped<Array<T>>,
305        idx: ExpandElementTyped<usize>,
306    ) -> ExpandElementTyped<T> {
307        index::expand(scope, this, idx)
308    }
309}
310
311impl<T: CubePrimitive> Deref for Array<T> {
312    type Target = [T];
313
314    fn deref(&self) -> &Self::Target {
315        unexpanded!()
316    }
317}
318
319impl<T: CubePrimitive> DerefMut for Array<T> {
320    fn deref_mut(&mut self) -> &mut Self::Target {
321        unexpanded!()
322    }
323}
324
325impl<T: CubePrimitive> ListExpand<T> for ExpandElementTyped<Array<T>> {
326    fn __expand_read_method(
327        &self,
328        scope: &mut Scope,
329        idx: ExpandElementTyped<usize>,
330    ) -> ExpandElementTyped<T> {
331        index::expand(scope, self.clone(), idx)
332    }
333    fn __expand_read_unchecked_method(
334        &self,
335        scope: &mut Scope,
336        idx: ExpandElementTyped<usize>,
337    ) -> ExpandElementTyped<T> {
338        index_unchecked::expand(scope, self.clone(), idx)
339    }
340
341    fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<usize> {
342        Self::__expand_len(scope, self.clone())
343    }
344}
345
346impl<T: CubePrimitive> Lined for Array<T> {}
347impl<T: CubePrimitive> LinedExpand for ExpandElementTyped<Array<T>> {
348    fn line_size(&self) -> LineSize {
349        self.expand.ty.line_size()
350    }
351}
352
353impl<T: CubePrimitive> ListMut<T> for Array<T> {
354    fn __expand_write(
355        scope: &mut Scope,
356        this: ExpandElementTyped<Array<T>>,
357        idx: ExpandElementTyped<usize>,
358        value: ExpandElementTyped<T>,
359    ) {
360        index_assign::expand(scope, this, idx, value);
361    }
362}
363
364impl<T: CubePrimitive> ListMutExpand<T> for ExpandElementTyped<Array<T>> {
365    fn __expand_write_method(
366        &self,
367        scope: &mut Scope,
368        idx: ExpandElementTyped<usize>,
369        value: ExpandElementTyped<T>,
370    ) {
371        index_assign::expand(scope, self.clone(), idx, value);
372    }
373}