1use cubecl_ir::{ExpandElement, IndexAssignOperator, Instruction, Operator, Scope, VariableKind};
2
3use super::{CubeType, ExpandElementTyped, index_expand, index_expand_no_vec};
4use crate::{
5 ir::{IntKind, UIntKind, Variable},
6 unexpanded,
7};
8
9pub trait CubeIndex:
12 CubeType<ExpandType: CubeIndexExpand<Output = <Self::Output as CubeType>::ExpandType>>
13{
14 type Output: CubeType;
15
16 fn cube_idx(&self, _i: u32) -> &Self::Output {
17 unexpanded!()
18 }
19
20 fn expand_index(
21 scope: &mut Scope,
22 array: Self::ExpandType,
23 index: ExpandElementTyped<u32>,
24 ) -> <Self::Output as CubeType>::ExpandType {
25 array.expand_index(scope, index)
26 }
27 fn expand_index_unchecked(
28 scope: &mut Scope,
29 array: Self::ExpandType,
30 index: ExpandElementTyped<u32>,
31 ) -> <Self::Output as CubeType>::ExpandType {
32 array.expand_index_unchecked(scope, index)
33 }
34}
35
36pub trait CubeIndexExpand {
37 type Output;
38 fn expand_index(self, scope: &mut Scope, index: ExpandElementTyped<u32>) -> Self::Output;
39 fn expand_index_unchecked(
40 self,
41 scope: &mut Scope,
42 index: ExpandElementTyped<u32>,
43 ) -> Self::Output;
44}
45
46pub trait CubeIndexMut:
47 CubeIndex
48 + CubeType<ExpandType: CubeIndexMutExpand<Output = <Self::Output as CubeType>::ExpandType>>
49{
50 fn cube_idx_mut(&mut self, _i: u32) -> &mut Self::Output {
51 unexpanded!()
52 }
53 fn expand_index_mut(
54 scope: &mut Scope,
55 array: Self::ExpandType,
56 index: ExpandElementTyped<u32>,
57 value: <Self::Output as CubeType>::ExpandType,
58 ) {
59 array.expand_index_mut(scope, index, value)
60 }
61}
62
63pub trait CubeIndexMutExpand {
64 type Output;
65 fn expand_index_mut(
66 self,
67 scope: &mut Scope,
68 index: ExpandElementTyped<u32>,
69 value: Self::Output,
70 );
71}
72
73pub(crate) fn expand_index_native<A: CubeType + CubeIndex>(
74 scope: &mut Scope,
75 array: ExpandElementTyped<A>,
76 index: ExpandElementTyped<u32>,
77 line_size: Option<u32>,
78 checked: bool,
79) -> ExpandElementTyped<A::Output>
80where
81 A::Output: CubeType + Sized,
82{
83 let index: ExpandElement = index.into();
84 let index_var: Variable = *index;
85 let index = match index_var.kind {
86 VariableKind::ConstantScalar(value) => ExpandElement::Plain(Variable::constant(
87 crate::ir::ConstantScalarValue::UInt(value.as_u64(), UIntKind::U32),
88 )),
89 _ => index,
90 };
91 let array: ExpandElement = array.into();
92 let var: Variable = *array;
93 let var = if checked {
94 match var.kind {
95 VariableKind::LocalMut { .. } | VariableKind::LocalConst { .. } => {
96 index_expand_no_vec(scope, array, index, Operator::Index)
97 }
98 _ => index_expand(scope, array, index, line_size, Operator::Index),
99 }
100 } else {
101 match var.kind {
102 VariableKind::LocalMut { .. } | VariableKind::LocalConst { .. } => {
103 index_expand_no_vec(scope, array, index, Operator::UncheckedIndex)
104 }
105 _ => index_expand(scope, array, index, line_size, Operator::UncheckedIndex),
106 }
107 };
108
109 ExpandElementTyped::new(var)
110}
111
112pub(crate) fn expand_index_assign_native<
113 A: CubeType<ExpandType = ExpandElementTyped<A>> + CubeIndexMut,
114>(
115 scope: &mut Scope,
116 array: A::ExpandType,
117 index: ExpandElementTyped<u32>,
118 value: ExpandElementTyped<A::Output>,
119 line_size: Option<u32>,
120 checked: bool,
121) where
122 A::Output: CubeType + Sized,
123{
124 let index: Variable = index.expand.into();
125 let index = match index.kind {
126 VariableKind::ConstantScalar(value) => Variable::constant(
127 crate::ir::ConstantScalarValue::UInt(value.as_u64(), UIntKind::U32),
128 ),
129 _ => index,
130 };
131 let line_size = line_size.unwrap_or(0);
132 if checked {
133 scope.register(Instruction::new(
134 Operator::IndexAssign(IndexAssignOperator {
135 index,
136 value: value.expand.into(),
137 line_size,
138 }),
139 array.expand.into(),
140 ));
141 } else {
142 scope.register(Instruction::new(
143 Operator::UncheckedIndexAssign(IndexAssignOperator {
144 index,
145 value: value.expand.into(),
146 line_size,
147 }),
148 array.expand.into(),
149 ));
150 }
151}
152
153pub trait Index {
154 fn value(self) -> Variable;
155}
156
157impl Index for i32 {
158 fn value(self) -> Variable {
159 Variable::constant(crate::ir::ConstantScalarValue::Int(
160 self as i64,
161 IntKind::I32,
162 ))
163 }
164}
165
166impl Index for u32 {
167 fn value(self) -> Variable {
168 Variable::constant(crate::ir::ConstantScalarValue::UInt(
169 self as u64,
170 UIntKind::U32,
171 ))
172 }
173}
174
175impl Index for ExpandElement {
176 fn value(self) -> Variable {
177 *self
178 }
179}
180
181impl Index for ExpandElementTyped<u32> {
182 fn value(self) -> Variable {
183 *self.expand
184 }
185}