cubecl_core/frontend/container/
iter.rs

1use cubecl_ir::ExpandElement;
2
3use crate::{
4    ir::{Branch, Item, RangeLoop, Scope},
5    prelude::{CubeIndex, CubePrimitive, CubeType, ExpandElementTyped, Iterable, index},
6};
7
8use super::Array;
9
10pub trait SizedContainer:
11    CubeIndex<ExpandElementTyped<u32>, Output = Self::Item> + CubeType
12{
13    type Item: CubeType<ExpandType = ExpandElementTyped<Self::Item>>;
14
15    /// Return the length of the container.
16    fn len(val: &ExpandElement, scope: &mut Scope) -> ExpandElement {
17        // By default we use the expand len method of the Array type.
18        let val: ExpandElementTyped<Array<Self::Item>> = val.clone().into();
19        val.__expand_len_method(scope).expand
20    }
21}
22
23impl<T: SizedContainer> Iterable<T::Item> for ExpandElementTyped<T> {
24    fn expand(
25        self,
26        scope: &mut Scope,
27        mut body: impl FnMut(&mut Scope, <T::Item as CubeType>::ExpandType),
28    ) {
29        let index_ty = Item::new(u32::as_elem(scope));
30        let len: ExpandElement = T::len(&self.expand, scope);
31
32        let mut child = scope.child();
33        let i = child.create_local_restricted(index_ty);
34
35        let item = index::expand(&mut child, self, i.clone().into());
36        body(&mut child, item);
37
38        scope.register(Branch::RangeLoop(Box::new(RangeLoop {
39            i: *i,
40            start: 0u32.into(),
41            end: *len,
42            step: None,
43            inclusive: false,
44            scope: child,
45        })));
46    }
47
48    fn expand_unroll(
49        self,
50        _scope: &mut Scope,
51        _body: impl FnMut(&mut Scope, <T::Item as CubeType>::ExpandType),
52    ) {
53        unimplemented!("Can't unroll array iterator")
54    }
55}