cubecl_core/frontend/container/sequence/
base.rs1use cubecl_ir::{ExpandElement, Scope};
2use serde::{Deserialize, Serialize};
3
4use crate::{
5 frontend::{CubeType, ExpandElementTyped, IntoMut, branch::Iterable, indexation::Index},
6 prelude::CubeDebug,
7};
8use std::{cell::RefCell, rc::Rc};
9
10#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]
18pub struct Sequence<T: CubeType> {
19 values: Vec<T>,
20}
21
22impl<T: CubeType> Default for Sequence<T> {
23 fn default() -> Self {
24 Self::new()
25 }
26}
27
28impl<T: CubeType> IntoMut for Sequence<T> {
29 fn into_mut(self, _scope: &mut Scope) -> Self {
30 self
31 }
32}
33impl<T: CubeType> CubeDebug for Sequence<T> {}
34
35impl<T: CubeType + Clone> Sequence<T> {
36 pub fn rev(&self) -> Self {
37 Self {
38 values: self.values.iter().rev().cloned().collect(),
39 }
40 }
41}
42
43impl<T: CubeType> Sequence<T> {
44 pub fn new() -> Self {
46 Self { values: Vec::new() }
47 }
48
49 pub fn push(&mut self, value: T) {
51 self.values.push(value);
52 }
53
54 #[allow(clippy::len_without_is_empty)]
56 pub fn len(&self) -> u32 {
57 self.values.len() as u32
58 }
59
60 #[allow(unused_variables, clippy::should_implement_trait)]
62 pub fn index<I: Index>(&self, index: I) -> &T {
63 let index: ExpandElementTyped<u32> = ExpandElement::Plain(index.value()).into();
64 let index = index
65 .constant()
66 .expect("Only constant are supported")
67 .as_usize();
68
69 self.values.get(index).unwrap()
70 }
71
72 #[allow(unused_variables, clippy::should_implement_trait)]
74 pub fn index_mut<I: Index>(&mut self, index: I) -> &mut T {
75 let index: ExpandElementTyped<u32> = ExpandElement::Plain(index.value()).into();
76 let index = index
77 .constant()
78 .expect("Only constant are supported")
79 .as_usize();
80
81 self.values.get_mut(index).unwrap()
82 }
83
84 pub fn __expand_new(_scope: &mut Scope) -> SequenceExpand<T> {
86 SequenceExpand {
87 values: Rc::new(RefCell::new(Vec::new())),
88 }
89 }
90
91 #[allow(unused_variables, clippy::should_implement_trait)]
93 pub fn insert<I: Index>(&mut self, index: I, value: T) {
94 *self.index_mut(index) = value;
95 }
96
97 pub fn __expand_push(scope: &mut Scope, expand: &mut SequenceExpand<T>, value: T::ExpandType) {
99 expand.__expand_push_method(scope, value)
100 }
101
102 pub fn __expand_index(
104 scope: &mut Scope,
105 expand: SequenceExpand<T>,
106 index: ExpandElementTyped<u32>,
107 ) -> T::ExpandType {
108 expand.__expand_index_method(scope, index)
109 }
110
111 pub fn __expand_index_mut(
113 scope: &mut Scope,
114 expand: SequenceExpand<T>,
115 index: ExpandElementTyped<u32>,
116 ) -> T::ExpandType {
117 expand.__expand_index_mut_method(scope, index)
118 }
119}
120
121pub struct SequenceExpand<T: CubeType> {
123 pub(super) values: Rc<RefCell<Vec<T::ExpandType>>>,
126}
127
128impl<T: CubeType> Iterable<T> for SequenceExpand<T> {
129 fn expand(self, scope: &mut Scope, func: impl FnMut(&mut Scope, <T as CubeType>::ExpandType)) {
130 self.expand_unroll(scope, func);
131 }
132
133 fn expand_unroll(
134 self,
135 scope: &mut Scope,
136 mut func: impl FnMut(&mut Scope, <T as CubeType>::ExpandType),
137 ) {
138 for elem in self {
139 func(scope, elem);
140 }
141 }
142}
143
144impl<T: CubeType> IntoMut for SequenceExpand<T> {
145 fn into_mut(self, scope: &mut Scope) -> Self {
146 let mut values = self.values.borrow_mut();
147 values.iter_mut().for_each(|v| {
148 *v = IntoMut::into_mut(v.clone(), scope);
149 });
150 core::mem::drop(values);
151
152 self
153 }
154}
155impl<T: CubeType> CubeDebug for SequenceExpand<T> {}
156
157impl<T: CubeType> Clone for SequenceExpand<T> {
158 fn clone(&self) -> Self {
159 Self {
160 values: self.values.clone(),
161 }
162 }
163}
164
165impl<T: CubeType> IntoIterator for Sequence<T> {
166 type Item = T;
167
168 type IntoIter = <Vec<T> as IntoIterator>::IntoIter;
169
170 fn into_iter(self) -> Self::IntoIter {
171 self.values.into_iter()
172 }
173}
174
175impl<T: CubeType> IntoIterator for SequenceExpand<T> {
176 type Item = T::ExpandType;
177
178 type IntoIter = <Vec<T::ExpandType> as IntoIterator>::IntoIter;
179
180 fn into_iter(self) -> Self::IntoIter {
181 self.values.take().into_iter()
182 }
183}
184
185impl<T: CubeType> CubeType for Sequence<T> {
186 type ExpandType = SequenceExpand<T>;
187}
188
189impl<T: CubeType> SequenceExpand<T> {
190 #[allow(clippy::len_without_is_empty)]
191 pub fn len(&self) -> u32 {
192 self.values.borrow().len() as u32
193 }
194 pub fn __expand_push_method(&mut self, _scope: &mut Scope, value: T::ExpandType) {
196 self.values.borrow_mut().push(value);
197 }
198
199 pub fn __expand_insert_method(
201 &self,
202 _scope: &mut Scope,
203 index: ExpandElementTyped<u32>,
204 value: T::ExpandType,
205 ) {
206 let index = index
207 .constant()
208 .expect("Only constant are supported")
209 .as_usize();
210
211 let mut values = self.values.borrow_mut();
212
213 if values.len() == index {
214 values.push(value);
215 } else {
216 values[index] = value;
217 }
218 }
219
220 pub fn __expand_index_method(
222 &self,
223 _scope: &mut Scope,
224 index: ExpandElementTyped<u32>,
225 ) -> T::ExpandType {
226 let index = index
227 .constant()
228 .expect("Only constant are supported")
229 .as_usize();
230
231 self.values.borrow()[index].clone()
232 }
233
234 pub fn __expand_index_mut_method(
236 &self,
237 _scope: &mut Scope,
238 index: ExpandElementTyped<u32>,
239 ) -> T::ExpandType {
240 let index = index
241 .constant()
242 .expect("Only constant are supported")
243 .as_usize();
244
245 self.values.borrow()[index].clone()
246 }
247
248 pub fn __expand_len_method(&self, _scope: &mut Scope) -> u32 {
249 let values = self.values.borrow();
250 values.len() as u32
251 }
252
253 pub fn __expand_rev_method(self, _scope: &mut Scope) -> Self {
254 self.values.borrow_mut().reverse();
255 self
256 }
257}