cubecl_core/frontend/container/array/
base.rs1use std::{marker::PhantomData, num::NonZero};
2
3use cubecl_ir::{ExpandElement, Scope};
4
5use crate as cubecl;
6use crate::frontend::{CubePrimitive, ExpandElementIntoMut, ExpandElementTyped};
7use crate::prelude::{List, ListExpand, ListMut, ListMutExpand, SizedContainer, index_unchecked};
8use crate::prelude::{assign, index, index_assign};
9use crate::{
10 frontend::CubeType,
11 ir::{Item, Metadata},
12 unexpanded,
13};
14use cubecl_macros::{cube, intrinsic};
15
16pub struct Array<E> {
18 _val: PhantomData<E>,
19}
20
21type ArrayExpand<E> = ExpandElementTyped<Array<E>>;
22
23mod new {
25
26 use cubecl_macros::intrinsic;
27
28 use super::*;
29 use crate::ir::Variable;
30
31 #[cube]
32 impl<T: CubePrimitive + Clone> Array<T> {
33 #[allow(unused_variables)]
35 pub fn new(length: u32) -> Self {
36 intrinsic!(|scope| {
37 let size = length
38 .constant()
39 .expect("Array needs constant initialization value")
40 .as_u32();
41 let elem = T::as_elem(scope);
42 scope.create_local_array(Item::new(elem), size).into()
43 })
44 }
45 }
46
47 impl<T: CubePrimitive + Clone> Array<T> {
48 #[allow(unused_variables)]
50 pub fn from_data<C: CubePrimitive>(data: impl IntoIterator<Item = C>) -> Self {
51 Array { _val: PhantomData }
52 }
53
54 pub fn __expand_from_data<C: CubePrimitive>(
56 scope: &mut Scope,
57 data: ArrayData<C>,
58 ) -> <Self as CubeType>::ExpandType {
59 let var = scope.create_const_array(Item::new(T::as_elem(scope)), data.values);
60 ExpandElementTyped::new(var)
61 }
62 }
63
64 pub struct ArrayData<C> {
66 values: Vec<Variable>,
67 _ty: PhantomData<C>,
68 }
69
70 impl<C: CubePrimitive + Into<ExpandElementTyped<C>>, T: IntoIterator<Item = C>> From<T>
71 for ArrayData<C>
72 {
73 fn from(value: T) -> Self {
74 let values: Vec<Variable> = value
75 .into_iter()
76 .map(|value| {
77 let value: ExpandElementTyped<C> = value.into();
78 *value.expand
79 })
80 .collect();
81 ArrayData {
82 values,
83 _ty: PhantomData,
84 }
85 }
86 }
87}
88
89mod line {
91 use crate::prelude::Line;
92
93 use super::*;
94
95 impl<P: CubePrimitive> Array<Line<P>> {
96 pub fn line_size(&self) -> u32 {
104 unexpanded!()
105 }
106
107 pub fn __expand_line_size(
109 expand: <Self as CubeType>::ExpandType,
110 scope: &mut Scope,
111 ) -> u32 {
112 expand.__expand_line_size_method(scope)
113 }
114 }
115
116 impl<P: CubePrimitive> ExpandElementTyped<Array<Line<P>>> {
117 pub fn line_size(&self) -> u32 {
119 self.expand
120 .item
121 .vectorization
122 .unwrap_or(NonZero::new(1).unwrap())
123 .get() as u32
124 }
125
126 pub fn __expand_line_size_method(&self, _content: &mut Scope) -> u32 {
128 self.line_size()
129 }
130 }
131}
132
133mod vectorization {
137
138 use super::*;
139
140 #[cube]
141 impl<T: CubePrimitive + Clone> Array<T> {
142 #[allow(unused_variables)]
143 pub fn vectorized(#[comptime] length: u32, #[comptime] vectorization_factor: u32) -> Self {
144 intrinsic!(|scope| {
145 scope
146 .create_local_array(
147 Item::vectorized(
148 T::as_elem(scope),
149 NonZero::new(vectorization_factor as u8),
150 ),
151 length,
152 )
153 .into()
154 })
155 }
156
157 #[allow(unused_variables)]
158 pub fn to_vectorized(self, #[comptime] vectorization_factor: u32) -> T {
159 intrinsic!(|scope| {
160 let factor = vectorization_factor;
161 let var = self.expand.clone();
162 let item = Item::vectorized(var.item.elem(), NonZero::new(factor as u8));
163
164 let new_var = if factor == 1 {
165 let new_var = scope.create_local(item);
166 let element = index::expand(
167 scope,
168 self.clone(),
169 ExpandElementTyped::from_lit(scope, 0u32),
170 );
171 assign::expand_no_check::<T>(scope, element, new_var.clone().into());
172 new_var
173 } else {
174 let new_var = scope.create_local_mut(item);
175 for i in 0..factor {
176 let expand: Self = self.expand.clone().into();
177 let element =
178 index::expand(scope, expand, ExpandElementTyped::from_lit(scope, i));
179 index_assign::expand::<ExpandElementTyped<Array<T>>, T>(
180 scope,
181 new_var.clone().into(),
182 ExpandElementTyped::from_lit(scope, i),
183 element,
184 );
185 }
186 new_var
187 };
188 new_var.into()
189 })
190 }
191 }
192}
193
194mod metadata {
196 use crate::{ir::Instruction, prelude::expand_length_native};
197
198 use super::*;
199
200 #[cube]
201 impl<E: CubeType> Array<E> {
202 #[allow(clippy::len_without_is_empty)]
204 pub fn len(&self) -> u32 {
205 intrinsic!(|scope| {
206 ExpandElement::Plain(expand_length_native(scope, *self.expand)).into()
207 })
208 }
209
210 pub fn buffer_len(&self) -> u32 {
212 intrinsic!(|scope| {
213 let out = scope.create_local(Item::new(u32::as_elem(scope)));
214 scope.register(Instruction::new(
215 Metadata::BufferLength {
216 var: self.expand.into(),
217 },
218 out.clone().into(),
219 ));
220 out.into()
221 })
222 }
223 }
224}
225
226mod indexation {
228 use cubecl_ir::{IndexAssignOperator, IndexOperator, Operator};
229
230 use crate::{
231 ir::Instruction,
232 prelude::{CubeIndex, CubeIndexMut},
233 };
234
235 use super::*;
236
237 #[cube]
238 impl<E: CubePrimitive> Array<E> {
239 #[allow(unused_variables)]
245 pub unsafe fn index_unchecked(&self, i: u32) -> &E
246 where
247 Self: CubeIndex,
248 {
249 intrinsic!(|scope| {
250 let out = scope.create_local(self.expand.item);
251 scope.register(Instruction::new(
252 Operator::UncheckedIndex(IndexOperator {
253 list: *self.expand,
254 index: i.expand.consume(),
255 line_size: 0,
256 }),
257 *out,
258 ));
259 out.into()
260 })
261 }
262
263 #[allow(unused_variables)]
269 pub unsafe fn index_assign_unchecked(&mut self, i: u32, value: E)
270 where
271 Self: CubeIndexMut,
272 {
273 intrinsic!(|scope| {
274 scope.register(Instruction::new(
275 Operator::UncheckedIndexAssign(IndexAssignOperator {
276 index: i.expand.consume(),
277 value: value.expand.consume(),
278 line_size: 0,
279 }),
280 *self.expand,
281 ));
282 })
283 }
284 }
285}
286
287impl<C: CubeType> CubeType for Array<C> {
288 type ExpandType = ExpandElementTyped<Array<C>>;
289}
290
291impl<C: CubeType> CubeType for &Array<C> {
292 type ExpandType = ExpandElementTyped<Array<C>>;
293}
294
295impl<C: CubeType> ExpandElementIntoMut for Array<C> {
296 fn elem_into_mut(_scope: &mut crate::ir::Scope, elem: ExpandElement) -> ExpandElement {
297 elem
299 }
300}
301
302impl<T: CubePrimitive> SizedContainer for Array<T> {
303 type Item = T;
304}
305
306impl<T: CubeType> Iterator for &Array<T> {
307 type Item = T;
308
309 fn next(&mut self) -> Option<Self::Item> {
310 unexpanded!()
311 }
312}
313
314impl<T: CubePrimitive> List<T> for Array<T> {
315 fn __expand_read(
316 scope: &mut Scope,
317 this: ExpandElementTyped<Array<T>>,
318 idx: ExpandElementTyped<u32>,
319 ) -> ExpandElementTyped<T> {
320 index::expand(scope, this, idx)
321 }
322}
323
324impl<T: CubePrimitive> ListExpand<T> for ExpandElementTyped<Array<T>> {
325 fn __expand_read_method(
326 &self,
327 scope: &mut Scope,
328 idx: ExpandElementTyped<u32>,
329 ) -> ExpandElementTyped<T> {
330 index::expand(scope, self.clone(), idx)
331 }
332 fn __expand_read_unchecked_method(
333 &self,
334 scope: &mut Scope,
335 idx: ExpandElementTyped<u32>,
336 ) -> ExpandElementTyped<T> {
337 index_unchecked::expand(scope, self.clone(), idx)
338 }
339}
340
341impl<T: CubePrimitive> ListMut<T> for Array<T> {
342 fn __expand_write(
343 scope: &mut Scope,
344 this: ExpandElementTyped<Array<T>>,
345 idx: ExpandElementTyped<u32>,
346 value: ExpandElementTyped<T>,
347 ) {
348 index_assign::expand(scope, this, idx, value);
349 }
350}
351
352impl<T: CubePrimitive> ListMutExpand<T> for ExpandElementTyped<Array<T>> {
353 fn __expand_write_method(
354 &self,
355 scope: &mut Scope,
356 idx: ExpandElementTyped<u32>,
357 value: ExpandElementTyped<T>,
358 ) {
359 index_assign::expand(scope, self.clone(), idx, value);
360 }
361}