cubecl_core/frontend/container/
slice.rs1use std::marker::PhantomData;
2
3use crate::{
4 frontend::{indexation::Index, Tensor},
5 ir::{self, Operator},
6 prelude::{CubeContext, IntoRuntime},
7 unexpanded,
8};
9use crate::{
10 frontend::{
11 Array, CubePrimitive, CubeType, ExpandElement, ExpandElementTyped, Init, SharedMemory,
12 SizedContainer,
13 },
14 ir::Instruction,
15};
16
17use super::Line;
18
19#[derive(Clone)]
25pub struct Slice<E> {
26 _e: PhantomData<E>,
27}
28
29pub struct SliceMut<E> {
35 _e: PhantomData<E>,
36}
37
38mod metadata {
39 use super::*;
40
41 impl<E> Slice<E> {
42 #[allow(clippy::len_without_is_empty)]
44 pub fn len(&self) -> u32 {
45 unexpanded!()
46 }
47
48 pub fn to_aligned(&self) -> Slice<Line<E>>
50 where
51 E: CubePrimitive,
52 {
53 unexpanded!()
54 }
55 }
56
57 impl<E> SliceMut<E> {
58 #[allow(clippy::len_without_is_empty)]
60 pub fn len(&self) -> u32 {
61 unexpanded!()
62 }
63
64 pub fn into_aligned(self) -> SliceMut<Line<E>>
66 where
67 E: CubePrimitive,
68 {
69 unexpanded!()
70 }
71 }
72
73 impl<C: CubeType> ExpandElementTyped<Slice<C>> {
74 pub fn __expand_len_method(self, context: &mut CubeContext) -> ExpandElementTyped<u32> {
76 let elem: ExpandElementTyped<Array<u32>> = self.expand.into();
77 elem.__expand_len_method(context)
78 }
79
80 pub fn __expand_to_aligned_method(
82 self,
83 _context: &mut CubeContext,
84 ) -> ExpandElementTyped<Slice<Line<C>>>
85 where
86 C: CubePrimitive,
87 {
88 self.expand.into()
89 }
90
91 pub fn __expand_clone_method(
93 self,
94 _context: &mut CubeContext,
95 ) -> ExpandElementTyped<Slice<Line<C>>>
96 where
97 C: CubePrimitive,
98 {
99 self.expand.into()
100 }
101 }
102
103 impl<C: CubeType> ExpandElementTyped<SliceMut<C>> {
104 pub fn __expand_len_method(self, context: &mut CubeContext) -> ExpandElementTyped<u32> {
106 let elem: ExpandElementTyped<Array<u32>> = self.expand.into();
107 elem.__expand_len_method(context)
108 }
109
110 pub fn __expand_into_aligned_method(
112 self,
113 _context: &mut CubeContext,
114 ) -> ExpandElementTyped<SliceMut<Line<C>>>
115 where
116 C: CubePrimitive,
117 {
118 self.expand.into()
119 }
120 }
121}
122
123mod indexation {
125 use ir::Instruction;
126
127 use crate::{
128 ir::{BinaryOperator, Operator},
129 prelude::{CubeIndex, CubeIndexMut},
130 };
131
132 use super::*;
133
134 impl<E: CubePrimitive> Slice<E> {
135 pub unsafe fn index_unchecked<I: Index>(&self, _i: I) -> &E
141 where
142 Self: CubeIndex<I>,
143 {
144 unexpanded!()
145 }
146 }
147
148 impl<E: CubePrimitive> SliceMut<E> {
149 pub unsafe fn index_unchecked<I: Index>(&self, _i: I) -> &E
155 where
156 Self: CubeIndex<I>,
157 {
158 unexpanded!()
159 }
160
161 pub unsafe fn index_assign_unchecked<I: Index>(&mut self, _i: I, _value: E)
167 where
168 Self: CubeIndexMut<I>,
169 {
170 unexpanded!()
171 }
172 }
173
174 impl<E: CubePrimitive> ExpandElementTyped<Slice<E>> {
175 pub fn __expand_index_unchecked_method(
176 self,
177 context: &mut CubeContext,
178 i: ExpandElementTyped<u32>,
179 ) -> ExpandElementTyped<E> {
180 let out = context.create_local(self.expand.item);
181 context.register(Instruction::new(
182 Operator::UncheckedIndex(BinaryOperator {
183 lhs: *self.expand,
184 rhs: i.expand.consume(),
185 }),
186 *out,
187 ));
188 out.into()
189 }
190 }
191
192 impl<E: CubePrimitive> ExpandElementTyped<SliceMut<E>> {
193 pub fn __expand_index_unchecked_method(
194 self,
195 context: &mut CubeContext,
196 i: ExpandElementTyped<u32>,
197 ) -> ExpandElementTyped<E> {
198 let out = context.create_local(self.expand.item);
199 context.register(Instruction::new(
200 Operator::UncheckedIndex(BinaryOperator {
201 lhs: *self.expand,
202 rhs: i.expand.consume(),
203 }),
204 *out,
205 ));
206 out.into()
207 }
208
209 pub fn __expand_index_assign_unchecked_method(
210 self,
211 context: &mut CubeContext,
212 i: ExpandElementTyped<u32>,
213 value: ExpandElementTyped<E>,
214 ) {
215 context.register(Instruction::new(
216 Operator::UncheckedIndexAssign(BinaryOperator {
217 lhs: i.expand.consume(),
218 rhs: value.expand.consume(),
219 }),
220 *self.expand,
221 ));
222 }
223 }
224}
225
226impl<E: CubeType> CubeType for Slice<E> {
227 type ExpandType = ExpandElementTyped<Slice<E>>;
228}
229
230impl<C: CubeType> Init for ExpandElementTyped<Slice<C>> {
231 fn init(self, _context: &mut crate::prelude::CubeContext) -> Self {
232 self
234 }
235}
236
237impl<E: CubeType> CubeType for SliceMut<E> {
238 type ExpandType = ExpandElementTyped<SliceMut<E>>;
239}
240
241impl<E: CubeType> CubeType for &mut SliceMut<E> {
242 type ExpandType = ExpandElementTyped<SliceMut<E>>;
243}
244
245impl<C: CubeType> Init for ExpandElementTyped<SliceMut<C>> {
246 fn init(self, _context: &mut crate::prelude::CubeContext) -> Self {
247 self
249 }
250}
251
252impl<C: CubeType<ExpandType = ExpandElementTyped<C>>> SizedContainer for Slice<C> {
253 type Item = C;
254}
255
256impl<T: CubeType> Iterator for Slice<T> {
257 type Item = T;
258
259 fn next(&mut self) -> Option<Self::Item> {
260 unexpanded!()
261 }
262}
263
264pub trait SliceOperator<E: CubeType>: CubeType<ExpandType = Self::Expand> {
265 type Expand: SliceOperatorExpand<E>;
266
267 #[allow(unused_variables)]
271 fn slice<Start: Index, End: Index>(&self, start: Start, end: End) -> Slice<E> {
272 unexpanded!()
273 }
274 fn __expand_slice(
276 context: &mut CubeContext,
277 expand: Self::Expand,
278 start: ExpandElementTyped<u32>,
279 end: ExpandElementTyped<u32>,
280 ) -> ExpandElementTyped<Slice<E>> {
281 expand.__expand_slice_method(context, start, end)
282 }
283
284 #[allow(unused_variables)]
288 fn slice_mut<Start: Index, End: Index>(&mut self, start: Start, end: End) -> SliceMut<E> {
289 unexpanded!()
290 }
291
292 fn __expand_slice_mut(
294 context: &mut CubeContext,
295 expand: Self::Expand,
296 start: ExpandElementTyped<u32>,
297 end: ExpandElementTyped<u32>,
298 ) -> ExpandElementTyped<SliceMut<E>> {
299 expand.__expand_slice_mut_method(context, start, end)
300 }
301
302 #[allow(unused_variables)]
304 fn to_slice(&self) -> Slice<E> {
305 unexpanded!()
306 }
307
308 fn __expand_to_slice(
310 context: &mut CubeContext,
311 expand: Self::Expand,
312 ) -> ExpandElementTyped<Slice<E>> {
313 expand.__expand_to_slice_method(context)
314 }
315
316 #[allow(unused_variables)]
318 fn to_slice_mut(&mut self) -> SliceMut<E> {
319 unexpanded!()
320 }
321
322 fn __expand_to_slice_mut(
324 context: &mut CubeContext,
325 expand: Self::Expand,
326 ) -> ExpandElementTyped<SliceMut<E>> {
327 expand.__expand_to_slice_mut_method(context)
328 }
329}
330
331pub trait SliceOperatorExpand<E: CubeType>: Into<ExpandElement> + Clone {
332 fn slice_base<Start: Index, End: Index>(
333 &self,
334 context: &mut CubeContext,
335 start: Start,
336 end: End,
337 ) -> ExpandElement;
338
339 fn __expand_slice_method(
340 &self,
341 context: &mut CubeContext,
342 start: ExpandElementTyped<u32>,
343 end: ExpandElementTyped<u32>,
344 ) -> ExpandElementTyped<Slice<E>> {
345 ExpandElementTyped::new(self.slice_base(context, start, end))
346 }
347
348 fn __expand_slice_mut_method(
349 &self,
350 context: &mut CubeContext,
351 start: ExpandElementTyped<u32>,
352 end: ExpandElementTyped<u32>,
353 ) -> ExpandElementTyped<SliceMut<E>> {
354 ExpandElementTyped::new(self.slice_base(context, start, end))
355 }
356
357 fn __expand_to_slice_method(&self, _context: &mut CubeContext) -> ExpandElementTyped<Slice<E>> {
358 let expand = self.clone().into();
359 ExpandElementTyped::new(expand)
360 }
361
362 fn __expand_to_slice_mut_method(
363 &self,
364 _context: &mut CubeContext,
365 ) -> ExpandElementTyped<SliceMut<E>> {
366 let expand = self.clone().into();
367 ExpandElementTyped::new(expand)
368 }
369}
370
371macro_rules! slice_op {
372 ($type:ident) => {
373 impl<E: CubePrimitive> SliceOperator<E> for $type<E> {
374 type Expand = ExpandElementTyped<$type<E>>;
375 }
376
377 impl<E: CubePrimitive> SliceOperatorExpand<E> for ExpandElementTyped<$type<E>> {
378 fn slice_base<Start: Index, End: Index>(
379 &self,
380 context: &mut CubeContext,
381 start: Start,
382 end: End,
383 ) -> ExpandElement {
384 slice_expand(context, self.clone(), start, end)
385 }
386 }
387 };
388 (slice $type:ident) => {
389 impl<E: CubePrimitive> SliceOperator<E> for $type<E> {
390 type Expand = ExpandElementTyped<$type<E>>;
391 }
392
393 impl<E: CubePrimitive> SliceOperatorExpand<E> for ExpandElementTyped<$type<E>> {
394 fn slice_base<Start: Index, End: Index>(
395 &self,
396 context: &mut CubeContext,
397 start: Start,
398 end: End,
399 ) -> ExpandElement {
400 slice_expand(context, self.clone(), start, end)
401 }
402 }
403 };
404}
405
406slice_op!(Array);
407slice_op!(Tensor);
408slice_op!(SharedMemory);
409slice_op!(slice Slice);
410slice_op!(slice SliceMut);
411
412pub fn slice_expand<I: Into<ExpandElement>, S1: Index, S2: Index>(
413 context: &mut CubeContext,
414 input: I,
415 start: S1,
416 end: S2, ) -> ExpandElement {
418 let input = input.into();
419 let out = context.create_slice(input.item);
420
421 context.register(Instruction::new(
422 Operator::Slice(ir::SliceOperator {
423 input: *input,
424 start: start.value(),
425 end: end.value(),
426 }),
427 *out,
428 ));
429
430 out
431}
432
433impl<E: CubePrimitive> IntoRuntime for Slice<E> {
434 fn __expand_runtime_method(self, _context: &mut CubeContext) -> Self::ExpandType {
435 unimplemented!("Array can't exist at compile time")
436 }
437}
438
439impl<E: CubePrimitive> IntoRuntime for SliceMut<E> {
440 fn __expand_runtime_method(self, _context: &mut CubeContext) -> Self::ExpandType {
441 unimplemented!("Array can't exist at compile time")
442 }
443}