cubecl_core/frontend/container/
iter.rs

1use crate::{
2    ir::{Branch, Item, RangeLoop},
3    prelude::{
4        index, CubeContext, CubeIndex, CubePrimitive, CubeType, ExpandElement, ExpandElementTyped,
5        Iterable,
6    },
7};
8
9use super::Array;
10
11pub trait SizedContainer:
12    CubeIndex<ExpandElementTyped<u32>, Output = Self::Item> + CubeType
13{
14    type Item: CubeType<ExpandType = ExpandElementTyped<Self::Item>>;
15
16    /// Return the length of the container.
17    fn len(val: &ExpandElement, context: &mut CubeContext) -> ExpandElement {
18        // By default we use the expand len method of the Array type.
19        let val: ExpandElementTyped<Array<Self::Item>> = val.clone().into();
20        val.__expand_len_method(context).expand
21    }
22}
23
24impl<T: SizedContainer> Iterable<T::Item> for ExpandElementTyped<T> {
25    fn expand(
26        self,
27        context: &mut CubeContext,
28        mut body: impl FnMut(&mut CubeContext, <T::Item as CubeType>::ExpandType),
29    ) {
30        let index_ty = Item::new(u32::as_elem(context));
31        let len: ExpandElement = T::len(&self.expand, context);
32
33        let mut child = context.child();
34        let i = child.create_local_restricted(index_ty);
35
36        let item = index::expand(&mut child, self, i.clone().into());
37        body(&mut child, item);
38
39        context.register(Branch::RangeLoop(Box::new(RangeLoop {
40            i: *i,
41            start: 0u32.into(),
42            end: *len,
43            step: None,
44            inclusive: false,
45            scope: child.into_scope(),
46        })));
47    }
48
49    fn expand_unroll(
50        self,
51        _context: &mut CubeContext,
52        _body: impl FnMut(&mut CubeContext, <T::Item as CubeType>::ExpandType),
53    ) {
54        unimplemented!("Can't unroll array iterator")
55    }
56}