1use cubecl_ir::{ExpandElement, IndexAssignOperator, Instruction, Operator, Scope, VariableKind};
2
3use super::{CubeType, ExpandElementTyped, index_expand, index_expand_no_vec};
4use crate::{
5 ir::{IntKind, UIntKind, Variable},
6 unexpanded,
7};
8
9pub trait CubeIndex:
12 CubeType<
13 ExpandType: CubeIndexExpand<
14 Idx = <Self::Idx as CubeType>::ExpandType,
15 Output = <Self::Output as CubeType>::ExpandType,
16 >,
17>
18{
19 type Output: CubeType;
20 type Idx: CubeType;
21
22 fn cube_idx(&self, _i: Self::Idx) -> &Self::Output {
23 unexpanded!()
24 }
25
26 fn expand_index(
27 scope: &mut Scope,
28 array: Self::ExpandType,
29 index: <Self::Idx as CubeType>::ExpandType,
30 ) -> <Self::Output as CubeType>::ExpandType {
31 array.expand_index(scope, index)
32 }
33 fn expand_index_unchecked(
34 scope: &mut Scope,
35 array: Self::ExpandType,
36 index: <Self::Idx as CubeType>::ExpandType,
37 ) -> <Self::Output as CubeType>::ExpandType {
38 array.expand_index_unchecked(scope, index)
39 }
40}
41
42pub trait CubeIndexExpand {
43 type Output;
44 type Idx;
45 fn expand_index(self, scope: &mut Scope, index: Self::Idx) -> Self::Output;
46 fn expand_index_unchecked(self, scope: &mut Scope, index: Self::Idx) -> Self::Output;
47}
48
49pub trait CubeIndexMut:
50 CubeIndex
51 + CubeType<ExpandType: CubeIndexMutExpand<Output = <Self::Output as CubeType>::ExpandType>>
52{
53 fn cube_idx_mut(&mut self, _i: Self::Idx) -> &mut Self::Output {
54 unexpanded!()
55 }
56 fn expand_index_mut(
57 scope: &mut Scope,
58 array: Self::ExpandType,
59 index: <Self::Idx as CubeType>::ExpandType,
60 value: <Self::Output as CubeType>::ExpandType,
61 ) {
62 array.expand_index_mut(scope, index, value)
63 }
64}
65
66pub trait CubeIndexMutExpand: CubeIndexExpand {
67 fn expand_index_mut(self, scope: &mut Scope, index: Self::Idx, value: Self::Output);
68}
69
70pub(crate) fn expand_index_native<A: CubeType + CubeIndex>(
71 scope: &mut Scope,
72 array: ExpandElementTyped<A>,
73 index: ExpandElementTyped<u32>,
74 line_size: Option<u32>,
75 checked: bool,
76) -> ExpandElementTyped<A::Output>
77where
78 A::Output: CubeType + Sized,
79{
80 let index: ExpandElement = index.into();
81 let index_var: Variable = *index;
82 let index = match index_var.kind {
83 VariableKind::ConstantScalar(value) => ExpandElement::Plain(Variable::constant(
84 crate::ir::ConstantScalarValue::UInt(value.as_u64(), UIntKind::U32),
85 )),
86 _ => index,
87 };
88 let array: ExpandElement = array.into();
89 let var: Variable = *array;
90 let var = if checked {
91 match var.kind {
92 VariableKind::LocalMut { .. } | VariableKind::LocalConst { .. } => {
93 index_expand_no_vec(scope, array, index, Operator::Index)
94 }
95 _ => index_expand(scope, array, index, line_size, Operator::Index),
96 }
97 } else {
98 match var.kind {
99 VariableKind::LocalMut { .. } | VariableKind::LocalConst { .. } => {
100 index_expand_no_vec(scope, array, index, Operator::UncheckedIndex)
101 }
102 _ => index_expand(scope, array, index, line_size, Operator::UncheckedIndex),
103 }
104 };
105
106 ExpandElementTyped::new(var)
107}
108
109pub(crate) fn expand_index_assign_native<
110 A: CubeType<ExpandType = ExpandElementTyped<A>> + CubeIndexMut,
111>(
112 scope: &mut Scope,
113 array: A::ExpandType,
114 index: ExpandElementTyped<u32>,
115 value: ExpandElementTyped<A::Output>,
116 line_size: Option<u32>,
117 checked: bool,
118) where
119 A::Output: CubeType + Sized,
120{
121 let index: Variable = index.expand.into();
122 let index = match index.kind {
123 VariableKind::ConstantScalar(value) => Variable::constant(
124 crate::ir::ConstantScalarValue::UInt(value.as_u64(), UIntKind::U32),
125 ),
126 _ => index,
127 };
128 let line_size = line_size.unwrap_or(0);
129 if checked {
130 scope.register(Instruction::new(
131 Operator::IndexAssign(IndexAssignOperator {
132 index,
133 value: value.expand.into(),
134 line_size,
135 unroll_factor: 1,
136 }),
137 array.expand.into(),
138 ));
139 } else {
140 scope.register(Instruction::new(
141 Operator::UncheckedIndexAssign(IndexAssignOperator {
142 index,
143 value: value.expand.into(),
144 line_size,
145 unroll_factor: 1,
146 }),
147 array.expand.into(),
148 ));
149 }
150}
151
152pub trait Index {
153 fn value(self) -> Variable;
154}
155
156impl Index for i32 {
157 fn value(self) -> Variable {
158 Variable::constant(crate::ir::ConstantScalarValue::Int(
159 self as i64,
160 IntKind::I32,
161 ))
162 }
163}
164
165impl Index for u32 {
166 fn value(self) -> Variable {
167 Variable::constant(crate::ir::ConstantScalarValue::UInt(
168 self as u64,
169 UIntKind::U32,
170 ))
171 }
172}
173
174impl Index for ExpandElement {
175 fn value(self) -> Variable {
176 *self
177 }
178}
179
180impl Index for ExpandElementTyped<u32> {
181 fn value(self) -> Variable {
182 *self.expand
183 }
184}