use alloc::boxed::Box;
use cubecl_ir::ManagedVariable;
use crate::{
ir::{Branch, RangeLoop, Scope},
prelude::{CubeIndex, CubePrimitive, CubeType, Iterable, NativeExpand, index},
};
use super::Array;
pub trait SizedContainer: CubeIndex<Idx: CubePrimitive, Output = Self::Item> + Sized {
type Item: CubePrimitive;
fn len(val: &ManagedVariable, scope: &mut Scope) -> ManagedVariable {
let val: NativeExpand<Array<Self::Item>> = val.clone().into();
val.__expand_len_method(scope).expand
}
}
impl<T: SizedContainer + CubeType<ExpandType = NativeExpand<T>>> Iterable<T::Item>
for NativeExpand<T>
{
fn expand(
self,
scope: &mut Scope,
mut body: impl FnMut(&mut Scope, <T::Item as CubeType>::ExpandType),
) {
let index_ty = u32::as_type(scope);
let len: ManagedVariable = T::len(&self.expand, scope);
let mut child = scope.child();
let i = child.create_local_restricted(index_ty);
let index = i.clone().into();
let item = index::expand(&mut child, self, index);
body(&mut child, item);
scope.register(Branch::RangeLoop(Box::new(RangeLoop {
i: *i,
start: 0u32.into(),
end: *len,
step: None,
inclusive: false,
scope: child,
})));
}
fn expand_unroll(
self,
_scope: &mut Scope,
_body: impl FnMut(&mut Scope, <T::Item as CubeType>::ExpandType),
) {
unimplemented!("Can't unroll array iterator")
}
}