cubecl-core 0.9.0

CubeCL core create
Documentation
use cubecl_ir::ExpandElement;

use crate::{
    ir::{Branch, RangeLoop, Scope, Type},
    prelude::{CubeIndex, CubePrimitive, CubeType, ExpandElementTyped, Iterable, index},
};

use super::Array;

pub trait SizedContainer: CubeIndex<Idx: CubePrimitive, Output = Self::Item> + Sized {
    type Item: CubePrimitive;

    /// Return the length of the container.
    fn len(val: &ExpandElement, scope: &mut Scope) -> ExpandElement {
        // By default we use the expand len method of the Array type.
        let val: ExpandElementTyped<Array<Self::Item>> = val.clone().into();
        val.__expand_len_method(scope).expand
    }
}

impl<T: SizedContainer + CubeType<ExpandType = ExpandElementTyped<T>>> Iterable<T::Item>
    for ExpandElementTyped<T>
{
    fn expand(
        self,
        scope: &mut Scope,
        mut body: impl FnMut(&mut Scope, <T::Item as CubeType>::ExpandType),
    ) {
        let index_ty = Type::new(u32::as_type(scope));
        let len: ExpandElement = T::len(&self.expand, scope);

        let mut child = scope.child();
        let i = child.create_local_restricted(index_ty);

        let index = i.clone().into();
        let item = index::expand(&mut child, self, index);
        body(&mut child, item);

        scope.register(Branch::RangeLoop(Box::new(RangeLoop {
            i: *i,
            start: 0u32.into(),
            end: *len,
            step: None,
            inclusive: false,
            scope: child,
        })));
    }

    fn expand_unroll(
        self,
        _scope: &mut Scope,
        _body: impl FnMut(&mut Scope, <T::Item as CubeType>::ExpandType),
    ) {
        unimplemented!("Can't unroll array iterator")
    }
}