1use core::ops::{Index, IndexMut};
2
3use cubecl_ir::{
4 ExpandElement, IndexAssignOperator, Instruction, LineSize, Operator, Scope, VariableKind,
5};
6
7use super::{CubeType, ExpandElementTyped, index_expand, index_expand_no_vec};
8use crate::{ir::Variable, prelude::CubePrimitive, unexpanded};
9
10pub trait CubeIndex:
13 CubeType<
14 ExpandType: CubeIndexExpand<
15 Idx = <Self::Idx as CubeType>::ExpandType,
16 Output = <Self::Output as CubeType>::ExpandType,
17 >,
18>
19{
20 type Output: CubeType;
21 type Idx: CubeType;
22
23 fn cube_idx(&self, _i: Self::Idx) -> &Self::Output {
24 unexpanded!()
25 }
26
27 fn expand_index(
28 scope: &mut Scope,
29 array: Self::ExpandType,
30 index: <Self::Idx as CubeType>::ExpandType,
31 ) -> <Self::Output as CubeType>::ExpandType {
32 array.expand_index(scope, index)
33 }
34 fn expand_index_unchecked(
35 scope: &mut Scope,
36 array: Self::ExpandType,
37 index: <Self::Idx as CubeType>::ExpandType,
38 ) -> <Self::Output as CubeType>::ExpandType {
39 array.expand_index_unchecked(scope, index)
40 }
41}
42
43pub trait ComptimeIndex<I>: Index<I> {
47 fn cube_idx(&self, i: I) -> &Self::Output {
48 self.index(i)
49 }
50}
51
52impl<I, T: Index<I>> ComptimeIndex<I> for T {}
53impl<I, T: IndexMut<I>> ComptimeIndexMut<I> for T {}
54
55pub trait ComptimeIndexMut<I>: ComptimeIndex<I> + IndexMut<I> {
56 fn cube_idx_mut(&mut self, i: I) -> &mut Self::Output {
57 self.index_mut(i)
58 }
59}
60
61pub trait CubeIndexExpand {
62 type Output;
63 type Idx;
64 fn expand_index(self, scope: &mut Scope, index: Self::Idx) -> Self::Output;
65 fn expand_index_unchecked(self, scope: &mut Scope, index: Self::Idx) -> Self::Output;
66}
67
68pub trait CubeIndexMut:
69 CubeIndex
70 + CubeType<ExpandType: CubeIndexMutExpand<Output = <Self::Output as CubeType>::ExpandType>>
71{
72 fn cube_idx_mut(&mut self, _i: <Self as CubeIndex>::Idx) -> &mut <Self as CubeIndex>::Output {
73 unexpanded!()
74 }
75 fn expand_index_mut(
76 scope: &mut Scope,
77 array: Self::ExpandType,
78 index: <Self::Idx as CubeType>::ExpandType,
79 value: <Self::Output as CubeType>::ExpandType,
80 ) {
81 array.expand_index_mut(scope, index, value)
82 }
83}
84
85pub trait CubeIndexMutExpand: CubeIndexExpand {
86 fn expand_index_mut(
87 self,
88 scope: &mut Scope,
89 index: <Self as CubeIndexExpand>::Idx,
90 value: <Self as CubeIndexExpand>::Output,
91 );
92}
93
94pub(crate) fn expand_index_native<A: CubeType + CubeIndex>(
95 scope: &mut Scope,
96 array: ExpandElementTyped<A>,
97 index: ExpandElementTyped<usize>,
98 line_size: Option<LineSize>,
99 checked: bool,
100) -> ExpandElementTyped<A::Output>
101where
102 A::Output: CubeType + Sized,
103{
104 let index: ExpandElement = index.into();
105 let index_var: Variable = *index;
106 let index = match index_var.kind {
107 VariableKind::Constant(value) => {
108 ExpandElement::Plain(Variable::constant(value, usize::as_type(scope)))
109 }
110 _ => index,
111 };
112 let array: ExpandElement = array.into();
113 let var: Variable = *array;
114 let var = if checked {
115 match var.kind {
116 VariableKind::LocalMut { .. } | VariableKind::LocalConst { .. } => {
117 index_expand_no_vec(scope, array, index, Operator::Index)
118 }
119 _ => index_expand(scope, array, index, line_size, Operator::Index),
120 }
121 } else {
122 match var.kind {
123 VariableKind::LocalMut { .. } | VariableKind::LocalConst { .. } => {
124 index_expand_no_vec(scope, array, index, Operator::UncheckedIndex)
125 }
126 _ => index_expand(scope, array, index, line_size, Operator::UncheckedIndex),
127 }
128 };
129
130 ExpandElementTyped::new(var)
131}
132
133pub(crate) fn expand_index_assign_native<
134 A: CubeType<ExpandType = ExpandElementTyped<A>> + CubeIndexMut,
135>(
136 scope: &mut Scope,
137 array: A::ExpandType,
138 index: ExpandElementTyped<usize>,
139 value: ExpandElementTyped<<A as CubeIndex>::Output>,
140 line_size: Option<LineSize>,
141 checked: bool,
142) where
143 A::Output: CubeType + Sized,
144{
145 let index: Variable = index.expand.into();
146 let index = match index.kind {
147 VariableKind::Constant(value) => Variable::constant(value, usize::as_type(scope)),
148 _ => index,
149 };
150
151 let line_size = line_size.unwrap_or(0);
152 if checked {
153 scope.register(Instruction::new(
154 Operator::IndexAssign(IndexAssignOperator {
155 index,
156 value: value.expand.into(),
157 line_size,
158 unroll_factor: 1,
159 }),
160 array.expand.into(),
161 ));
162 } else {
163 scope.register(Instruction::new(
164 Operator::UncheckedIndexAssign(IndexAssignOperator {
165 index,
166 value: value.expand.into(),
167 line_size,
168 unroll_factor: 1,
169 }),
170 array.expand.into(),
171 ));
172 }
173}