cubecl_core/frontend/container/array/
base.rs1use std::{
2 marker::PhantomData,
3 ops::{Deref, DerefMut},
4};
5
6use cubecl_ir::{ExpandElement, LineSize, Scope};
7
8use crate::prelude::{
9 LinedExpand, List, ListExpand, ListMut, ListMutExpand, SizedContainer, index_unchecked,
10};
11use crate::prelude::{assign, index, index_assign};
12use crate::{self as cubecl};
13use crate::{
14 frontend::CubeType,
15 ir::{Metadata, Type},
16 unexpanded,
17};
18use crate::{
19 frontend::{CubePrimitive, ExpandElementIntoMut, ExpandElementTyped},
20 prelude::Lined,
21};
22use cubecl_macros::{cube, intrinsic};
23
24pub struct Array<E> {
26 _val: PhantomData<E>,
27}
28
29type ArrayExpand<E> = ExpandElementTyped<Array<E>>;
30
31mod new {
33
34 use cubecl_macros::intrinsic;
35
36 use super::*;
37 use crate::ir::Variable;
38
39 #[cube]
40 impl<T: CubePrimitive + Clone> Array<T> {
41 #[allow(unused_variables)]
43 pub fn new(#[comptime] length: usize) -> Self {
44 intrinsic!(|scope| {
45 let elem = T::as_type(scope);
46 scope.create_local_array(Type::new(elem), length).into()
47 })
48 }
49 }
50
51 impl<T: CubePrimitive + Clone> Array<T> {
52 #[allow(unused_variables)]
54 pub fn from_data<C: CubePrimitive>(data: impl IntoIterator<Item = C>) -> Self {
55 intrinsic!(|scope| {
56 scope
57 .create_const_array(Type::new(T::as_type(scope)), data.values)
58 .into()
59 })
60 }
61
62 pub fn __expand_from_data<C: CubePrimitive>(
64 scope: &mut Scope,
65 data: ArrayData<C>,
66 ) -> <Self as CubeType>::ExpandType {
67 let var = scope.create_const_array(Type::new(T::as_type(scope)), data.values);
68 ExpandElementTyped::new(var)
69 }
70 }
71
72 pub struct ArrayData<C> {
74 values: Vec<Variable>,
75 _ty: PhantomData<C>,
76 }
77
78 impl<C: CubePrimitive + Into<ExpandElementTyped<C>>, T: IntoIterator<Item = C>> From<T>
79 for ArrayData<C>
80 {
81 fn from(value: T) -> Self {
82 let values: Vec<Variable> = value
83 .into_iter()
84 .map(|value| {
85 let value: ExpandElementTyped<C> = value.into();
86 *value.expand
87 })
88 .collect();
89 ArrayData {
90 values,
91 _ty: PhantomData,
92 }
93 }
94 }
95}
96
97mod line {
99 use crate::prelude::Line;
100
101 use super::*;
102
103 impl<P: CubePrimitive> Array<Line<P>> {
104 pub fn line_size(&self) -> LineSize {
112 unexpanded!()
113 }
114
115 pub fn __expand_line_size(
117 expand: <Self as CubeType>::ExpandType,
118 scope: &mut Scope,
119 ) -> LineSize {
120 expand.__expand_line_size_method(scope)
121 }
122 }
123}
124
125mod vectorization {
129
130 use cubecl_ir::LineSize;
131
132 use super::*;
133
134 #[cube]
135 impl<T: CubePrimitive + Clone> Array<T> {
136 #[allow(unused_variables)]
137 pub fn lined(#[comptime] length: usize, #[comptime] line_size: LineSize) -> Self {
138 intrinsic!(|scope| {
139 scope
140 .create_local_array(Type::new(T::as_type(scope)).line(line_size), length)
141 .into()
142 })
143 }
144
145 #[allow(unused_variables)]
146 pub fn to_lined(self, #[comptime] line_size: LineSize) -> T {
147 intrinsic!(|scope| {
148 let factor = line_size;
149 let var = self.expand.clone();
150 let item = Type::new(var.storage_type()).line(factor);
151
152 let new_var = if factor == 1 {
153 let new_var = scope.create_local(item);
154 let element =
155 index::expand(scope, self.clone(), ExpandElementTyped::from_lit(scope, 0));
156 assign::expand_no_check::<T>(scope, element, new_var.clone().into());
157 new_var
158 } else {
159 let new_var = scope.create_local_mut(item);
160 for i in 0..factor {
161 let expand: Self = self.expand.clone().into();
162 let element =
163 index::expand(scope, expand, ExpandElementTyped::from_lit(scope, i));
164 index_assign::expand::<ExpandElementTyped<Array<T>>, T>(
165 scope,
166 new_var.clone().into(),
167 ExpandElementTyped::from_lit(scope, i),
168 element,
169 );
170 }
171 new_var
172 };
173 new_var.into()
174 })
175 }
176 }
177}
178
179mod metadata {
181 use crate::{ir::Instruction, prelude::expand_length_native};
182
183 use super::*;
184
185 #[cube]
186 impl<E: CubeType> Array<E> {
187 #[allow(clippy::len_without_is_empty)]
189 pub fn len(&self) -> usize {
190 intrinsic!(|scope| {
191 ExpandElement::Plain(expand_length_native(scope, *self.expand)).into()
192 })
193 }
194
195 pub fn buffer_len(&self) -> usize {
197 intrinsic!(|scope| {
198 let out = scope.create_local(Type::new(usize::as_type(scope)));
199 scope.register(Instruction::new(
200 Metadata::BufferLength {
201 var: self.expand.into(),
202 },
203 out.clone().into(),
204 ));
205 out.into()
206 })
207 }
208 }
209}
210
211mod indexation {
213 use cubecl_ir::{IndexAssignOperator, IndexOperator, Operator};
214
215 use crate::{
216 ir::Instruction,
217 prelude::{CubeIndex, CubeIndexMut},
218 };
219
220 use super::*;
221
222 #[cube]
223 impl<E: CubePrimitive> Array<E> {
224 #[allow(unused_variables)]
230 pub unsafe fn index_unchecked(&self, i: usize) -> &E
231 where
232 Self: CubeIndex,
233 {
234 intrinsic!(|scope| {
235 let out = scope.create_local(self.expand.ty);
236 scope.register(Instruction::new(
237 Operator::UncheckedIndex(IndexOperator {
238 list: *self.expand,
239 index: i.expand.consume(),
240 line_size: 0,
241 unroll_factor: 1,
242 }),
243 *out,
244 ));
245 out.into()
246 })
247 }
248
249 #[allow(unused_variables)]
255 pub unsafe fn index_assign_unchecked(&mut self, i: usize, value: E)
256 where
257 Self: CubeIndexMut,
258 {
259 intrinsic!(|scope| {
260 scope.register(Instruction::new(
261 Operator::UncheckedIndexAssign(IndexAssignOperator {
262 index: i.expand.consume(),
263 value: value.expand.consume(),
264 line_size: 0,
265 unroll_factor: 1,
266 }),
267 *self.expand,
268 ));
269 })
270 }
271 }
272}
273
274impl<C: CubeType> CubeType for Array<C> {
275 type ExpandType = ExpandElementTyped<Array<C>>;
276}
277
278impl<C: CubeType> CubeType for &Array<C> {
279 type ExpandType = ExpandElementTyped<Array<C>>;
280}
281
282impl<C: CubeType> ExpandElementIntoMut for Array<C> {
283 fn elem_into_mut(_scope: &mut crate::ir::Scope, elem: ExpandElement) -> ExpandElement {
284 elem
286 }
287}
288
289impl<T: CubePrimitive> SizedContainer for Array<T> {
290 type Item = T;
291}
292
293impl<T: CubeType> Iterator for &Array<T> {
294 type Item = T;
295
296 fn next(&mut self) -> Option<Self::Item> {
297 unexpanded!()
298 }
299}
300
301impl<T: CubePrimitive> List<T> for Array<T> {
302 fn __expand_read(
303 scope: &mut Scope,
304 this: ExpandElementTyped<Array<T>>,
305 idx: ExpandElementTyped<usize>,
306 ) -> ExpandElementTyped<T> {
307 index::expand(scope, this, idx)
308 }
309}
310
311impl<T: CubePrimitive> Deref for Array<T> {
312 type Target = [T];
313
314 fn deref(&self) -> &Self::Target {
315 unexpanded!()
316 }
317}
318
319impl<T: CubePrimitive> DerefMut for Array<T> {
320 fn deref_mut(&mut self) -> &mut Self::Target {
321 unexpanded!()
322 }
323}
324
325impl<T: CubePrimitive> ListExpand<T> for ExpandElementTyped<Array<T>> {
326 fn __expand_read_method(
327 &self,
328 scope: &mut Scope,
329 idx: ExpandElementTyped<usize>,
330 ) -> ExpandElementTyped<T> {
331 index::expand(scope, self.clone(), idx)
332 }
333 fn __expand_read_unchecked_method(
334 &self,
335 scope: &mut Scope,
336 idx: ExpandElementTyped<usize>,
337 ) -> ExpandElementTyped<T> {
338 index_unchecked::expand(scope, self.clone(), idx)
339 }
340
341 fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<usize> {
342 Self::__expand_len(scope, self.clone())
343 }
344}
345
346impl<T: CubePrimitive> Lined for Array<T> {}
347impl<T: CubePrimitive> LinedExpand for ExpandElementTyped<Array<T>> {
348 fn line_size(&self) -> LineSize {
349 self.expand.ty.line_size()
350 }
351}
352
353impl<T: CubePrimitive> ListMut<T> for Array<T> {
354 fn __expand_write(
355 scope: &mut Scope,
356 this: ExpandElementTyped<Array<T>>,
357 idx: ExpandElementTyped<usize>,
358 value: ExpandElementTyped<T>,
359 ) {
360 index_assign::expand(scope, this, idx, value);
361 }
362}
363
364impl<T: CubePrimitive> ListMutExpand<T> for ExpandElementTyped<Array<T>> {
365 fn __expand_write_method(
366 &self,
367 scope: &mut Scope,
368 idx: ExpandElementTyped<usize>,
369 value: ExpandElementTyped<T>,
370 ) {
371 index_assign::expand(scope, self.clone(), idx, value);
372 }
373}