use core::ops::{Index, IndexMut};
use cubecl_ir::{
IndexAssignOperator, Instruction, ManagedVariable, Operator, Scope, VariableKind, VectorSize,
};
use super::{CubeType, NativeExpand, index_expand, index_expand_no_vec};
use crate::{ir::Variable, prelude::CubePrimitive, unexpanded};
pub trait CubeIndex:
CubeType<
ExpandType: CubeIndexExpand<
Idx = <Self::Idx as CubeType>::ExpandType,
Output = <Self::Output as CubeType>::ExpandType,
>,
>
{
type Output: CubeType;
type Idx: CubeType;
fn cube_idx(&self, _i: Self::Idx) -> &Self::Output {
unexpanded!()
}
fn expand_index(
scope: &mut Scope,
array: Self::ExpandType,
index: <Self::Idx as CubeType>::ExpandType,
) -> <Self::Output as CubeType>::ExpandType {
array.expand_index(scope, index)
}
fn expand_index_unchecked(
scope: &mut Scope,
array: Self::ExpandType,
index: <Self::Idx as CubeType>::ExpandType,
) -> <Self::Output as CubeType>::ExpandType {
array.expand_index_unchecked(scope, index)
}
}
pub trait ComptimeIndex<I>: Index<I> {
fn cube_idx(&self, i: I) -> &Self::Output {
self.index(i)
}
}
impl<I, T: Index<I>> ComptimeIndex<I> for T {}
impl<I, T: IndexMut<I>> ComptimeIndexMut<I> for T {}
pub trait ComptimeIndexMut<I>: ComptimeIndex<I> + IndexMut<I> {
fn cube_idx_mut(&mut self, i: I) -> &mut Self::Output {
self.index_mut(i)
}
}
pub trait CubeIndexExpand {
type Output;
type Idx;
fn expand_index(self, scope: &mut Scope, index: Self::Idx) -> Self::Output;
fn expand_index_unchecked(self, scope: &mut Scope, index: Self::Idx) -> Self::Output;
}
pub trait CubeIndexMut:
CubeIndex
+ CubeType<ExpandType: CubeIndexMutExpand<Output = <Self::Output as CubeType>::ExpandType>>
{
fn cube_idx_mut(&mut self, _i: <Self as CubeIndex>::Idx) -> &mut <Self as CubeIndex>::Output {
unexpanded!()
}
fn expand_index_mut(
scope: &mut Scope,
array: Self::ExpandType,
index: <Self::Idx as CubeType>::ExpandType,
value: <Self::Output as CubeType>::ExpandType,
) {
array.expand_index_mut(scope, index, value)
}
}
pub trait CubeIndexMutExpand: CubeIndexExpand {
fn expand_index_mut(
self,
scope: &mut Scope,
index: <Self as CubeIndexExpand>::Idx,
value: <Self as CubeIndexExpand>::Output,
);
}
pub(crate) fn expand_index_native<A: CubeType + CubeIndex>(
scope: &mut Scope,
array: NativeExpand<A>,
index: NativeExpand<usize>,
vector_size: Option<VectorSize>,
checked: bool,
) -> NativeExpand<A::Output>
where
A::Output: CubeType + Sized,
{
let index: ManagedVariable = index.into();
let index_var: Variable = *index;
let index = match index_var.kind {
VariableKind::Constant(value) => {
ManagedVariable::Plain(Variable::constant(value, usize::as_type(scope)))
}
_ => index,
};
let array: ManagedVariable = array.into();
let var: Variable = *array;
let var = if checked {
match var.kind {
VariableKind::LocalMut { .. } | VariableKind::LocalConst { .. } => {
index_expand_no_vec(scope, array, index, Operator::Index)
}
_ => index_expand(scope, array, index, vector_size, Operator::Index),
}
} else {
match var.kind {
VariableKind::LocalMut { .. } | VariableKind::LocalConst { .. } => {
index_expand_no_vec(scope, array, index, Operator::UncheckedIndex)
}
_ => index_expand(scope, array, index, vector_size, Operator::UncheckedIndex),
}
};
NativeExpand::new(var)
}
pub(crate) fn expand_index_assign_native<A: CubeType<ExpandType = NativeExpand<A>> + CubeIndexMut>(
scope: &mut Scope,
array: A::ExpandType,
index: NativeExpand<usize>,
value: NativeExpand<<A as CubeIndex>::Output>,
vector_size: Option<VectorSize>,
checked: bool,
) where
A::Output: CubeType + Sized,
{
let index: Variable = index.expand.into();
let index = match index.kind {
VariableKind::Constant(value) => Variable::constant(value, usize::as_type(scope)),
_ => index,
};
let vector_size = vector_size.unwrap_or(0);
if checked {
scope.register(Instruction::new(
Operator::IndexAssign(IndexAssignOperator {
index,
value: value.expand.into(),
vector_size,
unroll_factor: 1,
}),
array.expand.into(),
));
} else {
scope.register(Instruction::new(
Operator::UncheckedIndexAssign(IndexAssignOperator {
index,
value: value.expand.into(),
vector_size,
unroll_factor: 1,
}),
array.expand.into(),
));
}
}