Skip to main content

cubecl_core/frontend/container/
iter.rs

1use alloc::boxed::Box;
2
3use cubecl_ir::ExpandElement;
4
5use crate::{
6    ir::{Branch, RangeLoop, Scope, Type},
7    prelude::{CubeIndex, CubePrimitive, CubeType, ExpandElementTyped, Iterable, index},
8};
9
10use super::Array;
11
12pub trait SizedContainer: CubeIndex<Idx: CubePrimitive, Output = Self::Item> + Sized {
13    type Item: CubePrimitive;
14
15    /// Return the length of the container.
16    fn len(val: &ExpandElement, scope: &mut Scope) -> ExpandElement {
17        // By default we use the expand len method of the Array type.
18        let val: ExpandElementTyped<Array<Self::Item>> = val.clone().into();
19        val.__expand_len_method(scope).expand
20    }
21}
22
23impl<T: SizedContainer + CubeType<ExpandType = ExpandElementTyped<T>>> Iterable<T::Item>
24    for ExpandElementTyped<T>
25{
26    fn expand(
27        self,
28        scope: &mut Scope,
29        mut body: impl FnMut(&mut Scope, <T::Item as CubeType>::ExpandType),
30    ) {
31        let index_ty = Type::new(u32::as_type(scope));
32        let len: ExpandElement = T::len(&self.expand, scope);
33
34        let mut child = scope.child();
35        let i = child.create_local_restricted(index_ty);
36
37        let index = i.clone().into();
38        let item = index::expand(&mut child, self, index);
39        body(&mut child, item);
40
41        scope.register(Branch::RangeLoop(Box::new(RangeLoop {
42            i: *i,
43            start: 0u32.into(),
44            end: *len,
45            step: None,
46            inclusive: false,
47            scope: child,
48        })));
49    }
50
51    fn expand_unroll(
52        self,
53        _scope: &mut Scope,
54        _body: impl FnMut(&mut Scope, <T::Item as CubeType>::ExpandType),
55    ) {
56        unimplemented!("Can't unroll array iterator")
57    }
58}