cubecl_core/frontend/container/sequence/
base.rs1use cubecl_ir::{ExpandElement, Scope};
2use serde::{Deserialize, Serialize};
3
4use crate::{
5 frontend::{CubeType, ExpandElementTyped, IntoMut, branch::Iterable, indexation::Index},
6 prelude::CubeDebug,
7};
8use std::{cell::RefCell, rc::Rc};
9
10#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]
18pub struct Sequence<T: CubeType> {
19 values: Vec<T>,
20}
21
22impl<T: CubeType> Default for Sequence<T> {
23 fn default() -> Self {
24 Self::new()
25 }
26}
27
28impl<T: CubeType> IntoMut for Sequence<T> {
29 fn into_mut(self, _scope: &mut Scope) -> Self {
30 self
31 }
32}
33impl<T: CubeType> CubeDebug for Sequence<T> {}
34
35impl<T: CubeType + Clone> Sequence<T> {
36 pub fn rev(&self) -> Self {
37 Self {
38 values: self.values.iter().rev().cloned().collect(),
39 }
40 }
41}
42
43impl<T: CubeType> Sequence<T> {
44 pub fn new() -> Self {
46 Self { values: Vec::new() }
47 }
48
49 pub fn push(&mut self, value: T) {
51 self.values.push(value);
52 }
53
54 #[allow(clippy::len_without_is_empty)]
56 pub fn len(&self) -> u32 {
57 self.values.len() as u32
58 }
59
60 #[allow(unused_variables, clippy::should_implement_trait)]
62 pub fn index<I: Index>(&self, index: I) -> &T {
63 let index: ExpandElementTyped<u32> = ExpandElement::Plain(index.value()).into();
64 let index = index
65 .constant()
66 .expect("Only constant are supported")
67 .as_usize();
68
69 self.values.get(index).unwrap()
70 }
71
72 #[allow(unused_variables, clippy::should_implement_trait)]
74 pub fn index_mut<I: Index>(&mut self, index: I) -> &mut T {
75 let index: ExpandElementTyped<u32> = ExpandElement::Plain(index.value()).into();
76 let index = index
77 .constant()
78 .expect("Only constant are supported")
79 .as_usize();
80
81 self.values.get_mut(index).unwrap()
82 }
83
84 pub fn __expand_new(_scope: &mut Scope) -> SequenceExpand<T> {
86 SequenceExpand {
87 values: Rc::new(RefCell::new(Vec::new())),
88 }
89 }
90
91 #[allow(unused_variables, clippy::should_implement_trait)]
93 pub fn insert<I: Index>(&mut self, index: I, value: T) {
94 *self.index_mut(index) = value;
95 }
96
97 pub fn __expand_push(scope: &mut Scope, expand: &mut SequenceExpand<T>, value: T::ExpandType) {
99 expand.__expand_push_method(scope, value)
100 }
101
102 pub fn __expand_index(
104 scope: &mut Scope,
105 expand: SequenceExpand<T>,
106 index: ExpandElementTyped<u32>,
107 ) -> T::ExpandType {
108 expand.__expand_index_method(scope, index)
109 }
110
111 pub fn __expand_index_mut(
113 scope: &mut Scope,
114 expand: SequenceExpand<T>,
115 index: ExpandElementTyped<u32>,
116 ) -> T::ExpandType {
117 expand.__expand_index_mut_method(scope, index)
118 }
119}
120
121pub struct SequenceExpand<T: CubeType> {
123 pub(super) values: Rc<RefCell<Vec<T::ExpandType>>>,
126}
127
128impl<T: CubeType> Iterable<T> for SequenceExpand<T> {
129 fn expand(self, scope: &mut Scope, func: impl FnMut(&mut Scope, <T as CubeType>::ExpandType)) {
130 self.expand_unroll(scope, func);
131 }
132
133 fn expand_unroll(
134 self,
135 scope: &mut Scope,
136 mut func: impl FnMut(&mut Scope, <T as CubeType>::ExpandType),
137 ) {
138 for elem in self {
139 func(scope, elem);
140 }
141 }
142}
143
144impl<T: CubeType> IntoMut for SequenceExpand<T> {
145 fn into_mut(self, scope: &mut Scope) -> Self {
146 let mut values = self.values.borrow_mut();
147 values.iter_mut().for_each(|v| {
148 *v = IntoMut::into_mut(v.clone(), scope);
149 });
150 core::mem::drop(values);
151
152 self
153 }
154}
155impl<T: CubeType> CubeDebug for SequenceExpand<T> {}
156
157impl<T: CubeType> Clone for SequenceExpand<T> {
158 fn clone(&self) -> Self {
159 Self {
160 values: self.values.clone(),
161 }
162 }
163}
164
165impl<T: CubeType> IntoIterator for Sequence<T> {
166 type Item = T;
167
168 type IntoIter = <Vec<T> as IntoIterator>::IntoIter;
169
170 fn into_iter(self) -> Self::IntoIter {
171 self.values.into_iter()
172 }
173}
174
175impl<T: CubeType> IntoIterator for SequenceExpand<T> {
176 type Item = T::ExpandType;
177
178 type IntoIter = <Vec<T::ExpandType> as IntoIterator>::IntoIter;
179
180 fn into_iter(self) -> Self::IntoIter {
181 self.values.take().into_iter()
182 }
183}
184
185impl<T: CubeType> SequenceExpand<T> {
186 pub fn iter_cloned(&self) -> impl Iterator<Item = T::ExpandType> {
188 self.values.borrow().clone().into_iter()
189 }
190}
191
192impl<T: CubeType> CubeType for Sequence<T> {
193 type ExpandType = SequenceExpand<T>;
194}
195
196impl<T: CubeType> SequenceExpand<T> {
197 #[allow(clippy::len_without_is_empty)]
198 pub fn len(&self) -> u32 {
199 self.values.borrow().len() as u32
200 }
201 pub fn __expand_push_method(&mut self, _scope: &mut Scope, value: T::ExpandType) {
203 self.values.borrow_mut().push(value);
204 }
205
206 pub fn __expand_insert_method(
208 &self,
209 _scope: &mut Scope,
210 index: ExpandElementTyped<u32>,
211 value: T::ExpandType,
212 ) {
213 let index = index
214 .constant()
215 .expect("Only constant are supported")
216 .as_usize();
217
218 let mut values = self.values.borrow_mut();
219
220 if values.len() == index {
221 values.push(value);
222 } else {
223 values[index] = value;
224 }
225 }
226
227 pub fn __expand_index_method(
229 &self,
230 _scope: &mut Scope,
231 index: ExpandElementTyped<u32>,
232 ) -> T::ExpandType {
233 let index = index
234 .constant()
235 .expect("Only constant are supported")
236 .as_usize();
237
238 self.values.borrow()[index].clone()
239 }
240
241 pub fn __expand_index_mut_method(
243 &self,
244 _scope: &mut Scope,
245 index: ExpandElementTyped<u32>,
246 ) -> T::ExpandType {
247 let index = index
248 .constant()
249 .expect("Only constant are supported")
250 .as_usize();
251
252 self.values.borrow()[index].clone()
253 }
254
255 pub fn __expand_len_method(&self, _scope: &mut Scope) -> u32 {
256 let values = self.values.borrow();
257 values.len() as u32
258 }
259
260 pub fn __expand_rev_method(self, _scope: &mut Scope) -> Self {
261 let mut values = self.values.borrow().clone();
262 values.reverse();
263 Self {
264 values: Rc::new(RefCell::new(values)),
265 }
266 }
267
268 pub fn __expand_clone_method(&self, _scope: &mut Scope) -> Self {
269 self.clone()
270 }
271}