cubecl_core/frontend/container/sequence/
base.rs1use cubecl_ir::Scope;
2use serde::{Deserialize, Serialize};
3
4use crate::{
5 frontend::{CubeType, ExpandElementTyped, IntoMut, branch::Iterable},
6 prelude::{CubeDebug, CubeIndex, CubeIndexExpand},
7};
8use std::{cell::RefCell, ops::Deref, rc::Rc};
9
10#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]
18pub struct Sequence<T: CubeType> {
19 values: Vec<T>,
20}
21
22impl<T: CubeType> Default for Sequence<T> {
23 fn default() -> Self {
24 Self::new()
25 }
26}
27
28impl<T: CubeType> IntoMut for Sequence<T> {
29 fn into_mut(self, _scope: &mut Scope) -> Self {
30 self
31 }
32}
33impl<T: CubeType> CubeDebug for Sequence<T> {}
34
35impl<T: CubeType + Clone> Sequence<T> {
36 pub fn rev(&self) -> Self {
37 Self {
38 values: self.values.iter().rev().cloned().collect(),
39 }
40 }
41}
42
43impl<T: CubeType> Sequence<T> {
44 pub fn new() -> Self {
46 Self { values: Vec::new() }
47 }
48
49 pub fn push(&mut self, value: T) {
51 self.values.push(value);
52 }
53
54 #[allow(clippy::len_without_is_empty)]
56 pub fn len(&self) -> usize {
57 self.values.len()
58 }
59
60 #[allow(unused_variables, clippy::should_implement_trait)]
62 pub fn index(&self, index: usize) -> &T {
63 self.values.get(index).unwrap()
64 }
65
66 #[allow(unused_variables, clippy::should_implement_trait)]
68 pub fn index_mut(&mut self, index: usize) -> &mut T {
69 self.values.get_mut(index).unwrap()
70 }
71
72 pub fn __expand_new(_scope: &mut Scope) -> SequenceExpand<T> {
74 SequenceExpand {
75 values: Rc::new(RefCell::new(Vec::new())),
76 }
77 }
78
79 #[allow(unused_variables, clippy::should_implement_trait)]
81 pub fn insert(&mut self, index: usize, value: T) {
82 *self.index_mut(index) = value;
83 }
84
85 pub fn __expand_push(scope: &mut Scope, expand: &mut SequenceExpand<T>, value: T::ExpandType) {
87 expand.__expand_push_method(scope, value)
88 }
89
90 pub fn __expand_index(
92 scope: &mut Scope,
93 expand: SequenceExpand<T>,
94 index: usize,
95 ) -> T::ExpandType {
96 expand.__expand_index_method(scope, index)
97 }
98
99 pub fn __expand_index_mut(
101 scope: &mut Scope,
102 expand: SequenceExpand<T>,
103 index: usize,
104 ) -> T::ExpandType {
105 expand.__expand_index_mut_method(scope, index)
106 }
107}
108
109impl<T: CubeType> CubeIndex for Sequence<T> {
110 type Output = T;
111 type Idx = usize;
112}
113
114impl<T: CubeType> Deref for Sequence<T> {
115 type Target = [T];
116
117 fn deref(&self) -> &Self::Target {
118 &self.values
119 }
120}
121
122impl<T: CubeType> CubeIndexExpand for SequenceExpand<T> {
123 type Output = T::ExpandType;
124 type Idx = ExpandElementTyped<usize>;
125
126 fn expand_index(self, scope: &mut Scope, index: Self::Idx) -> Self::Output {
127 let index = index
128 .constant()
129 .expect("Sequence index must be constant")
130 .as_usize();
131 self.__expand_index_method(scope, index)
132 }
133
134 fn expand_index_unchecked(self, scope: &mut Scope, index: Self::Idx) -> Self::Output {
135 let index = index
136 .constant()
137 .expect("Sequence index must be constant")
138 .as_usize();
139 self.__expand_index_method(scope, index)
140 }
141}
142
143pub struct SequenceExpand<T: CubeType> {
145 pub(super) values: Rc<RefCell<Vec<T::ExpandType>>>,
148}
149
150impl<T: CubeType> Iterable<T> for SequenceExpand<T> {
151 fn expand(self, scope: &mut Scope, func: impl FnMut(&mut Scope, <T as CubeType>::ExpandType)) {
152 self.expand_unroll(scope, func);
153 }
154
155 fn expand_unroll(
156 self,
157 scope: &mut Scope,
158 mut func: impl FnMut(&mut Scope, <T as CubeType>::ExpandType),
159 ) {
160 for elem in self {
161 func(scope, elem);
162 }
163 }
164
165 fn const_len(&self) -> Option<usize> {
166 Some(self.values.borrow().len())
167 }
168}
169
170impl<T: CubeType> IntoMut for SequenceExpand<T> {
171 fn into_mut(self, scope: &mut Scope) -> Self {
172 let mut values = self.values.borrow_mut();
173 values.iter_mut().for_each(|v| {
174 *v = IntoMut::into_mut(v.clone(), scope);
175 });
176 core::mem::drop(values);
177
178 self
179 }
180}
181impl<T: CubeType> CubeDebug for SequenceExpand<T> {}
182
183impl<T: CubeType> Clone for SequenceExpand<T> {
184 fn clone(&self) -> Self {
185 Self {
186 values: self.values.clone(),
187 }
188 }
189}
190
191impl<T: CubeType> IntoIterator for Sequence<T> {
192 type Item = T;
193
194 type IntoIter = <Vec<T> as IntoIterator>::IntoIter;
195
196 fn into_iter(self) -> Self::IntoIter {
197 self.values.into_iter()
198 }
199}
200
201impl<T: CubeType> IntoIterator for SequenceExpand<T> {
202 type Item = T::ExpandType;
203
204 type IntoIter = <Vec<T::ExpandType> as IntoIterator>::IntoIter;
205
206 fn into_iter(self) -> Self::IntoIter {
207 self.values.take().into_iter()
208 }
209}
210
211impl<T: CubeType> SequenceExpand<T> {
212 pub fn iter_cloned(&self) -> impl Iterator<Item = T::ExpandType> {
214 self.values.borrow().clone().into_iter()
215 }
216}
217
218impl<T: CubeType> CubeType for Sequence<T> {
219 type ExpandType = SequenceExpand<T>;
220}
221
222impl<T: CubeType> SequenceExpand<T> {
223 #[allow(clippy::len_without_is_empty)]
224 pub fn len(&self) -> usize {
225 self.values.borrow().len()
226 }
227 pub fn __expand_push_method(&mut self, _scope: &mut Scope, value: T::ExpandType) {
229 self.values.borrow_mut().push(value);
230 }
231
232 pub fn __expand_insert_method(&self, _scope: &mut Scope, index: usize, value: T::ExpandType) {
234 let mut values = self.values.borrow_mut();
235
236 if values.len() == index {
237 values.push(value);
238 } else {
239 values[index] = value;
240 }
241 }
242
243 pub fn __expand_index_method(&self, _scope: &mut Scope, index: usize) -> T::ExpandType {
245 self.values.borrow()[index].clone()
246 }
247
248 pub fn __expand_index_mut_method(&self, _scope: &mut Scope, index: usize) -> T::ExpandType {
250 self.values.borrow()[index].clone()
251 }
252
253 pub fn __expand_len_method(&self, _scope: &mut Scope) -> usize {
254 let values = self.values.borrow();
255 values.len()
256 }
257
258 pub fn __expand_rev_method(self, _scope: &mut Scope) -> Self {
259 let mut values = self.values.borrow().clone();
260 values.reverse();
261 Self {
262 values: Rc::new(RefCell::new(values)),
263 }
264 }
265
266 pub fn __expand_clone_method(&self, _scope: &mut Scope) -> Self {
267 self.clone()
268 }
269}