cubecl_core/frontend/container/
iter.rs1use alloc::boxed::Box;
2
3use cubecl_ir::ExpandElement;
4
5use crate::{
6 ir::{Branch, RangeLoop, Scope, Type},
7 prelude::{CubeIndex, CubePrimitive, CubeType, ExpandElementTyped, Iterable, index},
8};
9
10use super::Array;
11
12pub trait SizedContainer: CubeIndex<Idx: CubePrimitive, Output = Self::Item> + Sized {
13 type Item: CubePrimitive;
14
15 fn len(val: &ExpandElement, scope: &mut Scope) -> ExpandElement {
17 let val: ExpandElementTyped<Array<Self::Item>> = val.clone().into();
19 val.__expand_len_method(scope).expand
20 }
21}
22
23impl<T: SizedContainer + CubeType<ExpandType = ExpandElementTyped<T>>> Iterable<T::Item>
24 for ExpandElementTyped<T>
25{
26 fn expand(
27 self,
28 scope: &mut Scope,
29 mut body: impl FnMut(&mut Scope, <T::Item as CubeType>::ExpandType),
30 ) {
31 let index_ty = Type::new(u32::as_type(scope));
32 let len: ExpandElement = T::len(&self.expand, scope);
33
34 let mut child = scope.child();
35 let i = child.create_local_restricted(index_ty);
36
37 let index = i.clone().into();
38 let item = index::expand(&mut child, self, index);
39 body(&mut child, item);
40
41 scope.register(Branch::RangeLoop(Box::new(RangeLoop {
42 i: *i,
43 start: 0u32.into(),
44 end: *len,
45 step: None,
46 inclusive: false,
47 scope: child,
48 })));
49 }
50
51 fn expand_unroll(
52 self,
53 _scope: &mut Scope,
54 _body: impl FnMut(&mut Scope, <T::Item as CubeType>::ExpandType),
55 ) {
56 unimplemented!("Can't unroll array iterator")
57 }
58}