cubecl_core/frontend/container/sequence/
base.rs1use alloc::vec::Vec;
2use cubecl_ir::Scope;
3use serde::{Deserialize, Serialize};
4
5use crate::{
6 frontend::{CubeType, ExpandElementTyped, IntoMut, branch::Iterable},
7 prelude::{CubeDebug, CubeIndex, CubeIndexExpand},
8};
9use alloc::rc::Rc;
10use core::{cell::RefCell, ops::Deref};
11
12#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]
20pub struct Sequence<T: CubeType> {
21 values: Vec<T>,
22}
23
24impl<T: CubeType> Default for Sequence<T> {
25 fn default() -> Self {
26 Self::new()
27 }
28}
29
30impl<T: CubeType> IntoMut for Sequence<T> {
31 fn into_mut(self, _scope: &mut Scope) -> Self {
32 self
33 }
34}
35impl<T: CubeType> CubeDebug for Sequence<T> {}
36
37impl<T: CubeType + Clone> Sequence<T> {
38 pub fn rev(&self) -> Self {
39 Self {
40 values: self.values.iter().rev().cloned().collect(),
41 }
42 }
43}
44
45impl<T: CubeType> Sequence<T> {
46 pub fn new() -> Self {
48 Self { values: Vec::new() }
49 }
50
51 pub fn push(&mut self, value: T) {
53 self.values.push(value);
54 }
55
56 #[allow(clippy::len_without_is_empty)]
58 pub fn len(&self) -> usize {
59 self.values.len()
60 }
61
62 #[allow(unused_variables, clippy::should_implement_trait)]
64 pub fn index(&self, index: usize) -> &T {
65 self.values.get(index).unwrap()
66 }
67
68 #[allow(unused_variables, clippy::should_implement_trait)]
70 pub fn index_mut(&mut self, index: usize) -> &mut T {
71 self.values.get_mut(index).unwrap()
72 }
73
74 pub fn __expand_new(_scope: &mut Scope) -> SequenceExpand<T> {
76 SequenceExpand {
77 values: Rc::new(RefCell::new(Vec::new())),
78 }
79 }
80
81 #[allow(unused_variables, clippy::should_implement_trait)]
83 pub fn insert(&mut self, index: usize, value: T) {
84 *self.index_mut(index) = value;
85 }
86
87 pub fn __expand_push(scope: &mut Scope, expand: &mut SequenceExpand<T>, value: T::ExpandType) {
89 expand.__expand_push_method(scope, value)
90 }
91
92 pub fn __expand_index(
94 scope: &mut Scope,
95 expand: SequenceExpand<T>,
96 index: usize,
97 ) -> T::ExpandType {
98 expand.__expand_index_method(scope, index)
99 }
100
101 pub fn __expand_index_mut(
103 scope: &mut Scope,
104 expand: SequenceExpand<T>,
105 index: usize,
106 ) -> T::ExpandType {
107 expand.__expand_index_mut_method(scope, index)
108 }
109}
110
111impl<T: CubeType> CubeIndex for Sequence<T> {
112 type Output = T;
113 type Idx = usize;
114}
115
116impl<T: CubeType> Deref for Sequence<T> {
117 type Target = [T];
118
119 fn deref(&self) -> &Self::Target {
120 &self.values
121 }
122}
123
124impl<T: CubeType> CubeIndexExpand for SequenceExpand<T> {
125 type Output = T::ExpandType;
126 type Idx = ExpandElementTyped<usize>;
127
128 fn expand_index(self, scope: &mut Scope, index: Self::Idx) -> Self::Output {
129 let index = index
130 .constant()
131 .expect("Sequence index must be constant")
132 .as_usize();
133 self.__expand_index_method(scope, index)
134 }
135
136 fn expand_index_unchecked(self, scope: &mut Scope, index: Self::Idx) -> Self::Output {
137 let index = index
138 .constant()
139 .expect("Sequence index must be constant")
140 .as_usize();
141 self.__expand_index_method(scope, index)
142 }
143}
144
145pub struct SequenceExpand<T: CubeType> {
147 pub(super) values: Rc<RefCell<Vec<T::ExpandType>>>,
150}
151
152impl<T: CubeType> Iterable<T> for SequenceExpand<T> {
153 fn expand(self, scope: &mut Scope, func: impl FnMut(&mut Scope, <T as CubeType>::ExpandType)) {
154 self.expand_unroll(scope, func);
155 }
156
157 fn expand_unroll(
158 self,
159 scope: &mut Scope,
160 mut func: impl FnMut(&mut Scope, <T as CubeType>::ExpandType),
161 ) {
162 for elem in self {
163 func(scope, elem);
164 }
165 }
166
167 fn const_len(&self) -> Option<usize> {
168 Some(self.values.borrow().len())
169 }
170}
171
172impl<T: CubeType> IntoMut for SequenceExpand<T> {
173 fn into_mut(self, scope: &mut Scope) -> Self {
174 let mut values = self.values.borrow_mut();
175 values.iter_mut().for_each(|v| {
176 *v = IntoMut::into_mut(v.clone(), scope);
177 });
178 core::mem::drop(values);
179
180 self
181 }
182}
183impl<T: CubeType> CubeDebug for SequenceExpand<T> {}
184
185impl<T: CubeType> Clone for SequenceExpand<T> {
186 fn clone(&self) -> Self {
187 Self {
188 values: self.values.clone(),
189 }
190 }
191}
192
193impl<T: CubeType> IntoIterator for Sequence<T> {
194 type Item = T;
195
196 type IntoIter = <Vec<T> as IntoIterator>::IntoIter;
197
198 fn into_iter(self) -> Self::IntoIter {
199 self.values.into_iter()
200 }
201}
202
203impl<T: CubeType> IntoIterator for SequenceExpand<T> {
204 type Item = T::ExpandType;
205
206 type IntoIter = <Vec<T::ExpandType> as IntoIterator>::IntoIter;
207
208 fn into_iter(self) -> Self::IntoIter {
209 self.values.take().into_iter()
210 }
211}
212
213impl<T: CubeType> SequenceExpand<T> {
214 pub fn iter_cloned(&self) -> impl Iterator<Item = T::ExpandType> {
216 self.values.borrow().clone().into_iter()
217 }
218}
219
220impl<T: CubeType> CubeType for Sequence<T> {
221 type ExpandType = SequenceExpand<T>;
222}
223
224impl<T: CubeType> SequenceExpand<T> {
225 #[allow(clippy::len_without_is_empty)]
226 pub fn len(&self) -> usize {
227 self.values.borrow().len()
228 }
229 pub fn __expand_push_method(&mut self, _scope: &mut Scope, value: T::ExpandType) {
231 self.values.borrow_mut().push(value);
232 }
233
234 pub fn __expand_insert_method(&self, _scope: &mut Scope, index: usize, value: T::ExpandType) {
236 let mut values = self.values.borrow_mut();
237
238 if values.len() == index {
239 values.push(value);
240 } else {
241 values[index] = value;
242 }
243 }
244
245 pub fn __expand_index_method(&self, _scope: &mut Scope, index: usize) -> T::ExpandType {
247 self.values.borrow()[index].clone()
248 }
249
250 pub fn __expand_index_mut_method(&self, _scope: &mut Scope, index: usize) -> T::ExpandType {
252 self.values.borrow()[index].clone()
253 }
254
255 pub fn __expand_len_method(&self, _scope: &mut Scope) -> usize {
256 let values = self.values.borrow();
257 values.len()
258 }
259
260 pub fn __expand_rev_method(self, _scope: &mut Scope) -> Self {
261 let mut values = self.values.borrow().clone();
262 values.reverse();
263 Self {
264 values: Rc::new(RefCell::new(values)),
265 }
266 }
267
268 pub fn __expand_clone_method(&self, _scope: &mut Scope) -> Self {
269 self.clone()
270 }
271}