cubecl_core/frontend/container/
iter.rs1use crate::{
2 ir::{Branch, Item, RangeLoop},
3 prelude::{
4 index, CubeContext, CubeIndex, CubePrimitive, CubeType, ExpandElement, ExpandElementTyped,
5 Iterable,
6 },
7};
8
9use super::Array;
10
11pub trait SizedContainer:
12 CubeIndex<ExpandElementTyped<u32>, Output = Self::Item> + CubeType
13{
14 type Item: CubeType<ExpandType = ExpandElementTyped<Self::Item>>;
15
16 fn len(val: &ExpandElement, context: &mut CubeContext) -> ExpandElement {
18 let val: ExpandElementTyped<Array<Self::Item>> = val.clone().into();
20 val.__expand_len_method(context).expand
21 }
22}
23
24impl<T: SizedContainer> Iterable<T::Item> for ExpandElementTyped<T> {
25 fn expand(
26 self,
27 context: &mut CubeContext,
28 mut body: impl FnMut(&mut CubeContext, <T::Item as CubeType>::ExpandType),
29 ) {
30 let index_ty = Item::new(u32::as_elem(context));
31 let len: ExpandElement = T::len(&self.expand, context);
32
33 let mut child = context.child();
34 let i = child.create_local_restricted(index_ty);
35
36 let item = index::expand(&mut child, self, i.clone().into());
37 body(&mut child, item);
38
39 context.register(Branch::RangeLoop(Box::new(RangeLoop {
40 i: *i,
41 start: 0u32.into(),
42 end: *len,
43 step: None,
44 inclusive: false,
45 scope: child.into_scope(),
46 })));
47 }
48
49 fn expand_unroll(
50 self,
51 _context: &mut CubeContext,
52 _body: impl FnMut(&mut CubeContext, <T::Item as CubeType>::ExpandType),
53 ) {
54 unimplemented!("Can't unroll array iterator")
55 }
56}