cubecl_core/frontend/container/sequence/
base.rs

1use cubecl_ir::{ExpandElement, Scope};
2use serde::{Deserialize, Serialize};
3
4use crate::{
5    frontend::{CubeType, ExpandElementTyped, IntoMut, branch::Iterable, indexation::Index},
6    prelude::CubeDebug,
7};
8use std::{cell::RefCell, rc::Rc};
9
10/// A sequence of [cube types](CubeType) that is inlined during compilation.
11///
12/// In other words, it allows you to group a dynamic amount of variables at compile time.
13///
14/// All methods [push](Sequence::push), [index](Sequence::index) and
15/// [into_iter](Sequence::into_iter) are executed _during_ compilation and don't add any overhead
16/// on the generated kernel.
17#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]
18pub struct Sequence<T: CubeType> {
19    values: Vec<T>,
20}
21
22impl<T: CubeType> Default for Sequence<T> {
23    fn default() -> Self {
24        Self::new()
25    }
26}
27
28impl<T: CubeType> IntoMut for Sequence<T> {
29    fn into_mut(self, _scope: &mut Scope) -> Self {
30        self
31    }
32}
33impl<T: CubeType> CubeDebug for Sequence<T> {}
34
35impl<T: CubeType + Clone> Sequence<T> {
36    pub fn rev(&self) -> Self {
37        Self {
38            values: self.values.iter().rev().cloned().collect(),
39        }
40    }
41}
42
43impl<T: CubeType> Sequence<T> {
44    /// Create a new empty sequence.
45    pub fn new() -> Self {
46        Self { values: Vec::new() }
47    }
48
49    /// Push a new value into the sequence.
50    pub fn push(&mut self, value: T) {
51        self.values.push(value);
52    }
53
54    /// Obtain the sequence length.
55    #[allow(clippy::len_without_is_empty)]
56    pub fn len(&self) -> u32 {
57        self.values.len() as u32
58    }
59
60    /// Get the variable at the given position in the sequence.
61    #[allow(unused_variables, clippy::should_implement_trait)]
62    pub fn index<I: Index>(&self, index: I) -> &T {
63        let index: ExpandElementTyped<u32> = ExpandElement::Plain(index.value()).into();
64        let index = index
65            .constant()
66            .expect("Only constant are supported")
67            .as_usize();
68
69        self.values.get(index).unwrap()
70    }
71
72    /// Get the variable at the given position in the sequence.
73    #[allow(unused_variables, clippy::should_implement_trait)]
74    pub fn index_mut<I: Index>(&mut self, index: I) -> &mut T {
75        let index: ExpandElementTyped<u32> = ExpandElement::Plain(index.value()).into();
76        let index = index
77            .constant()
78            .expect("Only constant are supported")
79            .as_usize();
80
81        self.values.get_mut(index).unwrap()
82    }
83
84    /// Expand function of [new](Self::new).
85    pub fn __expand_new(_scope: &mut Scope) -> SequenceExpand<T> {
86        SequenceExpand {
87            values: Rc::new(RefCell::new(Vec::new())),
88        }
89    }
90
91    /// Insert an item at the given index.
92    #[allow(unused_variables, clippy::should_implement_trait)]
93    pub fn insert<I: Index>(&mut self, index: I, value: T) {
94        *self.index_mut(index) = value;
95    }
96
97    /// Expand function of [push](Self::push).
98    pub fn __expand_push(scope: &mut Scope, expand: &mut SequenceExpand<T>, value: T::ExpandType) {
99        expand.__expand_push_method(scope, value)
100    }
101
102    /// Expand function of [index](Self::index).
103    pub fn __expand_index(
104        scope: &mut Scope,
105        expand: SequenceExpand<T>,
106        index: ExpandElementTyped<u32>,
107    ) -> T::ExpandType {
108        expand.__expand_index_method(scope, index)
109    }
110
111    /// Expand function of [index_mut](Self::index_mut).
112    pub fn __expand_index_mut(
113        scope: &mut Scope,
114        expand: SequenceExpand<T>,
115        index: ExpandElementTyped<u32>,
116    ) -> T::ExpandType {
117        expand.__expand_index_mut_method(scope, index)
118    }
119}
120
121/// Expand type of [Sequence].
122pub struct SequenceExpand<T: CubeType> {
123    // We clone the expand type during the compilation phase, but for register reuse, not for
124    // copying data. To achieve the intended behavior, we have to share the same underlying values.
125    pub(super) values: Rc<RefCell<Vec<T::ExpandType>>>,
126}
127
128impl<T: CubeType> Iterable<T> for SequenceExpand<T> {
129    fn expand(self, scope: &mut Scope, func: impl FnMut(&mut Scope, <T as CubeType>::ExpandType)) {
130        self.expand_unroll(scope, func);
131    }
132
133    fn expand_unroll(
134        self,
135        scope: &mut Scope,
136        mut func: impl FnMut(&mut Scope, <T as CubeType>::ExpandType),
137    ) {
138        for elem in self {
139            func(scope, elem);
140        }
141    }
142
143    fn const_len(&self) -> Option<usize> {
144        Some(self.values.borrow().len())
145    }
146}
147
148impl<T: CubeType> IntoMut for SequenceExpand<T> {
149    fn into_mut(self, scope: &mut Scope) -> Self {
150        let mut values = self.values.borrow_mut();
151        values.iter_mut().for_each(|v| {
152            *v = IntoMut::into_mut(v.clone(), scope);
153        });
154        core::mem::drop(values);
155
156        self
157    }
158}
159impl<T: CubeType> CubeDebug for SequenceExpand<T> {}
160
161impl<T: CubeType> Clone for SequenceExpand<T> {
162    fn clone(&self) -> Self {
163        Self {
164            values: self.values.clone(),
165        }
166    }
167}
168
169impl<T: CubeType> IntoIterator for Sequence<T> {
170    type Item = T;
171
172    type IntoIter = <Vec<T> as IntoIterator>::IntoIter;
173
174    fn into_iter(self) -> Self::IntoIter {
175        self.values.into_iter()
176    }
177}
178
179impl<T: CubeType> IntoIterator for SequenceExpand<T> {
180    type Item = T::ExpandType;
181
182    type IntoIter = <Vec<T::ExpandType> as IntoIterator>::IntoIter;
183
184    fn into_iter(self) -> Self::IntoIter {
185        self.values.take().into_iter()
186    }
187}
188
189impl<T: CubeType> SequenceExpand<T> {
190    /// Provides an iterator without modifying the sequence
191    pub fn iter_cloned(&self) -> impl Iterator<Item = T::ExpandType> {
192        self.values.borrow().clone().into_iter()
193    }
194}
195
196impl<T: CubeType> CubeType for Sequence<T> {
197    type ExpandType = SequenceExpand<T>;
198}
199
200impl<T: CubeType> SequenceExpand<T> {
201    #[allow(clippy::len_without_is_empty)]
202    pub fn len(&self) -> u32 {
203        self.values.borrow().len() as u32
204    }
205    /// Expand method of [push](Sequence::push).
206    pub fn __expand_push_method(&mut self, _scope: &mut Scope, value: T::ExpandType) {
207        self.values.borrow_mut().push(value);
208    }
209
210    /// Expand method of [insert](Sequence::insert).
211    pub fn __expand_insert_method(
212        &self,
213        _scope: &mut Scope,
214        index: ExpandElementTyped<u32>,
215        value: T::ExpandType,
216    ) {
217        let index = index
218            .constant()
219            .expect("Only constant are supported")
220            .as_usize();
221
222        let mut values = self.values.borrow_mut();
223
224        if values.len() == index {
225            values.push(value);
226        } else {
227            values[index] = value;
228        }
229    }
230
231    /// Expand method of [index](Sequence::index).
232    pub fn __expand_index_method(
233        &self,
234        _scope: &mut Scope,
235        index: ExpandElementTyped<u32>,
236    ) -> T::ExpandType {
237        let index = index
238            .constant()
239            .expect("Only constant are supported")
240            .as_usize();
241
242        self.values.borrow()[index].clone()
243    }
244
245    /// Expand method of [index_mut](Sequence::index_mut).
246    pub fn __expand_index_mut_method(
247        &self,
248        _scope: &mut Scope,
249        index: ExpandElementTyped<u32>,
250    ) -> T::ExpandType {
251        let index = index
252            .constant()
253            .expect("Only constant are supported")
254            .as_usize();
255
256        self.values.borrow()[index].clone()
257    }
258
259    pub fn __expand_len_method(&self, _scope: &mut Scope) -> u32 {
260        let values = self.values.borrow();
261        values.len() as u32
262    }
263
264    pub fn __expand_rev_method(self, _scope: &mut Scope) -> Self {
265        let mut values = self.values.borrow().clone();
266        values.reverse();
267        Self {
268            values: Rc::new(RefCell::new(values)),
269        }
270    }
271
272    pub fn __expand_clone_method(&self, _scope: &mut Scope) -> Self {
273        self.clone()
274    }
275}