Skip to main content

cubecl_core/frontend/container/sequence/
base.rs

1use alloc::vec::Vec;
2use cubecl_ir::Scope;
3use serde::{Deserialize, Serialize};
4
5use crate::{
6    frontend::{CubeType, ExpandElementTyped, IntoMut, branch::Iterable},
7    prelude::{CubeDebug, CubeIndex, CubeIndexExpand},
8};
9use alloc::rc::Rc;
10use core::{cell::RefCell, ops::Deref};
11
12/// A sequence of [cube types](CubeType) that is inlined during compilation.
13///
14/// In other words, it allows you to group a dynamic amount of variables at compile time.
15///
16/// All methods [push](Sequence::push), [index](Sequence::index) and
17/// [`into_iter`](Sequence::into_iter) are executed _during_ compilation and don't add any overhead
18/// on the generated kernel.
19#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]
20pub struct Sequence<T: CubeType> {
21    values: Vec<T>,
22}
23
24impl<T: CubeType> Default for Sequence<T> {
25    fn default() -> Self {
26        Self::new()
27    }
28}
29
30impl<T: CubeType> IntoMut for Sequence<T> {
31    fn into_mut(self, _scope: &mut Scope) -> Self {
32        self
33    }
34}
35impl<T: CubeType> CubeDebug for Sequence<T> {}
36
37impl<T: CubeType + Clone> Sequence<T> {
38    pub fn rev(&self) -> Self {
39        Self {
40            values: self.values.iter().rev().cloned().collect(),
41        }
42    }
43}
44
45impl<T: CubeType> Sequence<T> {
46    /// Create a new empty sequence.
47    pub fn new() -> Self {
48        Self { values: Vec::new() }
49    }
50
51    /// Push a new value into the sequence.
52    pub fn push(&mut self, value: T) {
53        self.values.push(value);
54    }
55
56    /// Obtain the sequence length.
57    #[allow(clippy::len_without_is_empty)]
58    pub fn len(&self) -> usize {
59        self.values.len()
60    }
61
62    /// Get the variable at the given position in the sequence.
63    #[allow(unused_variables, clippy::should_implement_trait)]
64    pub fn index(&self, index: usize) -> &T {
65        self.values.get(index).unwrap()
66    }
67
68    /// Get the variable at the given position in the sequence.
69    #[allow(unused_variables, clippy::should_implement_trait)]
70    pub fn index_mut(&mut self, index: usize) -> &mut T {
71        self.values.get_mut(index).unwrap()
72    }
73
74    /// Expand function of [new](Self::new).
75    pub fn __expand_new(_scope: &mut Scope) -> SequenceExpand<T> {
76        SequenceExpand {
77            values: Rc::new(RefCell::new(Vec::new())),
78        }
79    }
80
81    /// Insert an item at the given index.
82    #[allow(unused_variables, clippy::should_implement_trait)]
83    pub fn insert(&mut self, index: usize, value: T) {
84        *self.index_mut(index) = value;
85    }
86
87    /// Expand function of [push](Self::push).
88    pub fn __expand_push(scope: &mut Scope, expand: &mut SequenceExpand<T>, value: T::ExpandType) {
89        expand.__expand_push_method(scope, value)
90    }
91
92    /// Expand function of [index](Self::index).
93    pub fn __expand_index(
94        scope: &mut Scope,
95        expand: SequenceExpand<T>,
96        index: usize,
97    ) -> T::ExpandType {
98        expand.__expand_index_method(scope, index)
99    }
100
101    /// Expand function of [`index_mut`](Self::index_mut).
102    pub fn __expand_index_mut(
103        scope: &mut Scope,
104        expand: SequenceExpand<T>,
105        index: usize,
106    ) -> T::ExpandType {
107        expand.__expand_index_mut_method(scope, index)
108    }
109}
110
111impl<T: CubeType> CubeIndex for Sequence<T> {
112    type Output = T;
113    type Idx = usize;
114}
115
116impl<T: CubeType> Deref for Sequence<T> {
117    type Target = [T];
118
119    fn deref(&self) -> &Self::Target {
120        &self.values
121    }
122}
123
124impl<T: CubeType> CubeIndexExpand for SequenceExpand<T> {
125    type Output = T::ExpandType;
126    type Idx = ExpandElementTyped<usize>;
127
128    fn expand_index(self, scope: &mut Scope, index: Self::Idx) -> Self::Output {
129        let index = index
130            .constant()
131            .expect("Sequence index must be constant")
132            .as_usize();
133        self.__expand_index_method(scope, index)
134    }
135
136    fn expand_index_unchecked(self, scope: &mut Scope, index: Self::Idx) -> Self::Output {
137        let index = index
138            .constant()
139            .expect("Sequence index must be constant")
140            .as_usize();
141        self.__expand_index_method(scope, index)
142    }
143}
144
145/// Expand type of [Sequence].
146pub struct SequenceExpand<T: CubeType> {
147    // We clone the expand type during the compilation phase, but for register reuse, not for
148    // copying data. To achieve the intended behavior, we have to share the same underlying values.
149    pub(super) values: Rc<RefCell<Vec<T::ExpandType>>>,
150}
151
152impl<T: CubeType> Iterable<T> for SequenceExpand<T> {
153    fn expand(self, scope: &mut Scope, func: impl FnMut(&mut Scope, <T as CubeType>::ExpandType)) {
154        self.expand_unroll(scope, func);
155    }
156
157    fn expand_unroll(
158        self,
159        scope: &mut Scope,
160        mut func: impl FnMut(&mut Scope, <T as CubeType>::ExpandType),
161    ) {
162        for elem in self {
163            func(scope, elem);
164        }
165    }
166
167    fn const_len(&self) -> Option<usize> {
168        Some(self.values.borrow().len())
169    }
170}
171
172impl<T: CubeType> IntoMut for SequenceExpand<T> {
173    fn into_mut(self, scope: &mut Scope) -> Self {
174        let mut values = self.values.borrow_mut();
175        values.iter_mut().for_each(|v| {
176            *v = IntoMut::into_mut(v.clone(), scope);
177        });
178        core::mem::drop(values);
179
180        self
181    }
182}
183impl<T: CubeType> CubeDebug for SequenceExpand<T> {}
184
185impl<T: CubeType> Clone for SequenceExpand<T> {
186    fn clone(&self) -> Self {
187        Self {
188            values: self.values.clone(),
189        }
190    }
191}
192
193impl<T: CubeType> IntoIterator for Sequence<T> {
194    type Item = T;
195
196    type IntoIter = <Vec<T> as IntoIterator>::IntoIter;
197
198    fn into_iter(self) -> Self::IntoIter {
199        self.values.into_iter()
200    }
201}
202
203impl<T: CubeType> IntoIterator for SequenceExpand<T> {
204    type Item = T::ExpandType;
205
206    type IntoIter = <Vec<T::ExpandType> as IntoIterator>::IntoIter;
207
208    fn into_iter(self) -> Self::IntoIter {
209        self.values.take().into_iter()
210    }
211}
212
213impl<T: CubeType> SequenceExpand<T> {
214    /// Provides an iterator without modifying the sequence
215    pub fn iter_cloned(&self) -> impl Iterator<Item = T::ExpandType> {
216        self.values.borrow().clone().into_iter()
217    }
218}
219
220impl<T: CubeType> CubeType for Sequence<T> {
221    type ExpandType = SequenceExpand<T>;
222}
223
224impl<T: CubeType> SequenceExpand<T> {
225    #[allow(clippy::len_without_is_empty)]
226    pub fn len(&self) -> usize {
227        self.values.borrow().len()
228    }
229    /// Expand method of [push](Sequence::push).
230    pub fn __expand_push_method(&mut self, _scope: &mut Scope, value: T::ExpandType) {
231        self.values.borrow_mut().push(value);
232    }
233
234    /// Expand method of [insert](Sequence::insert).
235    pub fn __expand_insert_method(&self, _scope: &mut Scope, index: usize, value: T::ExpandType) {
236        let mut values = self.values.borrow_mut();
237
238        if values.len() == index {
239            values.push(value);
240        } else {
241            values[index] = value;
242        }
243    }
244
245    /// Expand method of [index](Sequence::index).
246    pub fn __expand_index_method(&self, _scope: &mut Scope, index: usize) -> T::ExpandType {
247        self.values.borrow()[index].clone()
248    }
249
250    /// Expand method of [`index_mut`](Sequence::index_mut).
251    pub fn __expand_index_mut_method(&self, _scope: &mut Scope, index: usize) -> T::ExpandType {
252        self.values.borrow()[index].clone()
253    }
254
255    pub fn __expand_len_method(&self, _scope: &mut Scope) -> usize {
256        let values = self.values.borrow();
257        values.len()
258    }
259
260    pub fn __expand_rev_method(self, _scope: &mut Scope) -> Self {
261        let mut values = self.values.borrow().clone();
262        values.reverse();
263        Self {
264            values: Rc::new(RefCell::new(values)),
265        }
266    }
267
268    pub fn __expand_clone_method(&self, _scope: &mut Scope) -> Self {
269        self.clone()
270    }
271}