cubecl_core/frontend/container/
iter.rs

1use cubecl_ir::ExpandElement;
2
3use crate::{
4    ir::{Branch, Item, RangeLoop, Scope},
5    prelude::{CubeIndex, CubePrimitive, CubeType, ExpandElementTyped, Iterable, index},
6};
7
8use super::Array;
9
10pub trait SizedContainer: CubeIndex<Output = Self::Item> + Sized {
11    type Item: CubePrimitive;
12
13    /// Return the length of the container.
14    fn len(val: &ExpandElement, scope: &mut Scope) -> ExpandElement {
15        // By default we use the expand len method of the Array type.
16        let val: ExpandElementTyped<Array<Self::Item>> = val.clone().into();
17        val.__expand_len_method(scope).expand
18    }
19}
20
21impl<T: SizedContainer + CubeType<ExpandType = ExpandElementTyped<T>>> Iterable<T::Item>
22    for ExpandElementTyped<T>
23{
24    fn expand(
25        self,
26        scope: &mut Scope,
27        mut body: impl FnMut(&mut Scope, <T::Item as CubeType>::ExpandType),
28    ) {
29        let index_ty = Item::new(u32::as_elem(scope));
30        let len: ExpandElement = T::len(&self.expand, scope);
31
32        let mut child = scope.child();
33        let i = child.create_local_restricted(index_ty);
34
35        let index = i.clone().into();
36        let item = index::expand(&mut child, self, index);
37        body(&mut child, item);
38
39        scope.register(Branch::RangeLoop(Box::new(RangeLoop {
40            i: *i,
41            start: 0u32.into(),
42            end: *len,
43            step: None,
44            inclusive: false,
45            scope: child,
46        })));
47    }
48
49    fn expand_unroll(
50        self,
51        _scope: &mut Scope,
52        _body: impl FnMut(&mut Scope, <T::Item as CubeType>::ExpandType),
53    ) {
54        unimplemented!("Can't unroll array iterator")
55    }
56}