cubecl_core/frontend/container/array/
base.rs1use alloc::vec::Vec;
2use core::{
3 marker::PhantomData,
4 ops::{Deref, DerefMut},
5};
6
7use cubecl_ir::{ExpandElement, LineSize, Scope};
8
9use crate::prelude::{
10 LinedExpand, List, ListExpand, ListMut, ListMutExpand, SizedContainer, index_unchecked,
11};
12use crate::prelude::{assign, index, index_assign};
13use crate::{self as cubecl};
14use crate::{
15 frontend::CubeType,
16 ir::{Metadata, Type},
17 unexpanded,
18};
19use crate::{
20 frontend::{CubePrimitive, ExpandElementIntoMut, ExpandElementTyped},
21 prelude::Lined,
22};
23use cubecl_macros::{cube, intrinsic};
24
25pub struct Array<E> {
27 _val: PhantomData<E>,
28}
29
30type ArrayExpand<E> = ExpandElementTyped<Array<E>>;
31
32mod new {
34
35 use cubecl_macros::intrinsic;
36
37 use super::*;
38 use crate::ir::Variable;
39
40 #[cube]
41 impl<T: CubePrimitive + Clone> Array<T> {
42 #[allow(unused_variables)]
44 pub fn new(#[comptime] length: usize) -> Self {
45 intrinsic!(|scope| {
46 let elem = T::as_type(scope);
47 scope.create_local_array(Type::new(elem), length).into()
48 })
49 }
50 }
51
52 impl<T: CubePrimitive + Clone> Array<T> {
53 #[allow(unused_variables)]
55 pub fn from_data<C: CubePrimitive>(data: impl IntoIterator<Item = C>) -> Self {
56 intrinsic!(|scope| {
57 scope
58 .create_const_array(Type::new(T::as_type(scope)), data.values)
59 .into()
60 })
61 }
62
63 pub fn __expand_from_data<C: CubePrimitive>(
65 scope: &mut Scope,
66 data: ArrayData<C>,
67 ) -> <Self as CubeType>::ExpandType {
68 let var = scope.create_const_array(Type::new(T::as_type(scope)), data.values);
69 ExpandElementTyped::new(var)
70 }
71 }
72
73 pub struct ArrayData<C> {
75 values: Vec<Variable>,
76 _ty: PhantomData<C>,
77 }
78
79 impl<C: CubePrimitive + Into<ExpandElementTyped<C>>, T: IntoIterator<Item = C>> From<T>
80 for ArrayData<C>
81 {
82 fn from(value: T) -> Self {
83 let values: Vec<Variable> = value
84 .into_iter()
85 .map(|value| {
86 let value: ExpandElementTyped<C> = value.into();
87 *value.expand
88 })
89 .collect();
90 ArrayData {
91 values,
92 _ty: PhantomData,
93 }
94 }
95 }
96}
97
98mod line {
100 use crate::prelude::Line;
101
102 use super::*;
103
104 impl<P: CubePrimitive> Array<Line<P>> {
105 pub fn line_size(&self) -> LineSize {
113 unexpanded!()
114 }
115
116 pub fn __expand_line_size(
118 expand: <Self as CubeType>::ExpandType,
119 scope: &mut Scope,
120 ) -> LineSize {
121 expand.__expand_line_size_method(scope)
122 }
123 }
124}
125
126mod vectorization {
130
131 use cubecl_ir::LineSize;
132
133 use super::*;
134
135 #[cube]
136 impl<T: CubePrimitive + Clone> Array<T> {
137 #[allow(unused_variables)]
138 pub fn lined(#[comptime] length: usize, #[comptime] line_size: LineSize) -> Self {
139 intrinsic!(|scope| {
140 scope
141 .create_local_array(Type::new(T::as_type(scope)).line(line_size), length)
142 .into()
143 })
144 }
145
146 #[allow(unused_variables)]
147 pub fn to_lined(self, #[comptime] line_size: LineSize) -> T {
148 intrinsic!(|scope| {
149 let factor = line_size;
150 let var = self.expand.clone();
151 let item = Type::new(var.storage_type()).line(factor);
152
153 let new_var = if factor == 1 {
154 let new_var = scope.create_local(item);
155 let element =
156 index::expand(scope, self.clone(), ExpandElementTyped::from_lit(scope, 0));
157 assign::expand_no_check::<T>(scope, element, new_var.clone().into());
158 new_var
159 } else {
160 let new_var = scope.create_local_mut(item);
161 for i in 0..factor {
162 let expand: Self = self.expand.clone().into();
163 let element =
164 index::expand(scope, expand, ExpandElementTyped::from_lit(scope, i));
165 index_assign::expand::<ExpandElementTyped<Array<T>>, T>(
166 scope,
167 new_var.clone().into(),
168 ExpandElementTyped::from_lit(scope, i),
169 element,
170 );
171 }
172 new_var
173 };
174 new_var.into()
175 })
176 }
177 }
178}
179
180mod metadata {
182 use crate::{ir::Instruction, prelude::expand_length_native};
183
184 use super::*;
185
186 #[cube]
187 impl<E: CubeType> Array<E> {
188 #[allow(clippy::len_without_is_empty)]
190 pub fn len(&self) -> usize {
191 intrinsic!(|scope| {
192 ExpandElement::Plain(expand_length_native(scope, *self.expand)).into()
193 })
194 }
195
196 pub fn buffer_len(&self) -> usize {
198 intrinsic!(|scope| {
199 let out = scope.create_local(Type::new(usize::as_type(scope)));
200 scope.register(Instruction::new(
201 Metadata::BufferLength {
202 var: self.expand.into(),
203 },
204 out.clone().into(),
205 ));
206 out.into()
207 })
208 }
209 }
210}
211
212mod indexation {
214 use cubecl_ir::{IndexAssignOperator, IndexOperator, Operator};
215
216 use crate::{
217 ir::Instruction,
218 prelude::{CubeIndex, CubeIndexMut},
219 };
220
221 use super::*;
222
223 #[cube]
224 impl<E: CubePrimitive> Array<E> {
225 #[allow(unused_variables)]
231 pub unsafe fn index_unchecked(&self, i: usize) -> &E
232 where
233 Self: CubeIndex,
234 {
235 intrinsic!(|scope| {
236 let out = scope.create_local(self.expand.ty);
237 scope.register(Instruction::new(
238 Operator::UncheckedIndex(IndexOperator {
239 list: *self.expand,
240 index: i.expand.consume(),
241 line_size: 0,
242 unroll_factor: 1,
243 }),
244 *out,
245 ));
246 out.into()
247 })
248 }
249
250 #[allow(unused_variables)]
256 pub unsafe fn index_assign_unchecked(&mut self, i: usize, value: E)
257 where
258 Self: CubeIndexMut,
259 {
260 intrinsic!(|scope| {
261 scope.register(Instruction::new(
262 Operator::UncheckedIndexAssign(IndexAssignOperator {
263 index: i.expand.consume(),
264 value: value.expand.consume(),
265 line_size: 0,
266 unroll_factor: 1,
267 }),
268 *self.expand,
269 ));
270 })
271 }
272 }
273}
274
275impl<C: CubeType> CubeType for Array<C> {
276 type ExpandType = ExpandElementTyped<Array<C>>;
277}
278
279impl<C: CubeType> CubeType for &Array<C> {
280 type ExpandType = ExpandElementTyped<Array<C>>;
281}
282
283impl<C: CubeType> ExpandElementIntoMut for Array<C> {
284 fn elem_into_mut(_scope: &mut crate::ir::Scope, elem: ExpandElement) -> ExpandElement {
285 elem
287 }
288}
289
290impl<T: CubePrimitive> SizedContainer for Array<T> {
291 type Item = T;
292}
293
294impl<T: CubeType> Iterator for &Array<T> {
295 type Item = T;
296
297 fn next(&mut self) -> Option<Self::Item> {
298 unexpanded!()
299 }
300}
301
302impl<T: CubePrimitive> List<T> for Array<T> {
303 fn __expand_read(
304 scope: &mut Scope,
305 this: ExpandElementTyped<Array<T>>,
306 idx: ExpandElementTyped<usize>,
307 ) -> ExpandElementTyped<T> {
308 index::expand(scope, this, idx)
309 }
310}
311
312impl<T: CubePrimitive> Deref for Array<T> {
313 type Target = [T];
314
315 fn deref(&self) -> &Self::Target {
316 unexpanded!()
317 }
318}
319
320impl<T: CubePrimitive> DerefMut for Array<T> {
321 fn deref_mut(&mut self) -> &mut Self::Target {
322 unexpanded!()
323 }
324}
325
326impl<T: CubePrimitive> ListExpand<T> for ExpandElementTyped<Array<T>> {
327 fn __expand_read_method(
328 &self,
329 scope: &mut Scope,
330 idx: ExpandElementTyped<usize>,
331 ) -> ExpandElementTyped<T> {
332 index::expand(scope, self.clone(), idx)
333 }
334 fn __expand_read_unchecked_method(
335 &self,
336 scope: &mut Scope,
337 idx: ExpandElementTyped<usize>,
338 ) -> ExpandElementTyped<T> {
339 index_unchecked::expand(scope, self.clone(), idx)
340 }
341
342 fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<usize> {
343 Self::__expand_len(scope, self.clone())
344 }
345}
346
347impl<T: CubePrimitive> Lined for Array<T> {}
348impl<T: CubePrimitive> LinedExpand for ExpandElementTyped<Array<T>> {
349 fn line_size(&self) -> LineSize {
350 self.expand.ty.line_size()
351 }
352}
353
354impl<T: CubePrimitive> ListMut<T> for Array<T> {
355 fn __expand_write(
356 scope: &mut Scope,
357 this: ExpandElementTyped<Array<T>>,
358 idx: ExpandElementTyped<usize>,
359 value: ExpandElementTyped<T>,
360 ) {
361 index_assign::expand(scope, this, idx, value);
362 }
363}
364
365impl<T: CubePrimitive> ListMutExpand<T> for ExpandElementTyped<Array<T>> {
366 fn __expand_write_method(
367 &self,
368 scope: &mut Scope,
369 idx: ExpandElementTyped<usize>,
370 value: ExpandElementTyped<T>,
371 ) {
372 index_assign::expand(scope, self.clone(), idx, value);
373 }
374}