Skip to main content

cubecl_core/frontend/
indexation.rs

1use core::ops::{Index, IndexMut};
2
3use cubecl_ir::{
4    IndexAssignOperator, Instruction, ManagedVariable, Operator, Scope, VariableKind, VectorSize,
5};
6
7use super::{CubeType, NativeExpand, index_expand, index_expand_no_vec};
8use crate::{ir::Variable, prelude::CubePrimitive, unexpanded};
9
10/// Fake indexation so we can rewrite indexes into scalars as calls to this fake function in the
11/// non-expanded function
12pub trait CubeIndex:
13    CubeType<
14    ExpandType: CubeIndexExpand<
15        Idx = <Self::Idx as CubeType>::ExpandType,
16        Output = <Self::Output as CubeType>::ExpandType,
17    >,
18>
19{
20    type Output: CubeType;
21    type Idx: CubeType;
22
23    fn cube_idx(&self, _i: Self::Idx) -> &Self::Output {
24        unexpanded!()
25    }
26
27    fn expand_index(
28        scope: &mut Scope,
29        array: Self::ExpandType,
30        index: <Self::Idx as CubeType>::ExpandType,
31    ) -> <Self::Output as CubeType>::ExpandType {
32        array.expand_index(scope, index)
33    }
34    fn expand_index_unchecked(
35        scope: &mut Scope,
36        array: Self::ExpandType,
37        index: <Self::Idx as CubeType>::ExpandType,
38    ) -> <Self::Output as CubeType>::ExpandType {
39        array.expand_index_unchecked(scope, index)
40    }
41}
42
43/// Workaround for comptime indexing, since the helper that replaces index operators doesn't know
44/// about whether a variable is comptime. Has the same signature in unexpanded code, so it will
45/// automatically dispatch the correct one.
46pub trait ComptimeIndex<I>: Index<I> {
47    fn cube_idx(&self, i: I) -> &Self::Output {
48        self.index(i)
49    }
50}
51
52impl<I, T: Index<I>> ComptimeIndex<I> for T {}
53impl<I, T: IndexMut<I>> ComptimeIndexMut<I> for T {}
54
55pub trait ComptimeIndexMut<I>: ComptimeIndex<I> + IndexMut<I> {
56    fn cube_idx_mut(&mut self, i: I) -> &mut Self::Output {
57        self.index_mut(i)
58    }
59}
60
61pub trait CubeIndexExpand {
62    type Output;
63    type Idx;
64    fn expand_index(self, scope: &mut Scope, index: Self::Idx) -> Self::Output;
65    fn expand_index_unchecked(self, scope: &mut Scope, index: Self::Idx) -> Self::Output;
66}
67
68pub trait CubeIndexMut:
69    CubeIndex
70    + CubeType<ExpandType: CubeIndexMutExpand<Output = <Self::Output as CubeType>::ExpandType>>
71{
72    fn cube_idx_mut(&mut self, _i: <Self as CubeIndex>::Idx) -> &mut <Self as CubeIndex>::Output {
73        unexpanded!()
74    }
75    fn expand_index_mut(
76        scope: &mut Scope,
77        array: Self::ExpandType,
78        index: <Self::Idx as CubeType>::ExpandType,
79        value: <Self::Output as CubeType>::ExpandType,
80    ) {
81        array.expand_index_mut(scope, index, value)
82    }
83}
84
85pub trait CubeIndexMutExpand: CubeIndexExpand {
86    fn expand_index_mut(
87        self,
88        scope: &mut Scope,
89        index: <Self as CubeIndexExpand>::Idx,
90        value: <Self as CubeIndexExpand>::Output,
91    );
92}
93
94pub(crate) fn expand_index_native<A: CubeType + CubeIndex>(
95    scope: &mut Scope,
96    array: NativeExpand<A>,
97    index: NativeExpand<usize>,
98    vector_size: Option<VectorSize>,
99    checked: bool,
100) -> NativeExpand<A::Output>
101where
102    A::Output: CubeType + Sized,
103{
104    let index: ManagedVariable = index.into();
105    let index_var: Variable = *index;
106    let index = match index_var.kind {
107        VariableKind::Constant(value) => {
108            ManagedVariable::Plain(Variable::constant(value, usize::as_type(scope)))
109        }
110        _ => index,
111    };
112    let array: ManagedVariable = array.into();
113    let var: Variable = *array;
114    let var = if checked {
115        match var.kind {
116            VariableKind::LocalMut { .. } | VariableKind::LocalConst { .. } => {
117                index_expand_no_vec(scope, array, index, Operator::Index)
118            }
119            _ => index_expand(scope, array, index, vector_size, Operator::Index),
120        }
121    } else {
122        match var.kind {
123            VariableKind::LocalMut { .. } | VariableKind::LocalConst { .. } => {
124                index_expand_no_vec(scope, array, index, Operator::UncheckedIndex)
125            }
126            _ => index_expand(scope, array, index, vector_size, Operator::UncheckedIndex),
127        }
128    };
129
130    NativeExpand::new(var)
131}
132
133pub(crate) fn expand_index_assign_native<A: CubeType<ExpandType = NativeExpand<A>> + CubeIndexMut>(
134    scope: &mut Scope,
135    array: A::ExpandType,
136    index: NativeExpand<usize>,
137    value: NativeExpand<<A as CubeIndex>::Output>,
138    vector_size: Option<VectorSize>,
139    checked: bool,
140) where
141    A::Output: CubeType + Sized,
142{
143    let index: Variable = index.expand.into();
144    let index = match index.kind {
145        VariableKind::Constant(value) => Variable::constant(value, usize::as_type(scope)),
146        _ => index,
147    };
148
149    let vector_size = vector_size.unwrap_or(0);
150    if checked {
151        scope.register(Instruction::new(
152            Operator::IndexAssign(IndexAssignOperator {
153                index,
154                value: value.expand.into(),
155                vector_size,
156                unroll_factor: 1,
157            }),
158            array.expand.into(),
159        ));
160    } else {
161        scope.register(Instruction::new(
162            Operator::UncheckedIndexAssign(IndexAssignOperator {
163                index,
164                value: value.expand.into(),
165                vector_size,
166                unroll_factor: 1,
167            }),
168            array.expand.into(),
169        ));
170    }
171}