1use core::ops::{Index, IndexMut};
2
3use cubecl_ir::{
4 IndexAssignOperator, Instruction, ManagedVariable, Operator, Scope, VariableKind, VectorSize,
5};
6
7use super::{CubeType, NativeExpand, index_expand, index_expand_no_vec};
8use crate::{ir::Variable, prelude::CubePrimitive, unexpanded};
9
10pub trait CubeIndex:
13 CubeType<
14 ExpandType: CubeIndexExpand<
15 Idx = <Self::Idx as CubeType>::ExpandType,
16 Output = <Self::Output as CubeType>::ExpandType,
17 >,
18>
19{
20 type Output: CubeType;
21 type Idx: CubeType;
22
23 fn cube_idx(&self, _i: Self::Idx) -> &Self::Output {
24 unexpanded!()
25 }
26
27 fn expand_index(
28 scope: &mut Scope,
29 array: Self::ExpandType,
30 index: <Self::Idx as CubeType>::ExpandType,
31 ) -> <Self::Output as CubeType>::ExpandType {
32 array.expand_index(scope, index)
33 }
34 fn expand_index_unchecked(
35 scope: &mut Scope,
36 array: Self::ExpandType,
37 index: <Self::Idx as CubeType>::ExpandType,
38 ) -> <Self::Output as CubeType>::ExpandType {
39 array.expand_index_unchecked(scope, index)
40 }
41}
42
43pub trait ComptimeIndex<I>: Index<I> {
47 fn cube_idx(&self, i: I) -> &Self::Output {
48 self.index(i)
49 }
50}
51
52impl<I, T: Index<I>> ComptimeIndex<I> for T {}
53impl<I, T: IndexMut<I>> ComptimeIndexMut<I> for T {}
54
55pub trait ComptimeIndexMut<I>: ComptimeIndex<I> + IndexMut<I> {
56 fn cube_idx_mut(&mut self, i: I) -> &mut Self::Output {
57 self.index_mut(i)
58 }
59}
60
61pub trait CubeIndexExpand {
62 type Output;
63 type Idx;
64 fn expand_index(self, scope: &mut Scope, index: Self::Idx) -> Self::Output;
65 fn expand_index_unchecked(self, scope: &mut Scope, index: Self::Idx) -> Self::Output;
66}
67
68pub trait CubeIndexMut:
69 CubeIndex
70 + CubeType<ExpandType: CubeIndexMutExpand<Output = <Self::Output as CubeType>::ExpandType>>
71{
72 fn cube_idx_mut(&mut self, _i: <Self as CubeIndex>::Idx) -> &mut <Self as CubeIndex>::Output {
73 unexpanded!()
74 }
75 fn expand_index_mut(
76 scope: &mut Scope,
77 array: Self::ExpandType,
78 index: <Self::Idx as CubeType>::ExpandType,
79 value: <Self::Output as CubeType>::ExpandType,
80 ) {
81 array.expand_index_mut(scope, index, value)
82 }
83}
84
85pub trait CubeIndexMutExpand: CubeIndexExpand {
86 fn expand_index_mut(
87 self,
88 scope: &mut Scope,
89 index: <Self as CubeIndexExpand>::Idx,
90 value: <Self as CubeIndexExpand>::Output,
91 );
92}
93
94pub(crate) fn expand_index_native<A: CubeType + CubeIndex>(
95 scope: &mut Scope,
96 array: NativeExpand<A>,
97 index: NativeExpand<usize>,
98 vector_size: Option<VectorSize>,
99 checked: bool,
100) -> NativeExpand<A::Output>
101where
102 A::Output: CubeType + Sized,
103{
104 let index: ManagedVariable = index.into();
105 let index_var: Variable = *index;
106 let index = match index_var.kind {
107 VariableKind::Constant(value) => {
108 ManagedVariable::Plain(Variable::constant(value, usize::as_type(scope)))
109 }
110 _ => index,
111 };
112 let array: ManagedVariable = array.into();
113 let var: Variable = *array;
114 let var = if checked {
115 match var.kind {
116 VariableKind::LocalMut { .. } | VariableKind::LocalConst { .. } => {
117 index_expand_no_vec(scope, array, index, Operator::Index)
118 }
119 _ => index_expand(scope, array, index, vector_size, Operator::Index),
120 }
121 } else {
122 match var.kind {
123 VariableKind::LocalMut { .. } | VariableKind::LocalConst { .. } => {
124 index_expand_no_vec(scope, array, index, Operator::UncheckedIndex)
125 }
126 _ => index_expand(scope, array, index, vector_size, Operator::UncheckedIndex),
127 }
128 };
129
130 NativeExpand::new(var)
131}
132
133pub(crate) fn expand_index_assign_native<A: CubeType<ExpandType = NativeExpand<A>> + CubeIndexMut>(
134 scope: &mut Scope,
135 array: A::ExpandType,
136 index: NativeExpand<usize>,
137 value: NativeExpand<<A as CubeIndex>::Output>,
138 vector_size: Option<VectorSize>,
139 checked: bool,
140) where
141 A::Output: CubeType + Sized,
142{
143 let index: Variable = index.expand.into();
144 let index = match index.kind {
145 VariableKind::Constant(value) => Variable::constant(value, usize::as_type(scope)),
146 _ => index,
147 };
148
149 let vector_size = vector_size.unwrap_or(0);
150 if checked {
151 scope.register(Instruction::new(
152 Operator::IndexAssign(IndexAssignOperator {
153 index,
154 value: value.expand.into(),
155 vector_size,
156 unroll_factor: 1,
157 }),
158 array.expand.into(),
159 ));
160 } else {
161 scope.register(Instruction::new(
162 Operator::UncheckedIndexAssign(IndexAssignOperator {
163 index,
164 value: value.expand.into(),
165 vector_size,
166 unroll_factor: 1,
167 }),
168 array.expand.into(),
169 ));
170 }
171}