cubecl_core/frontend/container/array/
base.rs1use std::{marker::PhantomData, num::NonZero};
2
3use cubecl_ir::{ExpandElement, Scope};
4
5use crate::frontend::{CubePrimitive, ExpandElementBaseInit, ExpandElementTyped};
6use crate::prelude::{List, ListExpand, ListMut, ListMutExpand, SizedContainer};
7use crate::{
8 frontend::CubeType,
9 ir::{Item, Metadata},
10 unexpanded,
11};
12use crate::{
13 frontend::indexation::Index,
14 prelude::{assign, index, index_assign},
15};
16
17pub struct Array<E> {
19 _val: PhantomData<E>,
20}
21
22mod new {
24 use super::*;
25 use crate::ir::Variable;
26
27 impl<T: CubePrimitive + Clone> Array<T> {
28 #[allow(unused_variables)]
30 pub fn new<L: Index>(length: L) -> Self {
31 Array { _val: PhantomData }
32 }
33
34 pub fn from_data<C: CubePrimitive>(_data: impl IntoIterator<Item = C>) -> Self {
36 Array { _val: PhantomData }
37 }
38
39 pub fn __expand_new(
41 scope: &mut Scope,
42 size: ExpandElementTyped<u32>,
43 ) -> <Self as CubeType>::ExpandType {
44 let size = size
45 .constant()
46 .expect("Array need constant initialization value")
47 .as_u32();
48 let elem = T::as_elem(scope);
49 scope.create_local_array(Item::new(elem), size).into()
50 }
51
52 pub fn __expand_from_data<C: CubePrimitive>(
54 scope: &mut Scope,
55 data: ArrayData<C>,
56 ) -> <Self as CubeType>::ExpandType {
57 let var = scope.create_const_array(Item::new(T::as_elem(scope)), data.values);
58 ExpandElementTyped::new(var)
59 }
60 }
61
62 pub struct ArrayData<C> {
64 values: Vec<Variable>,
65 _ty: PhantomData<C>,
66 }
67
68 impl<C: CubePrimitive + Into<ExpandElementTyped<C>>, T: IntoIterator<Item = C>> From<T>
69 for ArrayData<C>
70 {
71 fn from(value: T) -> Self {
72 let values: Vec<Variable> = value
73 .into_iter()
74 .map(|value| {
75 let value: ExpandElementTyped<C> = value.into();
76 *value.expand
77 })
78 .collect();
79 ArrayData {
80 values,
81 _ty: PhantomData,
82 }
83 }
84 }
85}
86
87mod line {
89 use crate::prelude::Line;
90
91 use super::*;
92
93 impl<P: CubePrimitive> Array<Line<P>> {
94 pub fn line_size(&self) -> u32 {
102 unexpanded!()
103 }
104
105 pub fn __expand_line_size(
107 expand: <Self as CubeType>::ExpandType,
108 scope: &mut Scope,
109 ) -> u32 {
110 expand.__expand_line_size_method(scope)
111 }
112 }
113
114 impl<P: CubePrimitive> ExpandElementTyped<Array<Line<P>>> {
115 pub fn line_size(&self) -> u32 {
117 self.expand
118 .item
119 .vectorization
120 .unwrap_or(NonZero::new(1).unwrap())
121 .get() as u32
122 }
123
124 pub fn __expand_line_size_method(&self, _content: &mut Scope) -> u32 {
126 self.line_size()
127 }
128 }
129}
130
131mod vectorization {
135 use super::*;
136
137 impl<T: CubePrimitive + Clone> Array<T> {
138 #[allow(unused_variables)]
139 pub fn vectorized<L: Index>(length: L, vectorization_factor: u32) -> Self {
140 Array { _val: PhantomData }
141 }
142
143 pub fn to_vectorized(self, _vectorization_factor: u32) -> T {
144 unexpanded!()
145 }
146
147 pub fn __expand_vectorized(
148 scope: &mut Scope,
149 size: ExpandElementTyped<u32>,
150 vectorization_factor: u32,
151 ) -> <Self as CubeType>::ExpandType {
152 let size = size.value();
153 let size = match size.kind {
154 crate::ir::VariableKind::ConstantScalar(value) => value.as_u32(),
155 _ => panic!("Shared memory need constant initialization value"),
156 };
157 scope
158 .create_local_array(
159 Item::vectorized(T::as_elem(scope), NonZero::new(vectorization_factor as u8)),
160 size,
161 )
162 .into()
163 }
164 }
165
166 impl<C: CubePrimitive> ExpandElementTyped<Array<C>> {
167 pub fn __expand_to_vectorized_method(
168 self,
169 scope: &mut Scope,
170 vectorization_factor: ExpandElementTyped<u32>,
171 ) -> ExpandElementTyped<C> {
172 let factor = vectorization_factor
173 .constant()
174 .expect("Vectorization must be comptime")
175 .as_u32();
176 let var = self.expand.clone();
177 let item = Item::vectorized(var.item.elem(), NonZero::new(factor as u8));
178
179 let new_var = if factor == 1 {
180 let new_var = scope.create_local(item);
181 let element = index::expand(
182 scope,
183 self.clone(),
184 ExpandElementTyped::from_lit(scope, 0u32),
185 );
186 assign::expand::<C>(scope, element, new_var.clone().into());
187 new_var
188 } else {
189 let new_var = scope.create_local_mut(item);
190 for i in 0..factor {
191 let expand: Self = self.expand.clone().into();
192 let element =
193 index::expand(scope, expand, ExpandElementTyped::from_lit(scope, i));
194 index_assign::expand::<Array<C>>(
195 scope,
196 new_var.clone().into(),
197 ExpandElementTyped::from_lit(scope, i),
198 element,
199 );
200 }
201 new_var
202 };
203 new_var.into()
204 }
205 }
206}
207
208mod metadata {
210 use crate::ir::Instruction;
211
212 use super::*;
213
214 impl<E: CubeType> Array<E> {
215 #[allow(clippy::len_without_is_empty)]
217 pub fn len(&self) -> u32 {
218 unexpanded!()
219 }
220
221 pub fn buffer_len(&self) -> u32 {
223 unexpanded!()
224 }
225 }
226
227 impl<T: CubeType> ExpandElementTyped<Array<T>> {
228 pub fn __expand_len_method(self, scope: &mut Scope) -> ExpandElementTyped<u32> {
230 let out = scope.create_local(Item::new(u32::as_elem(scope)));
231 scope.register(Instruction::new(
232 Metadata::Length {
233 var: self.expand.into(),
234 },
235 out.clone().into(),
236 ));
237 out.into()
238 }
239
240 pub fn __expand_buffer_len_method(self, scope: &mut Scope) -> ExpandElementTyped<u32> {
242 let out = scope.create_local(Item::new(u32::as_elem(scope)));
243 scope.register(Instruction::new(
244 Metadata::BufferLength {
245 var: self.expand.into(),
246 },
247 out.clone().into(),
248 ));
249 out.into()
250 }
251 }
252}
253
254mod indexation {
256 use cubecl_ir::Operator;
257
258 use crate::{
259 ir::{BinaryOperator, Instruction},
260 prelude::{CubeIndex, CubeIndexMut},
261 };
262
263 use super::*;
264
265 impl<E: CubePrimitive> Array<E> {
266 pub unsafe fn index_unchecked<I: Index>(&self, _i: I) -> &E
272 where
273 Self: CubeIndex<I>,
274 {
275 unexpanded!()
276 }
277
278 pub unsafe fn index_assign_unchecked<I: Index>(&mut self, _i: I, _value: E)
284 where
285 Self: CubeIndexMut<I>,
286 {
287 unexpanded!()
288 }
289 }
290
291 impl<E: CubePrimitive> ExpandElementTyped<Array<E>> {
292 pub fn __expand_index_unchecked_method(
293 self,
294 scope: &mut Scope,
295 i: ExpandElementTyped<u32>,
296 ) -> ExpandElementTyped<E> {
297 let out = scope.create_local(self.expand.item);
298 scope.register(Instruction::new(
299 Operator::UncheckedIndex(BinaryOperator {
300 lhs: *self.expand,
301 rhs: i.expand.consume(),
302 }),
303 *out,
304 ));
305 out.into()
306 }
307
308 pub fn __expand_index_assign_unchecked_method(
309 self,
310 scope: &mut Scope,
311 i: ExpandElementTyped<u32>,
312 value: ExpandElementTyped<E>,
313 ) {
314 scope.register(Instruction::new(
315 Operator::UncheckedIndexAssign(BinaryOperator {
316 lhs: i.expand.consume(),
317 rhs: value.expand.consume(),
318 }),
319 *self.expand,
320 ));
321 }
322 }
323}
324
325impl<C: CubeType> CubeType for Array<C> {
326 type ExpandType = ExpandElementTyped<Array<C>>;
327}
328
329impl<C: CubeType> CubeType for &Array<C> {
330 type ExpandType = ExpandElementTyped<Array<C>>;
331}
332
333impl<C: CubeType> ExpandElementBaseInit for Array<C> {
334 fn init_elem(_scope: &mut crate::ir::Scope, elem: ExpandElement) -> ExpandElement {
335 elem
337 }
338}
339
340impl<T: CubeType<ExpandType = ExpandElementTyped<T>>> SizedContainer for Array<T> {
341 type Item = T;
342}
343
344impl<T: CubeType> Iterator for &Array<T> {
345 type Item = T;
346
347 fn next(&mut self) -> Option<Self::Item> {
348 unexpanded!()
349 }
350}
351
352impl<T: CubePrimitive> List<T> for Array<T> {
353 fn __expand_read(
354 scope: &mut Scope,
355 this: ExpandElementTyped<Array<T>>,
356 idx: ExpandElementTyped<u32>,
357 ) -> ExpandElementTyped<T> {
358 index::expand(scope, this, idx)
359 }
360}
361
362impl<T: CubePrimitive> ListExpand<T> for ExpandElementTyped<Array<T>> {
363 fn __expand_read_method(
364 self,
365 scope: &mut Scope,
366 idx: ExpandElementTyped<u32>,
367 ) -> ExpandElementTyped<T> {
368 index::expand(scope, self, idx)
369 }
370}
371
372impl<T: CubePrimitive> ListMut<T> for Array<T> {
373 fn __expand_write(
374 scope: &mut Scope,
375 this: ExpandElementTyped<Array<T>>,
376 idx: ExpandElementTyped<u32>,
377 value: ExpandElementTyped<T>,
378 ) {
379 index_assign::expand(scope, this, idx, value);
380 }
381}
382
383impl<T: CubePrimitive> ListMutExpand<T> for ExpandElementTyped<Array<T>> {
384 fn __expand_write_method(
385 self,
386 scope: &mut Scope,
387 idx: ExpandElementTyped<u32>,
388 value: ExpandElementTyped<T>,
389 ) {
390 index_assign::expand(scope, self, idx, value);
391 }
392}