cubecl_core/frontend/
indexation.rs

1use core::ops::{Index, IndexMut};
2
3use cubecl_ir::{
4    ExpandElement, IndexAssignOperator, Instruction, LineSize, Operator, Scope, VariableKind,
5};
6
7use super::{CubeType, ExpandElementTyped, index_expand, index_expand_no_vec};
8use crate::{ir::Variable, prelude::CubePrimitive, unexpanded};
9
10/// Fake indexation so we can rewrite indexes into scalars as calls to this fake function in the
11/// non-expanded function
12pub trait CubeIndex:
13    CubeType<
14    ExpandType: CubeIndexExpand<
15        Idx = <Self::Idx as CubeType>::ExpandType,
16        Output = <Self::Output as CubeType>::ExpandType,
17    >,
18>
19{
20    type Output: CubeType;
21    type Idx: CubeType;
22
23    fn cube_idx(&self, _i: Self::Idx) -> &Self::Output {
24        unexpanded!()
25    }
26
27    fn expand_index(
28        scope: &mut Scope,
29        array: Self::ExpandType,
30        index: <Self::Idx as CubeType>::ExpandType,
31    ) -> <Self::Output as CubeType>::ExpandType {
32        array.expand_index(scope, index)
33    }
34    fn expand_index_unchecked(
35        scope: &mut Scope,
36        array: Self::ExpandType,
37        index: <Self::Idx as CubeType>::ExpandType,
38    ) -> <Self::Output as CubeType>::ExpandType {
39        array.expand_index_unchecked(scope, index)
40    }
41}
42
43/// Workaround for comptime indexing, since the helper that replaces index operators doesn't know
44/// about whether a variable is comptime. Has the same signature in unexpanded code, so it will
45/// automatically dispatch the correct one.
46pub trait ComptimeIndex<I>: Index<I> {
47    fn cube_idx(&self, i: I) -> &Self::Output {
48        self.index(i)
49    }
50}
51
52impl<I, T: Index<I>> ComptimeIndex<I> for T {}
53impl<I, T: IndexMut<I>> ComptimeIndexMut<I> for T {}
54
55pub trait ComptimeIndexMut<I>: ComptimeIndex<I> + IndexMut<I> {
56    fn cube_idx_mut(&mut self, i: I) -> &mut Self::Output {
57        self.index_mut(i)
58    }
59}
60
61pub trait CubeIndexExpand {
62    type Output;
63    type Idx;
64    fn expand_index(self, scope: &mut Scope, index: Self::Idx) -> Self::Output;
65    fn expand_index_unchecked(self, scope: &mut Scope, index: Self::Idx) -> Self::Output;
66}
67
68pub trait CubeIndexMut:
69    CubeIndex
70    + CubeType<ExpandType: CubeIndexMutExpand<Output = <Self::Output as CubeType>::ExpandType>>
71{
72    fn cube_idx_mut(&mut self, _i: <Self as CubeIndex>::Idx) -> &mut <Self as CubeIndex>::Output {
73        unexpanded!()
74    }
75    fn expand_index_mut(
76        scope: &mut Scope,
77        array: Self::ExpandType,
78        index: <Self::Idx as CubeType>::ExpandType,
79        value: <Self::Output as CubeType>::ExpandType,
80    ) {
81        array.expand_index_mut(scope, index, value)
82    }
83}
84
85pub trait CubeIndexMutExpand: CubeIndexExpand {
86    fn expand_index_mut(
87        self,
88        scope: &mut Scope,
89        index: <Self as CubeIndexExpand>::Idx,
90        value: <Self as CubeIndexExpand>::Output,
91    );
92}
93
94pub(crate) fn expand_index_native<A: CubeType + CubeIndex>(
95    scope: &mut Scope,
96    array: ExpandElementTyped<A>,
97    index: ExpandElementTyped<usize>,
98    line_size: Option<LineSize>,
99    checked: bool,
100) -> ExpandElementTyped<A::Output>
101where
102    A::Output: CubeType + Sized,
103{
104    let index: ExpandElement = index.into();
105    let index_var: Variable = *index;
106    let index = match index_var.kind {
107        VariableKind::Constant(value) => {
108            ExpandElement::Plain(Variable::constant(value, usize::as_type(scope)))
109        }
110        _ => index,
111    };
112    let array: ExpandElement = array.into();
113    let var: Variable = *array;
114    let var = if checked {
115        match var.kind {
116            VariableKind::LocalMut { .. } | VariableKind::LocalConst { .. } => {
117                index_expand_no_vec(scope, array, index, Operator::Index)
118            }
119            _ => index_expand(scope, array, index, line_size, Operator::Index),
120        }
121    } else {
122        match var.kind {
123            VariableKind::LocalMut { .. } | VariableKind::LocalConst { .. } => {
124                index_expand_no_vec(scope, array, index, Operator::UncheckedIndex)
125            }
126            _ => index_expand(scope, array, index, line_size, Operator::UncheckedIndex),
127        }
128    };
129
130    ExpandElementTyped::new(var)
131}
132
133pub(crate) fn expand_index_assign_native<
134    A: CubeType<ExpandType = ExpandElementTyped<A>> + CubeIndexMut,
135>(
136    scope: &mut Scope,
137    array: A::ExpandType,
138    index: ExpandElementTyped<usize>,
139    value: ExpandElementTyped<<A as CubeIndex>::Output>,
140    line_size: Option<LineSize>,
141    checked: bool,
142) where
143    A::Output: CubeType + Sized,
144{
145    let index: Variable = index.expand.into();
146    let index = match index.kind {
147        VariableKind::Constant(value) => Variable::constant(value, usize::as_type(scope)),
148        _ => index,
149    };
150
151    let line_size = line_size.unwrap_or(0);
152    if checked {
153        scope.register(Instruction::new(
154            Operator::IndexAssign(IndexAssignOperator {
155                index,
156                value: value.expand.into(),
157                line_size,
158                unroll_factor: 1,
159            }),
160            array.expand.into(),
161        ));
162    } else {
163        scope.register(Instruction::new(
164            Operator::UncheckedIndexAssign(IndexAssignOperator {
165                index,
166                value: value.expand.into(),
167                line_size,
168                unroll_factor: 1,
169            }),
170            array.expand.into(),
171        ));
172    }
173}