cubecl_core/frontend/container/array/
base.rs1use std::{marker::PhantomData, num::NonZero};
2
3use crate::frontend::{
4 CubePrimitive, ExpandElement, ExpandElementBaseInit, ExpandElementTyped, IntoRuntime,
5};
6use crate::prelude::SizedContainer;
7use crate::{
8 frontend::CubeType,
9 ir::{Item, Metadata},
10 unexpanded,
11};
12use crate::{
13 frontend::{indexation::Index, CubeContext},
14 prelude::{assign, index, index_assign},
15};
16
17pub struct Array<E> {
19 _val: PhantomData<E>,
20}
21
22mod new {
24 use super::*;
25 use crate::ir::Variable;
26
27 impl<T: CubePrimitive + Clone> Array<T> {
28 #[allow(unused_variables)]
30 pub fn new<L: Index>(length: L) -> Self {
31 Array { _val: PhantomData }
32 }
33
34 pub fn from_data<C: CubePrimitive>(_data: impl IntoIterator<Item = C>) -> Self {
36 Array { _val: PhantomData }
37 }
38
39 pub fn __expand_new(
41 context: &mut CubeContext,
42 size: ExpandElementTyped<u32>,
43 ) -> <Self as CubeType>::ExpandType {
44 let size = size
45 .constant()
46 .expect("Array need constant initialization value")
47 .as_u32();
48 let elem = T::as_elem(context);
49 context.create_local_array(Item::new(elem), size).into()
50 }
51
52 pub fn __expand_from_data<C: CubePrimitive>(
54 context: &mut CubeContext,
55 data: ArrayData<C>,
56 ) -> <Self as CubeType>::ExpandType {
57 let var = context.create_const_array(Item::new(T::as_elem(context)), data.values);
58 ExpandElementTyped::new(var)
59 }
60 }
61
62 pub struct ArrayData<C> {
64 values: Vec<Variable>,
65 _ty: PhantomData<C>,
66 }
67
68 impl<C: CubePrimitive + Into<ExpandElementTyped<C>>, T: IntoIterator<Item = C>> From<T>
69 for ArrayData<C>
70 {
71 fn from(value: T) -> Self {
72 let values: Vec<Variable> = value
73 .into_iter()
74 .map(|value| {
75 let value: ExpandElementTyped<C> = value.into();
76 *value.expand
77 })
78 .collect();
79 ArrayData {
80 values,
81 _ty: PhantomData,
82 }
83 }
84 }
85}
86
87mod line {
89 use crate::prelude::Line;
90
91 use super::*;
92
93 impl<P: CubePrimitive> Array<Line<P>> {
94 pub fn line_size(&self) -> u32 {
102 unexpanded!()
103 }
104
105 pub fn __expand_line_size(
107 expand: <Self as CubeType>::ExpandType,
108 context: &mut CubeContext,
109 ) -> u32 {
110 expand.__expand_line_size_method(context)
111 }
112 }
113
114 impl<P: CubePrimitive> ExpandElementTyped<Array<Line<P>>> {
115 pub fn line_size(&self) -> u32 {
117 self.expand
118 .item
119 .vectorization
120 .unwrap_or(NonZero::new(1).unwrap())
121 .get() as u32
122 }
123
124 pub fn __expand_line_size_method(&self, _content: &mut CubeContext) -> u32 {
126 self.line_size()
127 }
128 }
129}
130
131mod vectorization {
135 use super::*;
136
137 impl<T: CubePrimitive + Clone> Array<T> {
138 #[allow(unused_variables)]
139 pub fn vectorized<L: Index>(length: L, vectorization_factor: u32) -> Self {
140 Array { _val: PhantomData }
141 }
142
143 pub fn to_vectorized(self, _vectorization_factor: u32) -> T {
144 unexpanded!()
145 }
146
147 pub fn __expand_vectorized(
148 context: &mut CubeContext,
149 size: ExpandElementTyped<u32>,
150 vectorization_factor: u32,
151 ) -> <Self as CubeType>::ExpandType {
152 let size = size.value();
153 let size = match size.kind {
154 crate::ir::VariableKind::ConstantScalar(value) => value.as_u32(),
155 _ => panic!("Shared memory need constant initialization value"),
156 };
157 context
158 .create_local_array(
159 Item::vectorized(
160 T::as_elem(context),
161 NonZero::new(vectorization_factor as u8),
162 ),
163 size,
164 )
165 .into()
166 }
167 }
168
169 impl<C: CubePrimitive> ExpandElementTyped<Array<C>> {
170 pub fn __expand_to_vectorized_method(
171 self,
172 context: &mut CubeContext,
173 vectorization_factor: ExpandElementTyped<u32>,
174 ) -> ExpandElementTyped<C> {
175 let factor = vectorization_factor
176 .constant()
177 .expect("Vectorization must be comptime")
178 .as_u32();
179 let var = self.expand.clone();
180 let item = Item::vectorized(var.item.elem(), NonZero::new(factor as u8));
181
182 let new_var = if factor == 1 {
183 let new_var = context.create_local(item);
184 let element = index::expand(
185 context,
186 self.clone(),
187 ExpandElementTyped::from_lit(context, 0u32),
188 );
189 assign::expand(context, element, new_var.clone().into());
190 new_var
191 } else {
192 let new_var = context.create_local_mut(item);
193 for i in 0..factor {
194 let expand: Self = self.expand.clone().into();
195 let element =
196 index::expand(context, expand, ExpandElementTyped::from_lit(context, i));
197 index_assign::expand::<Array<C>>(
198 context,
199 new_var.clone().into(),
200 ExpandElementTyped::from_lit(context, i),
201 element,
202 );
203 }
204 new_var
205 };
206 new_var.into()
207 }
208 }
209}
210
211mod metadata {
213 use crate::ir::Instruction;
214
215 use super::*;
216
217 impl<E: CubeType> Array<E> {
218 #[allow(clippy::len_without_is_empty)]
220 pub fn len(&self) -> u32 {
221 unexpanded!()
222 }
223
224 pub fn buffer_len(&self) -> u32 {
226 unexpanded!()
227 }
228 }
229
230 impl<T: CubeType> ExpandElementTyped<Array<T>> {
231 pub fn __expand_len_method(self, context: &mut CubeContext) -> ExpandElementTyped<u32> {
233 let out = context.create_local(Item::new(u32::as_elem(context)));
234 context.register(Instruction::new(
235 Metadata::Length {
236 var: self.expand.into(),
237 },
238 out.clone().into(),
239 ));
240 out.into()
241 }
242
243 pub fn __expand_buffer_len_method(
245 self,
246 context: &mut CubeContext,
247 ) -> ExpandElementTyped<u32> {
248 let out = context.create_local(Item::new(u32::as_elem(context)));
249 context.register(Instruction::new(
250 Metadata::BufferLength {
251 var: self.expand.into(),
252 },
253 out.clone().into(),
254 ));
255 out.into()
256 }
257 }
258}
259
260mod indexation {
262 use crate::{
263 ir::{BinaryOperator, Instruction, Operator},
264 prelude::{CubeIndex, CubeIndexMut},
265 };
266
267 use super::*;
268
269 impl<E: CubePrimitive> Array<E> {
270 pub unsafe fn index_unchecked<I: Index>(&self, _i: I) -> &E
276 where
277 Self: CubeIndex<I>,
278 {
279 unexpanded!()
280 }
281
282 pub unsafe fn index_assign_unchecked<I: Index>(&mut self, _i: I, _value: E)
288 where
289 Self: CubeIndexMut<I>,
290 {
291 unexpanded!()
292 }
293 }
294
295 impl<E: CubePrimitive> ExpandElementTyped<Array<E>> {
296 pub fn __expand_index_unchecked_method(
297 self,
298 context: &mut CubeContext,
299 i: ExpandElementTyped<u32>,
300 ) -> ExpandElementTyped<E> {
301 let out = context.create_local(self.expand.item);
302 context.register(Instruction::new(
303 Operator::UncheckedIndex(BinaryOperator {
304 lhs: *self.expand,
305 rhs: i.expand.consume(),
306 }),
307 *out,
308 ));
309 out.into()
310 }
311
312 pub fn __expand_index_assign_unchecked_method(
313 self,
314 context: &mut CubeContext,
315 i: ExpandElementTyped<u32>,
316 value: ExpandElementTyped<E>,
317 ) {
318 context.register(Instruction::new(
319 Operator::UncheckedIndexAssign(BinaryOperator {
320 lhs: i.expand.consume(),
321 rhs: value.expand.consume(),
322 }),
323 *self.expand,
324 ));
325 }
326 }
327}
328
329impl<E: CubePrimitive> IntoRuntime for Array<E> {
330 fn __expand_runtime_method(self, _context: &mut CubeContext) -> Self::ExpandType {
331 unimplemented!("Array can't exist at compile time")
332 }
333}
334
335impl<C: CubeType> CubeType for Array<C> {
336 type ExpandType = ExpandElementTyped<Array<C>>;
337}
338
339impl<C: CubeType> CubeType for &Array<C> {
340 type ExpandType = ExpandElementTyped<Array<C>>;
341}
342
343impl<C: CubeType> ExpandElementBaseInit for Array<C> {
344 fn init_elem(_context: &mut crate::prelude::CubeContext, elem: ExpandElement) -> ExpandElement {
345 elem
347 }
348}
349
350impl<T: CubeType<ExpandType = ExpandElementTyped<T>>> SizedContainer for Array<T> {
351 type Item = T;
352}
353
354impl<T: CubeType> Iterator for &Array<T> {
355 type Item = T;
356
357 fn next(&mut self) -> Option<Self::Item> {
358 unexpanded!()
359 }
360}