cubecl-core 0.10.0-pre.3

CubeCL core create
Documentation
use alloc::boxed::Box;

use cubecl_ir::ManagedVariable;

use crate::{
    ir::{Branch, RangeLoop, Scope},
    prelude::{CubeIndex, CubePrimitive, CubeType, Iterable, NativeExpand, index},
};

use super::Array;

pub trait SizedContainer: CubeIndex<Idx: CubePrimitive, Output = Self::Item> + Sized {
    type Item: CubePrimitive;

    /// Return the length of the container.
    fn len(val: &ManagedVariable, scope: &mut Scope) -> ManagedVariable {
        // By default we use the expand len method of the Array type.
        let val: NativeExpand<Array<Self::Item>> = val.clone().into();
        val.__expand_len_method(scope).expand
    }
}

impl<T: SizedContainer + CubeType<ExpandType = NativeExpand<T>>> Iterable<T::Item>
    for NativeExpand<T>
{
    fn expand(
        self,
        scope: &mut Scope,
        mut body: impl FnMut(&mut Scope, <T::Item as CubeType>::ExpandType),
    ) {
        let index_ty = u32::as_type(scope);
        let len: ManagedVariable = T::len(&self.expand, scope);

        let mut child = scope.child();
        let i = child.create_local_restricted(index_ty);

        let index = i.clone().into();
        let item = index::expand(&mut child, self, index);
        body(&mut child, item);

        scope.register(Branch::RangeLoop(Box::new(RangeLoop {
            i: *i,
            start: 0u32.into(),
            end: *len,
            step: None,
            inclusive: false,
            scope: child,
        })));
    }

    fn expand_unroll(
        self,
        _scope: &mut Scope,
        _body: impl FnMut(&mut Scope, <T::Item as CubeType>::ExpandType),
    ) {
        unimplemented!("Can't unroll array iterator")
    }
}