cubecl_core/frontend/container/
iter.rs1use cubecl_ir::ExpandElement;
2
3use crate::{
4 ir::{Branch, Item, RangeLoop, Scope},
5 prelude::{CubeIndex, CubePrimitive, CubeType, ExpandElementTyped, Iterable, index},
6};
7
8use super::Array;
9
10pub trait SizedContainer:
11 CubeIndex<ExpandElementTyped<u32>, Output = Self::Item> + CubeType
12{
13 type Item: CubeType<ExpandType = ExpandElementTyped<Self::Item>>;
14
15 fn len(val: &ExpandElement, scope: &mut Scope) -> ExpandElement {
17 let val: ExpandElementTyped<Array<Self::Item>> = val.clone().into();
19 val.__expand_len_method(scope).expand
20 }
21}
22
23impl<T: SizedContainer> Iterable<T::Item> for ExpandElementTyped<T> {
24 fn expand(
25 self,
26 scope: &mut Scope,
27 mut body: impl FnMut(&mut Scope, <T::Item as CubeType>::ExpandType),
28 ) {
29 let index_ty = Item::new(u32::as_elem(scope));
30 let len: ExpandElement = T::len(&self.expand, scope);
31
32 let mut child = scope.child();
33 let i = child.create_local_restricted(index_ty);
34
35 let item = index::expand(&mut child, self, i.clone().into());
36 body(&mut child, item);
37
38 scope.register(Branch::RangeLoop(Box::new(RangeLoop {
39 i: *i,
40 start: 0u32.into(),
41 end: *len,
42 step: None,
43 inclusive: false,
44 scope: child,
45 })));
46 }
47
48 fn expand_unroll(
49 self,
50 _scope: &mut Scope,
51 _body: impl FnMut(&mut Scope, <T::Item as CubeType>::ExpandType),
52 ) {
53 unimplemented!("Can't unroll array iterator")
54 }
55}