cubecl_core/frontend/container/array/
base.rs1use std::marker::PhantomData;
2
3use cubecl_ir::{ExpandElement, LineSize, Scope};
4
5use crate::prelude::{
6 LinedExpand, List, ListExpand, ListMut, ListMutExpand, SizedContainer, index_unchecked,
7};
8use crate::prelude::{assign, index, index_assign};
9use crate::{self as cubecl};
10use crate::{
11 frontend::CubeType,
12 ir::{Metadata, Type},
13 unexpanded,
14};
15use crate::{
16 frontend::{CubePrimitive, ExpandElementIntoMut, ExpandElementTyped},
17 prelude::Lined,
18};
19use cubecl_macros::{cube, intrinsic};
20
21pub struct Array<E> {
23 _val: PhantomData<E>,
24}
25
26type ArrayExpand<E> = ExpandElementTyped<Array<E>>;
27
28mod new {
30
31 use cubecl_macros::intrinsic;
32
33 use super::*;
34 use crate::ir::Variable;
35
36 #[cube]
37 impl<T: CubePrimitive + Clone> Array<T> {
38 #[allow(unused_variables)]
40 pub fn new(#[comptime] length: usize) -> Self {
41 intrinsic!(|scope| {
42 let elem = T::as_type(scope);
43 scope.create_local_array(Type::new(elem), length).into()
44 })
45 }
46 }
47
48 impl<T: CubePrimitive + Clone> Array<T> {
49 #[allow(unused_variables)]
51 pub fn from_data<C: CubePrimitive>(data: impl IntoIterator<Item = C>) -> Self {
52 intrinsic!(|scope| {
53 scope
54 .create_const_array(Type::new(T::as_type(scope)), data.values)
55 .into()
56 })
57 }
58
59 pub fn __expand_from_data<C: CubePrimitive>(
61 scope: &mut Scope,
62 data: ArrayData<C>,
63 ) -> <Self as CubeType>::ExpandType {
64 let var = scope.create_const_array(Type::new(T::as_type(scope)), data.values);
65 ExpandElementTyped::new(var)
66 }
67 }
68
69 pub struct ArrayData<C> {
71 values: Vec<Variable>,
72 _ty: PhantomData<C>,
73 }
74
75 impl<C: CubePrimitive + Into<ExpandElementTyped<C>>, T: IntoIterator<Item = C>> From<T>
76 for ArrayData<C>
77 {
78 fn from(value: T) -> Self {
79 let values: Vec<Variable> = value
80 .into_iter()
81 .map(|value| {
82 let value: ExpandElementTyped<C> = value.into();
83 *value.expand
84 })
85 .collect();
86 ArrayData {
87 values,
88 _ty: PhantomData,
89 }
90 }
91 }
92}
93
94mod line {
96 use crate::prelude::Line;
97
98 use super::*;
99
100 impl<P: CubePrimitive> Array<Line<P>> {
101 pub fn line_size(&self) -> LineSize {
109 unexpanded!()
110 }
111
112 pub fn __expand_line_size(
114 expand: <Self as CubeType>::ExpandType,
115 scope: &mut Scope,
116 ) -> LineSize {
117 expand.__expand_line_size_method(scope)
118 }
119 }
120}
121
122mod vectorization {
126
127 use cubecl_ir::LineSize;
128
129 use super::*;
130
131 #[cube]
132 impl<T: CubePrimitive + Clone> Array<T> {
133 #[allow(unused_variables)]
134 pub fn lined(#[comptime] length: usize, #[comptime] line_size: LineSize) -> Self {
135 intrinsic!(|scope| {
136 scope
137 .create_local_array(Type::new(T::as_type(scope)).line(line_size), length)
138 .into()
139 })
140 }
141
142 #[allow(unused_variables)]
143 pub fn to_lined(self, #[comptime] line_size: LineSize) -> T {
144 intrinsic!(|scope| {
145 let factor = line_size;
146 let var = self.expand.clone();
147 let item = Type::new(var.storage_type()).line(factor);
148
149 let new_var = if factor == 1 {
150 let new_var = scope.create_local(item);
151 let element =
152 index::expand(scope, self.clone(), ExpandElementTyped::from_lit(scope, 0));
153 assign::expand_no_check::<T>(scope, element, new_var.clone().into());
154 new_var
155 } else {
156 let new_var = scope.create_local_mut(item);
157 for i in 0..factor {
158 let expand: Self = self.expand.clone().into();
159 let element =
160 index::expand(scope, expand, ExpandElementTyped::from_lit(scope, i));
161 index_assign::expand::<ExpandElementTyped<Array<T>>, T>(
162 scope,
163 new_var.clone().into(),
164 ExpandElementTyped::from_lit(scope, i),
165 element,
166 );
167 }
168 new_var
169 };
170 new_var.into()
171 })
172 }
173 }
174}
175
176mod metadata {
178 use crate::{ir::Instruction, prelude::expand_length_native};
179
180 use super::*;
181
182 #[cube]
183 impl<E: CubeType> Array<E> {
184 #[allow(clippy::len_without_is_empty)]
186 pub fn len(&self) -> usize {
187 intrinsic!(|scope| {
188 ExpandElement::Plain(expand_length_native(scope, *self.expand)).into()
189 })
190 }
191
192 pub fn buffer_len(&self) -> usize {
194 intrinsic!(|scope| {
195 let out = scope.create_local(Type::new(usize::as_type(scope)));
196 scope.register(Instruction::new(
197 Metadata::BufferLength {
198 var: self.expand.into(),
199 },
200 out.clone().into(),
201 ));
202 out.into()
203 })
204 }
205 }
206}
207
208mod indexation {
210 use cubecl_ir::{IndexAssignOperator, IndexOperator, Operator};
211
212 use crate::{
213 ir::Instruction,
214 prelude::{CubeIndex, CubeIndexMut},
215 };
216
217 use super::*;
218
219 #[cube]
220 impl<E: CubePrimitive> Array<E> {
221 #[allow(unused_variables)]
227 pub unsafe fn index_unchecked(&self, i: usize) -> &E
228 where
229 Self: CubeIndex,
230 {
231 intrinsic!(|scope| {
232 let out = scope.create_local(self.expand.ty);
233 scope.register(Instruction::new(
234 Operator::UncheckedIndex(IndexOperator {
235 list: *self.expand,
236 index: i.expand.consume(),
237 line_size: 0,
238 unroll_factor: 1,
239 }),
240 *out,
241 ));
242 out.into()
243 })
244 }
245
246 #[allow(unused_variables)]
252 pub unsafe fn index_assign_unchecked(&mut self, i: usize, value: E)
253 where
254 Self: CubeIndexMut,
255 {
256 intrinsic!(|scope| {
257 scope.register(Instruction::new(
258 Operator::UncheckedIndexAssign(IndexAssignOperator {
259 index: i.expand.consume(),
260 value: value.expand.consume(),
261 line_size: 0,
262 unroll_factor: 1,
263 }),
264 *self.expand,
265 ));
266 })
267 }
268 }
269}
270
271impl<C: CubeType> CubeType for Array<C> {
272 type ExpandType = ExpandElementTyped<Array<C>>;
273}
274
275impl<C: CubeType> CubeType for &Array<C> {
276 type ExpandType = ExpandElementTyped<Array<C>>;
277}
278
279impl<C: CubeType> ExpandElementIntoMut for Array<C> {
280 fn elem_into_mut(_scope: &mut crate::ir::Scope, elem: ExpandElement) -> ExpandElement {
281 elem
283 }
284}
285
286impl<T: CubePrimitive> SizedContainer for Array<T> {
287 type Item = T;
288}
289
290impl<T: CubeType> Iterator for &Array<T> {
291 type Item = T;
292
293 fn next(&mut self) -> Option<Self::Item> {
294 unexpanded!()
295 }
296}
297
298impl<T: CubePrimitive> List<T> for Array<T> {
299 fn __expand_read(
300 scope: &mut Scope,
301 this: ExpandElementTyped<Array<T>>,
302 idx: ExpandElementTyped<usize>,
303 ) -> ExpandElementTyped<T> {
304 index::expand(scope, this, idx)
305 }
306}
307
308impl<T: CubePrimitive> ListExpand<T> for ExpandElementTyped<Array<T>> {
309 fn __expand_read_method(
310 &self,
311 scope: &mut Scope,
312 idx: ExpandElementTyped<usize>,
313 ) -> ExpandElementTyped<T> {
314 index::expand(scope, self.clone(), idx)
315 }
316 fn __expand_read_unchecked_method(
317 &self,
318 scope: &mut Scope,
319 idx: ExpandElementTyped<usize>,
320 ) -> ExpandElementTyped<T> {
321 index_unchecked::expand(scope, self.clone(), idx)
322 }
323
324 fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<usize> {
325 Self::__expand_len(scope, self.clone())
326 }
327}
328
329impl<T: CubePrimitive> Lined for Array<T> {}
330impl<T: CubePrimitive> LinedExpand for ExpandElementTyped<Array<T>> {
331 fn line_size(&self) -> LineSize {
332 self.expand.ty.line_size()
333 }
334}
335
336impl<T: CubePrimitive> ListMut<T> for Array<T> {
337 fn __expand_write(
338 scope: &mut Scope,
339 this: ExpandElementTyped<Array<T>>,
340 idx: ExpandElementTyped<usize>,
341 value: ExpandElementTyped<T>,
342 ) {
343 index_assign::expand(scope, this, idx, value);
344 }
345}
346
347impl<T: CubePrimitive> ListMutExpand<T> for ExpandElementTyped<Array<T>> {
348 fn __expand_write_method(
349 &self,
350 scope: &mut Scope,
351 idx: ExpandElementTyped<usize>,
352 value: ExpandElementTyped<T>,
353 ) {
354 index_assign::expand(scope, self.clone(), idx, value);
355 }
356}