use alloc::vec::Vec;
use cubecl_ir::Scope;
use serde::{Deserialize, Serialize};
use crate::{
frontend::{CubeType, IntoMut, NativeExpand, branch::Iterable},
prelude::{CubeDebug, CubeIndex, CubeIndexExpand},
};
use alloc::rc::Rc;
use core::{cell::RefCell, ops::Deref};
#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]
pub struct Sequence<T: CubeType> {
values: Vec<T>,
}
impl<T: CubeType> Default for Sequence<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: CubeType> IntoMut for Sequence<T> {
fn into_mut(self, _scope: &mut Scope) -> Self {
self
}
}
impl<T: CubeType> CubeDebug for Sequence<T> {}
impl<T: CubeType + Clone> Sequence<T> {
pub fn rev(&self) -> Self {
Self {
values: self.values.iter().rev().cloned().collect(),
}
}
}
impl<T: CubeType> Sequence<T> {
pub fn new() -> Self {
Self { values: Vec::new() }
}
pub fn push(&mut self, value: T) {
self.values.push(value);
}
#[allow(clippy::len_without_is_empty)]
pub fn len(&self) -> usize {
self.values.len()
}
#[allow(unused_variables, clippy::should_implement_trait)]
pub fn index(&self, index: usize) -> &T {
self.values.get(index).unwrap()
}
#[allow(unused_variables, clippy::should_implement_trait)]
pub fn index_mut(&mut self, index: usize) -> &mut T {
self.values.get_mut(index).unwrap()
}
pub fn __expand_new(_scope: &mut Scope) -> SequenceExpand<T> {
SequenceExpand {
values: Rc::new(RefCell::new(Vec::new())),
}
}
#[allow(unused_variables, clippy::should_implement_trait)]
pub fn insert(&mut self, index: usize, value: T) {
*self.index_mut(index) = value;
}
pub fn __expand_push(scope: &mut Scope, expand: &mut SequenceExpand<T>, value: T::ExpandType) {
expand.__expand_push_method(scope, value)
}
pub fn __expand_index(
scope: &mut Scope,
expand: SequenceExpand<T>,
index: usize,
) -> T::ExpandType {
expand.__expand_index_method(scope, index)
}
pub fn __expand_index_mut(
scope: &mut Scope,
expand: SequenceExpand<T>,
index: usize,
) -> T::ExpandType {
expand.__expand_index_mut_method(scope, index)
}
}
impl<T: CubeType> CubeIndex for Sequence<T> {
type Output = T;
type Idx = usize;
}
impl<T: CubeType> Deref for Sequence<T> {
type Target = [T];
fn deref(&self) -> &Self::Target {
&self.values
}
}
impl<T: CubeType> CubeIndexExpand for SequenceExpand<T> {
type Output = T::ExpandType;
type Idx = NativeExpand<usize>;
fn expand_index(self, scope: &mut Scope, index: Self::Idx) -> Self::Output {
let index = index
.constant()
.expect("Sequence index must be constant")
.as_usize();
self.__expand_index_method(scope, index)
}
fn expand_index_unchecked(self, scope: &mut Scope, index: Self::Idx) -> Self::Output {
let index = index
.constant()
.expect("Sequence index must be constant")
.as_usize();
self.__expand_index_method(scope, index)
}
}
pub struct SequenceExpand<T: CubeType> {
pub(super) values: Rc<RefCell<Vec<T::ExpandType>>>,
}
impl<T: CubeType> Iterable<T> for SequenceExpand<T> {
fn expand(self, scope: &mut Scope, func: impl FnMut(&mut Scope, <T as CubeType>::ExpandType)) {
self.expand_unroll(scope, func);
}
fn expand_unroll(
self,
scope: &mut Scope,
mut func: impl FnMut(&mut Scope, <T as CubeType>::ExpandType),
) {
for elem in self {
func(scope, elem);
}
}
fn const_len(&self) -> Option<usize> {
Some(self.values.borrow().len())
}
}
impl<T: CubeType> IntoMut for SequenceExpand<T> {
fn into_mut(self, scope: &mut Scope) -> Self {
let mut values = self.values.borrow_mut();
values.iter_mut().for_each(|v| {
*v = IntoMut::into_mut(v.clone(), scope);
});
core::mem::drop(values);
self
}
}
impl<T: CubeType> CubeDebug for SequenceExpand<T> {}
impl<T: CubeType> Clone for SequenceExpand<T> {
fn clone(&self) -> Self {
Self {
values: self.values.clone(),
}
}
}
impl<T: CubeType> IntoIterator for Sequence<T> {
type Item = T;
type IntoIter = <Vec<T> as IntoIterator>::IntoIter;
fn into_iter(self) -> Self::IntoIter {
self.values.into_iter()
}
}
impl<T: CubeType> IntoIterator for SequenceExpand<T> {
type Item = T::ExpandType;
type IntoIter = <Vec<T::ExpandType> as IntoIterator>::IntoIter;
fn into_iter(self) -> Self::IntoIter {
self.values.take().into_iter()
}
}
impl<T: CubeType> SequenceExpand<T> {
pub fn iter_cloned(&self) -> impl Iterator<Item = T::ExpandType> {
self.values.borrow().clone().into_iter()
}
}
impl<T: CubeType> CubeType for Sequence<T> {
type ExpandType = SequenceExpand<T>;
}
impl<T: CubeType> SequenceExpand<T> {
#[allow(clippy::len_without_is_empty)]
pub fn len(&self) -> usize {
self.values.borrow().len()
}
pub fn __expand_push_method(&mut self, _scope: &mut Scope, value: T::ExpandType) {
self.values.borrow_mut().push(value);
}
pub fn __expand_insert_method(&self, _scope: &mut Scope, index: usize, value: T::ExpandType) {
let mut values = self.values.borrow_mut();
if values.len() == index {
values.push(value);
} else {
values[index] = value;
}
}
pub fn __expand_index_method(&self, _scope: &mut Scope, index: usize) -> T::ExpandType {
self.values.borrow()[index].clone()
}
pub fn __expand_index_mut_method(&self, _scope: &mut Scope, index: usize) -> T::ExpandType {
self.values.borrow()[index].clone()
}
pub fn __expand_len_method(&self, _scope: &mut Scope) -> usize {
let values = self.values.borrow();
values.len()
}
pub fn __expand_rev_method(self, _scope: &mut Scope) -> Self {
let mut values = self.values.borrow().clone();
values.reverse();
Self {
values: Rc::new(RefCell::new(values)),
}
}
pub fn __expand_clone_method(&self, _scope: &mut Scope) -> Self {
self.clone()
}
}