Skip to main content

cubecl_core/frontend/container/sequence/
base.rs

1use cubecl_ir::Scope;
2use serde::{Deserialize, Serialize};
3
4use crate::{
5    frontend::{CubeType, ExpandElementTyped, IntoMut, branch::Iterable},
6    prelude::{CubeDebug, CubeIndex, CubeIndexExpand},
7};
8use std::{cell::RefCell, ops::Deref, rc::Rc};
9
10/// A sequence of [cube types](CubeType) that is inlined during compilation.
11///
12/// In other words, it allows you to group a dynamic amount of variables at compile time.
13///
14/// All methods [push](Sequence::push), [index](Sequence::index) and
15/// [`into_iter`](Sequence::into_iter) are executed _during_ compilation and don't add any overhead
16/// on the generated kernel.
17#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]
18pub struct Sequence<T: CubeType> {
19    values: Vec<T>,
20}
21
22impl<T: CubeType> Default for Sequence<T> {
23    fn default() -> Self {
24        Self::new()
25    }
26}
27
28impl<T: CubeType> IntoMut for Sequence<T> {
29    fn into_mut(self, _scope: &mut Scope) -> Self {
30        self
31    }
32}
33impl<T: CubeType> CubeDebug for Sequence<T> {}
34
35impl<T: CubeType + Clone> Sequence<T> {
36    pub fn rev(&self) -> Self {
37        Self {
38            values: self.values.iter().rev().cloned().collect(),
39        }
40    }
41}
42
43impl<T: CubeType> Sequence<T> {
44    /// Create a new empty sequence.
45    pub fn new() -> Self {
46        Self { values: Vec::new() }
47    }
48
49    /// Push a new value into the sequence.
50    pub fn push(&mut self, value: T) {
51        self.values.push(value);
52    }
53
54    /// Obtain the sequence length.
55    #[allow(clippy::len_without_is_empty)]
56    pub fn len(&self) -> usize {
57        self.values.len()
58    }
59
60    /// Get the variable at the given position in the sequence.
61    #[allow(unused_variables, clippy::should_implement_trait)]
62    pub fn index(&self, index: usize) -> &T {
63        self.values.get(index).unwrap()
64    }
65
66    /// Get the variable at the given position in the sequence.
67    #[allow(unused_variables, clippy::should_implement_trait)]
68    pub fn index_mut(&mut self, index: usize) -> &mut T {
69        self.values.get_mut(index).unwrap()
70    }
71
72    /// Expand function of [new](Self::new).
73    pub fn __expand_new(_scope: &mut Scope) -> SequenceExpand<T> {
74        SequenceExpand {
75            values: Rc::new(RefCell::new(Vec::new())),
76        }
77    }
78
79    /// Insert an item at the given index.
80    #[allow(unused_variables, clippy::should_implement_trait)]
81    pub fn insert(&mut self, index: usize, value: T) {
82        *self.index_mut(index) = value;
83    }
84
85    /// Expand function of [push](Self::push).
86    pub fn __expand_push(scope: &mut Scope, expand: &mut SequenceExpand<T>, value: T::ExpandType) {
87        expand.__expand_push_method(scope, value)
88    }
89
90    /// Expand function of [index](Self::index).
91    pub fn __expand_index(
92        scope: &mut Scope,
93        expand: SequenceExpand<T>,
94        index: usize,
95    ) -> T::ExpandType {
96        expand.__expand_index_method(scope, index)
97    }
98
99    /// Expand function of [`index_mut`](Self::index_mut).
100    pub fn __expand_index_mut(
101        scope: &mut Scope,
102        expand: SequenceExpand<T>,
103        index: usize,
104    ) -> T::ExpandType {
105        expand.__expand_index_mut_method(scope, index)
106    }
107}
108
109impl<T: CubeType> CubeIndex for Sequence<T> {
110    type Output = T;
111    type Idx = usize;
112}
113
114impl<T: CubeType> Deref for Sequence<T> {
115    type Target = [T];
116
117    fn deref(&self) -> &Self::Target {
118        &self.values
119    }
120}
121
122impl<T: CubeType> CubeIndexExpand for SequenceExpand<T> {
123    type Output = T::ExpandType;
124    type Idx = ExpandElementTyped<usize>;
125
126    fn expand_index(self, scope: &mut Scope, index: Self::Idx) -> Self::Output {
127        let index = index
128            .constant()
129            .expect("Sequence index must be constant")
130            .as_usize();
131        self.__expand_index_method(scope, index)
132    }
133
134    fn expand_index_unchecked(self, scope: &mut Scope, index: Self::Idx) -> Self::Output {
135        let index = index
136            .constant()
137            .expect("Sequence index must be constant")
138            .as_usize();
139        self.__expand_index_method(scope, index)
140    }
141}
142
143/// Expand type of [Sequence].
144pub struct SequenceExpand<T: CubeType> {
145    // We clone the expand type during the compilation phase, but for register reuse, not for
146    // copying data. To achieve the intended behavior, we have to share the same underlying values.
147    pub(super) values: Rc<RefCell<Vec<T::ExpandType>>>,
148}
149
150impl<T: CubeType> Iterable<T> for SequenceExpand<T> {
151    fn expand(self, scope: &mut Scope, func: impl FnMut(&mut Scope, <T as CubeType>::ExpandType)) {
152        self.expand_unroll(scope, func);
153    }
154
155    fn expand_unroll(
156        self,
157        scope: &mut Scope,
158        mut func: impl FnMut(&mut Scope, <T as CubeType>::ExpandType),
159    ) {
160        for elem in self {
161            func(scope, elem);
162        }
163    }
164
165    fn const_len(&self) -> Option<usize> {
166        Some(self.values.borrow().len())
167    }
168}
169
170impl<T: CubeType> IntoMut for SequenceExpand<T> {
171    fn into_mut(self, scope: &mut Scope) -> Self {
172        let mut values = self.values.borrow_mut();
173        values.iter_mut().for_each(|v| {
174            *v = IntoMut::into_mut(v.clone(), scope);
175        });
176        core::mem::drop(values);
177
178        self
179    }
180}
181impl<T: CubeType> CubeDebug for SequenceExpand<T> {}
182
183impl<T: CubeType> Clone for SequenceExpand<T> {
184    fn clone(&self) -> Self {
185        Self {
186            values: self.values.clone(),
187        }
188    }
189}
190
191impl<T: CubeType> IntoIterator for Sequence<T> {
192    type Item = T;
193
194    type IntoIter = <Vec<T> as IntoIterator>::IntoIter;
195
196    fn into_iter(self) -> Self::IntoIter {
197        self.values.into_iter()
198    }
199}
200
201impl<T: CubeType> IntoIterator for SequenceExpand<T> {
202    type Item = T::ExpandType;
203
204    type IntoIter = <Vec<T::ExpandType> as IntoIterator>::IntoIter;
205
206    fn into_iter(self) -> Self::IntoIter {
207        self.values.take().into_iter()
208    }
209}
210
211impl<T: CubeType> SequenceExpand<T> {
212    /// Provides an iterator without modifying the sequence
213    pub fn iter_cloned(&self) -> impl Iterator<Item = T::ExpandType> {
214        self.values.borrow().clone().into_iter()
215    }
216}
217
218impl<T: CubeType> CubeType for Sequence<T> {
219    type ExpandType = SequenceExpand<T>;
220}
221
222impl<T: CubeType> SequenceExpand<T> {
223    #[allow(clippy::len_without_is_empty)]
224    pub fn len(&self) -> usize {
225        self.values.borrow().len()
226    }
227    /// Expand method of [push](Sequence::push).
228    pub fn __expand_push_method(&mut self, _scope: &mut Scope, value: T::ExpandType) {
229        self.values.borrow_mut().push(value);
230    }
231
232    /// Expand method of [insert](Sequence::insert).
233    pub fn __expand_insert_method(&self, _scope: &mut Scope, index: usize, value: T::ExpandType) {
234        let mut values = self.values.borrow_mut();
235
236        if values.len() == index {
237            values.push(value);
238        } else {
239            values[index] = value;
240        }
241    }
242
243    /// Expand method of [index](Sequence::index).
244    pub fn __expand_index_method(&self, _scope: &mut Scope, index: usize) -> T::ExpandType {
245        self.values.borrow()[index].clone()
246    }
247
248    /// Expand method of [`index_mut`](Sequence::index_mut).
249    pub fn __expand_index_mut_method(&self, _scope: &mut Scope, index: usize) -> T::ExpandType {
250        self.values.borrow()[index].clone()
251    }
252
253    pub fn __expand_len_method(&self, _scope: &mut Scope) -> usize {
254        let values = self.values.borrow();
255        values.len()
256    }
257
258    pub fn __expand_rev_method(self, _scope: &mut Scope) -> Self {
259        let mut values = self.values.borrow().clone();
260        values.reverse();
261        Self {
262            values: Rc::new(RefCell::new(values)),
263        }
264    }
265
266    pub fn __expand_clone_method(&self, _scope: &mut Scope) -> Self {
267        self.clone()
268    }
269}