use cubecl_ir::ExpandElement;
use crate::{
ir::{Branch, RangeLoop, Scope, Type},
prelude::{CubeIndex, CubePrimitive, CubeType, ExpandElementTyped, Iterable, index},
};
use super::Array;
pub trait SizedContainer: CubeIndex<Idx = u32, Output = Self::Item> + Sized {
type Item: CubePrimitive;
fn len(val: &ExpandElement, scope: &mut Scope) -> ExpandElement {
let val: ExpandElementTyped<Array<Self::Item>> = val.clone().into();
val.__expand_len_method(scope).expand
}
}
impl<T: SizedContainer + CubeType<ExpandType = ExpandElementTyped<T>>> Iterable<T::Item>
for ExpandElementTyped<T>
{
fn expand(
self,
scope: &mut Scope,
mut body: impl FnMut(&mut Scope, <T::Item as CubeType>::ExpandType),
) {
let index_ty = Type::new(u32::as_type(scope));
let len: ExpandElement = T::len(&self.expand, scope);
let mut child = scope.child();
let i = child.create_local_restricted(index_ty);
let index = i.clone().into();
let item = index::expand(&mut child, self, index);
body(&mut child, item);
scope.register(Branch::RangeLoop(Box::new(RangeLoop {
i: *i,
start: 0u32.into(),
end: *len,
step: None,
inclusive: false,
scope: child,
})));
}
fn expand_unroll(
self,
_scope: &mut Scope,
_body: impl FnMut(&mut Scope, <T::Item as CubeType>::ExpandType),
) {
unimplemented!("Can't unroll array iterator")
}
}