cubecl_core/frontend/container/sequence/
base.rs

1use cubecl_ir::{ExpandElement, Scope};
2use serde::{Deserialize, Serialize};
3
4use crate::{
5    frontend::{CubeType, ExpandElementTyped, IntoMut, branch::Iterable, indexation::Index},
6    prelude::CubeDebug,
7};
8use std::{cell::RefCell, rc::Rc};
9
10/// A sequence of [cube types](CubeType) that is inlined during compilation.
11///
12/// In other words, it allows you to group a dynamic amount of variables at compile time.
13///
14/// All methods [push](Sequence::push), [index](Sequence::index) and
15/// [into_iter](Sequence::into_iter) are executed _during_ compilation and don't add any overhead
16/// on the generated kernel.
17#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]
18pub struct Sequence<T: CubeType> {
19    values: Vec<T>,
20}
21
22impl<T: CubeType> Default for Sequence<T> {
23    fn default() -> Self {
24        Self::new()
25    }
26}
27
28impl<T: CubeType> IntoMut for Sequence<T> {
29    fn into_mut(self, _scope: &mut Scope) -> Self {
30        self
31    }
32}
33impl<T: CubeType> CubeDebug for Sequence<T> {}
34
35impl<T: CubeType + Clone> Sequence<T> {
36    pub fn rev(&self) -> Self {
37        Self {
38            values: self.values.iter().rev().cloned().collect(),
39        }
40    }
41}
42
43impl<T: CubeType> Sequence<T> {
44    /// Create a new empty sequence.
45    pub fn new() -> Self {
46        Self { values: Vec::new() }
47    }
48
49    /// Push a new value into the sequence.
50    pub fn push(&mut self, value: T) {
51        self.values.push(value);
52    }
53
54    /// Obtain the sequence length.
55    #[allow(clippy::len_without_is_empty)]
56    pub fn len(&self) -> u32 {
57        self.values.len() as u32
58    }
59
60    /// Get the variable at the given position in the sequence.
61    #[allow(unused_variables, clippy::should_implement_trait)]
62    pub fn index<I: Index>(&self, index: I) -> &T {
63        let index: ExpandElementTyped<u32> = ExpandElement::Plain(index.value()).into();
64        let index = index
65            .constant()
66            .expect("Only constant are supported")
67            .as_usize();
68
69        self.values.get(index).unwrap()
70    }
71
72    /// Get the variable at the given position in the sequence.
73    #[allow(unused_variables, clippy::should_implement_trait)]
74    pub fn index_mut<I: Index>(&mut self, index: I) -> &mut T {
75        let index: ExpandElementTyped<u32> = ExpandElement::Plain(index.value()).into();
76        let index = index
77            .constant()
78            .expect("Only constant are supported")
79            .as_usize();
80
81        self.values.get_mut(index).unwrap()
82    }
83
84    /// Expand function of [new](Self::new).
85    pub fn __expand_new(_scope: &mut Scope) -> SequenceExpand<T> {
86        SequenceExpand {
87            values: Rc::new(RefCell::new(Vec::new())),
88        }
89    }
90
91    /// Insert an item at the given index.
92    #[allow(unused_variables, clippy::should_implement_trait)]
93    pub fn insert<I: Index>(&mut self, index: I, value: T) {
94        *self.index_mut(index) = value;
95    }
96
97    /// Expand function of [push](Self::push).
98    pub fn __expand_push(scope: &mut Scope, expand: &mut SequenceExpand<T>, value: T::ExpandType) {
99        expand.__expand_push_method(scope, value)
100    }
101
102    /// Expand function of [index](Self::index).
103    pub fn __expand_index(
104        scope: &mut Scope,
105        expand: SequenceExpand<T>,
106        index: ExpandElementTyped<u32>,
107    ) -> T::ExpandType {
108        expand.__expand_index_method(scope, index)
109    }
110
111    /// Expand function of [index_mut](Self::index_mut).
112    pub fn __expand_index_mut(
113        scope: &mut Scope,
114        expand: SequenceExpand<T>,
115        index: ExpandElementTyped<u32>,
116    ) -> T::ExpandType {
117        expand.__expand_index_mut_method(scope, index)
118    }
119}
120
121/// Expand type of [Sequence].
122pub struct SequenceExpand<T: CubeType> {
123    // We clone the expand type during the compilation phase, but for register reuse, not for
124    // copying data. To achieve the intended behavior, we have to share the same underlying values.
125    pub(super) values: Rc<RefCell<Vec<T::ExpandType>>>,
126}
127
128impl<T: CubeType> Iterable<T> for SequenceExpand<T> {
129    fn expand(self, scope: &mut Scope, func: impl FnMut(&mut Scope, <T as CubeType>::ExpandType)) {
130        self.expand_unroll(scope, func);
131    }
132
133    fn expand_unroll(
134        self,
135        scope: &mut Scope,
136        mut func: impl FnMut(&mut Scope, <T as CubeType>::ExpandType),
137    ) {
138        for elem in self {
139            func(scope, elem);
140        }
141    }
142}
143
144impl<T: CubeType> IntoMut for SequenceExpand<T> {
145    fn into_mut(self, scope: &mut Scope) -> Self {
146        let mut values = self.values.borrow_mut();
147        values.iter_mut().for_each(|v| {
148            *v = IntoMut::into_mut(v.clone(), scope);
149        });
150        core::mem::drop(values);
151
152        self
153    }
154}
155impl<T: CubeType> CubeDebug for SequenceExpand<T> {}
156
157impl<T: CubeType> Clone for SequenceExpand<T> {
158    fn clone(&self) -> Self {
159        Self {
160            values: self.values.clone(),
161        }
162    }
163}
164
165impl<T: CubeType> IntoIterator for Sequence<T> {
166    type Item = T;
167
168    type IntoIter = <Vec<T> as IntoIterator>::IntoIter;
169
170    fn into_iter(self) -> Self::IntoIter {
171        self.values.into_iter()
172    }
173}
174
175impl<T: CubeType> IntoIterator for SequenceExpand<T> {
176    type Item = T::ExpandType;
177
178    type IntoIter = <Vec<T::ExpandType> as IntoIterator>::IntoIter;
179
180    fn into_iter(self) -> Self::IntoIter {
181        self.values.take().into_iter()
182    }
183}
184
185impl<T: CubeType> CubeType for Sequence<T> {
186    type ExpandType = SequenceExpand<T>;
187}
188
189impl<T: CubeType> SequenceExpand<T> {
190    #[allow(clippy::len_without_is_empty)]
191    pub fn len(&self) -> u32 {
192        self.values.borrow().len() as u32
193    }
194    /// Expand method of [push](Sequence::push).
195    pub fn __expand_push_method(&mut self, _scope: &mut Scope, value: T::ExpandType) {
196        self.values.borrow_mut().push(value);
197    }
198
199    /// Expand method of [insert](Sequence::insert).
200    pub fn __expand_insert_method(
201        &self,
202        _scope: &mut Scope,
203        index: ExpandElementTyped<u32>,
204        value: T::ExpandType,
205    ) {
206        let index = index
207            .constant()
208            .expect("Only constant are supported")
209            .as_usize();
210
211        let mut values = self.values.borrow_mut();
212
213        if values.len() == index {
214            values.push(value);
215        } else {
216            values[index] = value;
217        }
218    }
219
220    /// Expand method of [index](Sequence::index).
221    pub fn __expand_index_method(
222        &self,
223        _scope: &mut Scope,
224        index: ExpandElementTyped<u32>,
225    ) -> T::ExpandType {
226        let index = index
227            .constant()
228            .expect("Only constant are supported")
229            .as_usize();
230
231        self.values.borrow()[index].clone()
232    }
233
234    /// Expand method of [index_mut](Sequence::index_mut).
235    pub fn __expand_index_mut_method(
236        &self,
237        _scope: &mut Scope,
238        index: ExpandElementTyped<u32>,
239    ) -> T::ExpandType {
240        let index = index
241            .constant()
242            .expect("Only constant are supported")
243            .as_usize();
244
245        self.values.borrow()[index].clone()
246    }
247
248    pub fn __expand_len_method(&self, _scope: &mut Scope) -> u32 {
249        let values = self.values.borrow();
250        values.len() as u32
251    }
252
253    pub fn __expand_rev_method(self, _scope: &mut Scope) -> Self {
254        self.values.borrow_mut().reverse();
255        self
256    }
257}