cubecl_core/frontend/
indexation.rs

1use cubecl_ir::{ExpandElement, IndexAssignOperator, Instruction, Operator, Scope, VariableKind};
2
3use super::{CubeType, ExpandElementTyped, index_expand, index_expand_no_vec};
4use crate::{
5    ir::{IntKind, UIntKind, Variable},
6    unexpanded,
7};
8
9/// Fake indexation so we can rewrite indexes into scalars as calls to this fake function in the
10/// non-expanded function
11pub trait CubeIndex:
12    CubeType<ExpandType: CubeIndexExpand<Output = <Self::Output as CubeType>::ExpandType>>
13{
14    type Output: CubeType;
15
16    fn cube_idx(&self, _i: u32) -> &Self::Output {
17        unexpanded!()
18    }
19
20    fn expand_index(
21        scope: &mut Scope,
22        array: Self::ExpandType,
23        index: ExpandElementTyped<u32>,
24    ) -> <Self::Output as CubeType>::ExpandType {
25        array.expand_index(scope, index)
26    }
27    fn expand_index_unchecked(
28        scope: &mut Scope,
29        array: Self::ExpandType,
30        index: ExpandElementTyped<u32>,
31    ) -> <Self::Output as CubeType>::ExpandType {
32        array.expand_index_unchecked(scope, index)
33    }
34}
35
36pub trait CubeIndexExpand {
37    type Output;
38    fn expand_index(self, scope: &mut Scope, index: ExpandElementTyped<u32>) -> Self::Output;
39    fn expand_index_unchecked(
40        self,
41        scope: &mut Scope,
42        index: ExpandElementTyped<u32>,
43    ) -> Self::Output;
44}
45
46pub trait CubeIndexMut:
47    CubeIndex
48    + CubeType<ExpandType: CubeIndexMutExpand<Output = <Self::Output as CubeType>::ExpandType>>
49{
50    fn cube_idx_mut(&mut self, _i: u32) -> &mut Self::Output {
51        unexpanded!()
52    }
53    fn expand_index_mut(
54        scope: &mut Scope,
55        array: Self::ExpandType,
56        index: ExpandElementTyped<u32>,
57        value: <Self::Output as CubeType>::ExpandType,
58    ) {
59        array.expand_index_mut(scope, index, value)
60    }
61}
62
63pub trait CubeIndexMutExpand {
64    type Output;
65    fn expand_index_mut(
66        self,
67        scope: &mut Scope,
68        index: ExpandElementTyped<u32>,
69        value: Self::Output,
70    );
71}
72
73pub(crate) fn expand_index_native<A: CubeType + CubeIndex>(
74    scope: &mut Scope,
75    array: ExpandElementTyped<A>,
76    index: ExpandElementTyped<u32>,
77    line_size: Option<u32>,
78    checked: bool,
79) -> ExpandElementTyped<A::Output>
80where
81    A::Output: CubeType + Sized,
82{
83    let index: ExpandElement = index.into();
84    let index_var: Variable = *index;
85    let index = match index_var.kind {
86        VariableKind::ConstantScalar(value) => ExpandElement::Plain(Variable::constant(
87            crate::ir::ConstantScalarValue::UInt(value.as_u64(), UIntKind::U32),
88        )),
89        _ => index,
90    };
91    let array: ExpandElement = array.into();
92    let var: Variable = *array;
93    let var = if checked {
94        match var.kind {
95            VariableKind::LocalMut { .. } | VariableKind::LocalConst { .. } => {
96                index_expand_no_vec(scope, array, index, Operator::Index)
97            }
98            _ => index_expand(scope, array, index, line_size, Operator::Index),
99        }
100    } else {
101        match var.kind {
102            VariableKind::LocalMut { .. } | VariableKind::LocalConst { .. } => {
103                index_expand_no_vec(scope, array, index, Operator::UncheckedIndex)
104            }
105            _ => index_expand(scope, array, index, line_size, Operator::UncheckedIndex),
106        }
107    };
108
109    ExpandElementTyped::new(var)
110}
111
112pub(crate) fn expand_index_assign_native<
113    A: CubeType<ExpandType = ExpandElementTyped<A>> + CubeIndexMut,
114>(
115    scope: &mut Scope,
116    array: A::ExpandType,
117    index: ExpandElementTyped<u32>,
118    value: ExpandElementTyped<A::Output>,
119    line_size: Option<u32>,
120    checked: bool,
121) where
122    A::Output: CubeType + Sized,
123{
124    let index: Variable = index.expand.into();
125    let index = match index.kind {
126        VariableKind::ConstantScalar(value) => Variable::constant(
127            crate::ir::ConstantScalarValue::UInt(value.as_u64(), UIntKind::U32),
128        ),
129        _ => index,
130    };
131    let line_size = line_size.unwrap_or(0);
132    if checked {
133        scope.register(Instruction::new(
134            Operator::IndexAssign(IndexAssignOperator {
135                index,
136                value: value.expand.into(),
137                line_size,
138            }),
139            array.expand.into(),
140        ));
141    } else {
142        scope.register(Instruction::new(
143            Operator::UncheckedIndexAssign(IndexAssignOperator {
144                index,
145                value: value.expand.into(),
146                line_size,
147            }),
148            array.expand.into(),
149        ));
150    }
151}
152
153pub trait Index {
154    fn value(self) -> Variable;
155}
156
157impl Index for i32 {
158    fn value(self) -> Variable {
159        Variable::constant(crate::ir::ConstantScalarValue::Int(
160            self as i64,
161            IntKind::I32,
162        ))
163    }
164}
165
166impl Index for u32 {
167    fn value(self) -> Variable {
168        Variable::constant(crate::ir::ConstantScalarValue::UInt(
169            self as u64,
170            UIntKind::U32,
171        ))
172    }
173}
174
175impl Index for ExpandElement {
176    fn value(self) -> Variable {
177        *self
178    }
179}
180
181impl Index for ExpandElementTyped<u32> {
182    fn value(self) -> Variable {
183        *self.expand
184    }
185}