cubecl_core/frontend/
indexation.rs

1use cubecl_ir::{ExpandElement, IndexAssignOperator, Instruction, Operator, Scope, VariableKind};
2
3use super::{CubeType, ExpandElementTyped, index_expand, index_expand_no_vec};
4use crate::{
5    ir::{IntKind, UIntKind, Variable},
6    unexpanded,
7};
8
9/// Fake indexation so we can rewrite indexes into scalars as calls to this fake function in the
10/// non-expanded function
11pub trait CubeIndex:
12    CubeType<
13    ExpandType: CubeIndexExpand<
14        Idx = <Self::Idx as CubeType>::ExpandType,
15        Output = <Self::Output as CubeType>::ExpandType,
16    >,
17>
18{
19    type Output: CubeType;
20    type Idx: CubeType;
21
22    fn cube_idx(&self, _i: Self::Idx) -> &Self::Output {
23        unexpanded!()
24    }
25
26    fn expand_index(
27        scope: &mut Scope,
28        array: Self::ExpandType,
29        index: <Self::Idx as CubeType>::ExpandType,
30    ) -> <Self::Output as CubeType>::ExpandType {
31        array.expand_index(scope, index)
32    }
33    fn expand_index_unchecked(
34        scope: &mut Scope,
35        array: Self::ExpandType,
36        index: <Self::Idx as CubeType>::ExpandType,
37    ) -> <Self::Output as CubeType>::ExpandType {
38        array.expand_index_unchecked(scope, index)
39    }
40}
41
42pub trait CubeIndexExpand {
43    type Output;
44    type Idx;
45    fn expand_index(self, scope: &mut Scope, index: Self::Idx) -> Self::Output;
46    fn expand_index_unchecked(self, scope: &mut Scope, index: Self::Idx) -> Self::Output;
47}
48
49pub trait CubeIndexMut:
50    CubeIndex
51    + CubeType<ExpandType: CubeIndexMutExpand<Output = <Self::Output as CubeType>::ExpandType>>
52{
53    fn cube_idx_mut(&mut self, _i: Self::Idx) -> &mut Self::Output {
54        unexpanded!()
55    }
56    fn expand_index_mut(
57        scope: &mut Scope,
58        array: Self::ExpandType,
59        index: <Self::Idx as CubeType>::ExpandType,
60        value: <Self::Output as CubeType>::ExpandType,
61    ) {
62        array.expand_index_mut(scope, index, value)
63    }
64}
65
66pub trait CubeIndexMutExpand: CubeIndexExpand {
67    fn expand_index_mut(self, scope: &mut Scope, index: Self::Idx, value: Self::Output);
68}
69
70pub(crate) fn expand_index_native<A: CubeType + CubeIndex>(
71    scope: &mut Scope,
72    array: ExpandElementTyped<A>,
73    index: ExpandElementTyped<u32>,
74    line_size: Option<u32>,
75    checked: bool,
76) -> ExpandElementTyped<A::Output>
77where
78    A::Output: CubeType + Sized,
79{
80    let index: ExpandElement = index.into();
81    let index_var: Variable = *index;
82    let index = match index_var.kind {
83        VariableKind::ConstantScalar(value) => ExpandElement::Plain(Variable::constant(
84            crate::ir::ConstantScalarValue::UInt(value.as_u64(), UIntKind::U32),
85        )),
86        _ => index,
87    };
88    let array: ExpandElement = array.into();
89    let var: Variable = *array;
90    let var = if checked {
91        match var.kind {
92            VariableKind::LocalMut { .. } | VariableKind::LocalConst { .. } => {
93                index_expand_no_vec(scope, array, index, Operator::Index)
94            }
95            _ => index_expand(scope, array, index, line_size, Operator::Index),
96        }
97    } else {
98        match var.kind {
99            VariableKind::LocalMut { .. } | VariableKind::LocalConst { .. } => {
100                index_expand_no_vec(scope, array, index, Operator::UncheckedIndex)
101            }
102            _ => index_expand(scope, array, index, line_size, Operator::UncheckedIndex),
103        }
104    };
105
106    ExpandElementTyped::new(var)
107}
108
109pub(crate) fn expand_index_assign_native<
110    A: CubeType<ExpandType = ExpandElementTyped<A>> + CubeIndexMut,
111>(
112    scope: &mut Scope,
113    array: A::ExpandType,
114    index: ExpandElementTyped<u32>,
115    value: ExpandElementTyped<A::Output>,
116    line_size: Option<u32>,
117    checked: bool,
118) where
119    A::Output: CubeType + Sized,
120{
121    let index: Variable = index.expand.into();
122    let index = match index.kind {
123        VariableKind::ConstantScalar(value) => Variable::constant(
124            crate::ir::ConstantScalarValue::UInt(value.as_u64(), UIntKind::U32),
125        ),
126        _ => index,
127    };
128    let line_size = line_size.unwrap_or(0);
129    if checked {
130        scope.register(Instruction::new(
131            Operator::IndexAssign(IndexAssignOperator {
132                index,
133                value: value.expand.into(),
134                line_size,
135                unroll_factor: 1,
136            }),
137            array.expand.into(),
138        ));
139    } else {
140        scope.register(Instruction::new(
141            Operator::UncheckedIndexAssign(IndexAssignOperator {
142                index,
143                value: value.expand.into(),
144                line_size,
145                unroll_factor: 1,
146            }),
147            array.expand.into(),
148        ));
149    }
150}
151
152pub trait Index {
153    fn value(self) -> Variable;
154}
155
156impl Index for i32 {
157    fn value(self) -> Variable {
158        Variable::constant(crate::ir::ConstantScalarValue::Int(
159            self as i64,
160            IntKind::I32,
161        ))
162    }
163}
164
165impl Index for u32 {
166    fn value(self) -> Variable {
167        Variable::constant(crate::ir::ConstantScalarValue::UInt(
168            self as u64,
169            UIntKind::U32,
170        ))
171    }
172}
173
174impl Index for ExpandElement {
175    fn value(self) -> Variable {
176        *self
177    }
178}
179
180impl Index for ExpandElementTyped<u32> {
181    fn value(self) -> Variable {
182        *self.expand
183    }
184}