cubecl_core/frontend/container/sequence/
base.rs1use cubecl_ir::{ExpandElement, Scope};
2use serde::{Deserialize, Serialize};
3
4use crate::{
5    frontend::{CubeType, ExpandElementTyped, IntoMut, branch::Iterable, indexation::Index},
6    prelude::CubeDebug,
7};
8use std::{cell::RefCell, rc::Rc};
9
10#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]
18pub struct Sequence<T: CubeType> {
19    values: Vec<T>,
20}
21
22impl<T: CubeType> Default for Sequence<T> {
23    fn default() -> Self {
24        Self::new()
25    }
26}
27
28impl<T: CubeType> IntoMut for Sequence<T> {
29    fn into_mut(self, _scope: &mut Scope) -> Self {
30        self
31    }
32}
33impl<T: CubeType> CubeDebug for Sequence<T> {}
34
35impl<T: CubeType + Clone> Sequence<T> {
36    pub fn rev(&self) -> Self {
37        Self {
38            values: self.values.iter().rev().cloned().collect(),
39        }
40    }
41}
42
43impl<T: CubeType> Sequence<T> {
44    pub fn new() -> Self {
46        Self { values: Vec::new() }
47    }
48
49    pub fn push(&mut self, value: T) {
51        self.values.push(value);
52    }
53
54    #[allow(clippy::len_without_is_empty)]
56    pub fn len(&self) -> u32 {
57        self.values.len() as u32
58    }
59
60    #[allow(unused_variables, clippy::should_implement_trait)]
62    pub fn index<I: Index>(&self, index: I) -> &T {
63        let index: ExpandElementTyped<u32> = ExpandElement::Plain(index.value()).into();
64        let index = index
65            .constant()
66            .expect("Only constant are supported")
67            .as_usize();
68
69        self.values.get(index).unwrap()
70    }
71
72    #[allow(unused_variables, clippy::should_implement_trait)]
74    pub fn index_mut<I: Index>(&mut self, index: I) -> &mut T {
75        let index: ExpandElementTyped<u32> = ExpandElement::Plain(index.value()).into();
76        let index = index
77            .constant()
78            .expect("Only constant are supported")
79            .as_usize();
80
81        self.values.get_mut(index).unwrap()
82    }
83
84    pub fn __expand_new(_scope: &mut Scope) -> SequenceExpand<T> {
86        SequenceExpand {
87            values: Rc::new(RefCell::new(Vec::new())),
88        }
89    }
90
91    #[allow(unused_variables, clippy::should_implement_trait)]
93    pub fn insert<I: Index>(&mut self, index: I, value: T) {
94        *self.index_mut(index) = value;
95    }
96
97    pub fn __expand_push(scope: &mut Scope, expand: &mut SequenceExpand<T>, value: T::ExpandType) {
99        expand.__expand_push_method(scope, value)
100    }
101
102    pub fn __expand_index(
104        scope: &mut Scope,
105        expand: SequenceExpand<T>,
106        index: ExpandElementTyped<u32>,
107    ) -> T::ExpandType {
108        expand.__expand_index_method(scope, index)
109    }
110
111    pub fn __expand_index_mut(
113        scope: &mut Scope,
114        expand: SequenceExpand<T>,
115        index: ExpandElementTyped<u32>,
116    ) -> T::ExpandType {
117        expand.__expand_index_mut_method(scope, index)
118    }
119}
120
121pub struct SequenceExpand<T: CubeType> {
123    pub(super) values: Rc<RefCell<Vec<T::ExpandType>>>,
126}
127
128impl<T: CubeType> Iterable<T> for SequenceExpand<T> {
129    fn expand(self, scope: &mut Scope, func: impl FnMut(&mut Scope, <T as CubeType>::ExpandType)) {
130        self.expand_unroll(scope, func);
131    }
132
133    fn expand_unroll(
134        self,
135        scope: &mut Scope,
136        mut func: impl FnMut(&mut Scope, <T as CubeType>::ExpandType),
137    ) {
138        for elem in self {
139            func(scope, elem);
140        }
141    }
142}
143
144impl<T: CubeType> IntoMut for SequenceExpand<T> {
145    fn into_mut(self, scope: &mut Scope) -> Self {
146        let mut values = self.values.borrow_mut();
147        values.iter_mut().for_each(|v| {
148            *v = IntoMut::into_mut(v.clone(), scope);
149        });
150        core::mem::drop(values);
151
152        self
153    }
154}
155impl<T: CubeType> CubeDebug for SequenceExpand<T> {}
156
157impl<T: CubeType> Clone for SequenceExpand<T> {
158    fn clone(&self) -> Self {
159        Self {
160            values: self.values.clone(),
161        }
162    }
163}
164
165impl<T: CubeType> IntoIterator for Sequence<T> {
166    type Item = T;
167
168    type IntoIter = <Vec<T> as IntoIterator>::IntoIter;
169
170    fn into_iter(self) -> Self::IntoIter {
171        self.values.into_iter()
172    }
173}
174
175impl<T: CubeType> IntoIterator for SequenceExpand<T> {
176    type Item = T::ExpandType;
177
178    type IntoIter = <Vec<T::ExpandType> as IntoIterator>::IntoIter;
179
180    fn into_iter(self) -> Self::IntoIter {
181        self.values.take().into_iter()
182    }
183}
184
185impl<T: CubeType> SequenceExpand<T> {
186    pub fn iter_cloned(&self) -> impl Iterator<Item = T::ExpandType> {
188        self.values.borrow().clone().into_iter()
189    }
190}
191
192impl<T: CubeType> CubeType for Sequence<T> {
193    type ExpandType = SequenceExpand<T>;
194}
195
196impl<T: CubeType> SequenceExpand<T> {
197    #[allow(clippy::len_without_is_empty)]
198    pub fn len(&self) -> u32 {
199        self.values.borrow().len() as u32
200    }
201    pub fn __expand_push_method(&mut self, _scope: &mut Scope, value: T::ExpandType) {
203        self.values.borrow_mut().push(value);
204    }
205
206    pub fn __expand_insert_method(
208        &self,
209        _scope: &mut Scope,
210        index: ExpandElementTyped<u32>,
211        value: T::ExpandType,
212    ) {
213        let index = index
214            .constant()
215            .expect("Only constant are supported")
216            .as_usize();
217
218        let mut values = self.values.borrow_mut();
219
220        if values.len() == index {
221            values.push(value);
222        } else {
223            values[index] = value;
224        }
225    }
226
227    pub fn __expand_index_method(
229        &self,
230        _scope: &mut Scope,
231        index: ExpandElementTyped<u32>,
232    ) -> T::ExpandType {
233        let index = index
234            .constant()
235            .expect("Only constant are supported")
236            .as_usize();
237
238        self.values.borrow()[index].clone()
239    }
240
241    pub fn __expand_index_mut_method(
243        &self,
244        _scope: &mut Scope,
245        index: ExpandElementTyped<u32>,
246    ) -> T::ExpandType {
247        let index = index
248            .constant()
249            .expect("Only constant are supported")
250            .as_usize();
251
252        self.values.borrow()[index].clone()
253    }
254
255    pub fn __expand_len_method(&self, _scope: &mut Scope) -> u32 {
256        let values = self.values.borrow();
257        values.len() as u32
258    }
259
260    pub fn __expand_rev_method(self, _scope: &mut Scope) -> Self {
261        let mut values = self.values.borrow().clone();
262        values.reverse();
263        Self {
264            values: Rc::new(RefCell::new(values)),
265        }
266    }
267
268    pub fn __expand_clone_method(&self, _scope: &mut Scope) -> Self {
269        self.clone()
270    }
271}