cubecl_core/frontend/container/array/
base.rs1use std::marker::PhantomData;
2
3use cubecl_ir::{ExpandElement, Scope};
4
5use crate::prelude::{
6 LinedExpand, List, ListExpand, ListMut, ListMutExpand, SizedContainer, index_unchecked,
7};
8use crate::prelude::{assign, index, index_assign};
9use crate::{self as cubecl};
10use crate::{
11 frontend::CubeType,
12 ir::{Metadata, Type},
13 unexpanded,
14};
15use crate::{
16 frontend::{CubePrimitive, ExpandElementIntoMut, ExpandElementTyped},
17 prelude::Lined,
18};
19use cubecl_macros::{cube, intrinsic};
20
21pub struct Array<E> {
23 _val: PhantomData<E>,
24}
25
26type ArrayExpand<E> = ExpandElementTyped<Array<E>>;
27
28mod new {
30
31 use cubecl_macros::intrinsic;
32
33 use super::*;
34 use crate::ir::Variable;
35
36 #[cube]
37 impl<T: CubePrimitive + Clone> Array<T> {
38 #[allow(unused_variables)]
40 pub fn new(length: u32) -> Self {
41 intrinsic!(|scope| {
42 let size = length
43 .constant()
44 .expect("Array needs constant initialization value")
45 .as_u32();
46 let elem = T::as_type(scope);
47 scope.create_local_array(Type::new(elem), size).into()
48 })
49 }
50 }
51
52 impl<T: CubePrimitive + Clone> Array<T> {
53 #[allow(unused_variables)]
55 pub fn from_data<C: CubePrimitive>(data: impl IntoIterator<Item = C>) -> Self {
56 Array { _val: PhantomData }
57 }
58
59 pub fn __expand_from_data<C: CubePrimitive>(
61 scope: &mut Scope,
62 data: ArrayData<C>,
63 ) -> <Self as CubeType>::ExpandType {
64 let var = scope.create_const_array(Type::new(T::as_type(scope)), data.values);
65 ExpandElementTyped::new(var)
66 }
67 }
68
69 pub struct ArrayData<C> {
71 values: Vec<Variable>,
72 _ty: PhantomData<C>,
73 }
74
75 impl<C: CubePrimitive + Into<ExpandElementTyped<C>>, T: IntoIterator<Item = C>> From<T>
76 for ArrayData<C>
77 {
78 fn from(value: T) -> Self {
79 let values: Vec<Variable> = value
80 .into_iter()
81 .map(|value| {
82 let value: ExpandElementTyped<C> = value.into();
83 *value.expand
84 })
85 .collect();
86 ArrayData {
87 values,
88 _ty: PhantomData,
89 }
90 }
91 }
92}
93
94mod line {
96 use crate::prelude::Line;
97
98 use super::*;
99
100 impl<P: CubePrimitive> Array<Line<P>> {
101 pub fn line_size(&self) -> u32 {
109 unexpanded!()
110 }
111
112 pub fn __expand_line_size(
114 expand: <Self as CubeType>::ExpandType,
115 scope: &mut Scope,
116 ) -> u32 {
117 expand.__expand_line_size_method(scope)
118 }
119 }
120}
121
122mod vectorization {
126
127 use super::*;
128
129 #[cube]
130 impl<T: CubePrimitive + Clone> Array<T> {
131 #[allow(unused_variables)]
132 pub fn vectorized(#[comptime] length: u32, #[comptime] line_size: u32) -> Self {
133 intrinsic!(|scope| {
134 scope
135 .create_local_array(Type::new(T::as_type(scope)).line(line_size), length)
136 .into()
137 })
138 }
139
140 #[allow(unused_variables)]
141 pub fn to_vectorized(self, #[comptime] line_size: u32) -> T {
142 intrinsic!(|scope| {
143 let factor = line_size;
144 let var = self.expand.clone();
145 let item = Type::new(var.storage_type()).line(factor);
146
147 let new_var = if factor == 1 {
148 let new_var = scope.create_local(item);
149 let element = index::expand(
150 scope,
151 self.clone(),
152 ExpandElementTyped::from_lit(scope, 0u32),
153 );
154 assign::expand_no_check::<T>(scope, element, new_var.clone().into());
155 new_var
156 } else {
157 let new_var = scope.create_local_mut(item);
158 for i in 0..factor {
159 let expand: Self = self.expand.clone().into();
160 let element =
161 index::expand(scope, expand, ExpandElementTyped::from_lit(scope, i));
162 index_assign::expand::<ExpandElementTyped<Array<T>>, T>(
163 scope,
164 new_var.clone().into(),
165 ExpandElementTyped::from_lit(scope, i),
166 element,
167 );
168 }
169 new_var
170 };
171 new_var.into()
172 })
173 }
174 }
175}
176
177mod metadata {
179 use crate::{ir::Instruction, prelude::expand_length_native};
180
181 use super::*;
182
183 #[cube]
184 impl<E: CubeType> Array<E> {
185 #[allow(clippy::len_without_is_empty)]
187 pub fn len(&self) -> u32 {
188 intrinsic!(|scope| {
189 ExpandElement::Plain(expand_length_native(scope, *self.expand)).into()
190 })
191 }
192
193 pub fn buffer_len(&self) -> u32 {
195 intrinsic!(|scope| {
196 let out = scope.create_local(Type::new(u32::as_type(scope)));
197 scope.register(Instruction::new(
198 Metadata::BufferLength {
199 var: self.expand.into(),
200 },
201 out.clone().into(),
202 ));
203 out.into()
204 })
205 }
206 }
207}
208
209mod indexation {
211 use cubecl_ir::{IndexAssignOperator, IndexOperator, Operator};
212
213 use crate::{
214 ir::Instruction,
215 prelude::{CubeIndex, CubeIndexMut},
216 };
217
218 use super::*;
219
220 #[cube]
221 impl<E: CubePrimitive> Array<E> {
222 #[allow(unused_variables)]
228 pub unsafe fn index_unchecked(&self, i: u32) -> &E
229 where
230 Self: CubeIndex,
231 {
232 intrinsic!(|scope| {
233 let out = scope.create_local(self.expand.ty);
234 scope.register(Instruction::new(
235 Operator::UncheckedIndex(IndexOperator {
236 list: *self.expand,
237 index: i.expand.consume(),
238 line_size: 0,
239 unroll_factor: 1,
240 }),
241 *out,
242 ));
243 out.into()
244 })
245 }
246
247 #[allow(unused_variables)]
253 pub unsafe fn index_assign_unchecked(&mut self, i: u32, value: E)
254 where
255 Self: CubeIndexMut,
256 {
257 intrinsic!(|scope| {
258 scope.register(Instruction::new(
259 Operator::UncheckedIndexAssign(IndexAssignOperator {
260 index: i.expand.consume(),
261 value: value.expand.consume(),
262 line_size: 0,
263 unroll_factor: 1,
264 }),
265 *self.expand,
266 ));
267 })
268 }
269 }
270}
271
272impl<C: CubeType> CubeType for Array<C> {
273 type ExpandType = ExpandElementTyped<Array<C>>;
274}
275
276impl<C: CubeType> CubeType for &Array<C> {
277 type ExpandType = ExpandElementTyped<Array<C>>;
278}
279
280impl<C: CubeType> ExpandElementIntoMut for Array<C> {
281 fn elem_into_mut(_scope: &mut crate::ir::Scope, elem: ExpandElement) -> ExpandElement {
282 elem
284 }
285}
286
287impl<T: CubePrimitive> SizedContainer for Array<T> {
288 type Item = T;
289}
290
291impl<T: CubeType> Iterator for &Array<T> {
292 type Item = T;
293
294 fn next(&mut self) -> Option<Self::Item> {
295 unexpanded!()
296 }
297}
298
299impl<T: CubePrimitive> List<T> for Array<T> {
300 fn __expand_read(
301 scope: &mut Scope,
302 this: ExpandElementTyped<Array<T>>,
303 idx: ExpandElementTyped<u32>,
304 ) -> ExpandElementTyped<T> {
305 index::expand(scope, this, idx)
306 }
307}
308
309impl<T: CubePrimitive> ListExpand<T> for ExpandElementTyped<Array<T>> {
310 fn __expand_read_method(
311 &self,
312 scope: &mut Scope,
313 idx: ExpandElementTyped<u32>,
314 ) -> ExpandElementTyped<T> {
315 index::expand(scope, self.clone(), idx)
316 }
317 fn __expand_read_unchecked_method(
318 &self,
319 scope: &mut Scope,
320 idx: ExpandElementTyped<u32>,
321 ) -> ExpandElementTyped<T> {
322 index_unchecked::expand(scope, self.clone(), idx)
323 }
324
325 fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
326 Self::__expand_len(scope, self.clone())
327 }
328}
329
330impl<T: CubePrimitive> Lined for Array<T> {}
331impl<T: CubePrimitive> LinedExpand for ExpandElementTyped<Array<T>> {
332 fn line_size(&self) -> u32 {
333 self.expand.ty.line_size()
334 }
335}
336
337impl<T: CubePrimitive> ListMut<T> for Array<T> {
338 fn __expand_write(
339 scope: &mut Scope,
340 this: ExpandElementTyped<Array<T>>,
341 idx: ExpandElementTyped<u32>,
342 value: ExpandElementTyped<T>,
343 ) {
344 index_assign::expand(scope, this, idx, value);
345 }
346}
347
348impl<T: CubePrimitive> ListMutExpand<T> for ExpandElementTyped<Array<T>> {
349 fn __expand_write_method(
350 &self,
351 scope: &mut Scope,
352 idx: ExpandElementTyped<u32>,
353 value: ExpandElementTyped<T>,
354 ) {
355 index_assign::expand(scope, self.clone(), idx, value);
356 }
357}