use super::{CubeType, ExpandElementTyped};
use crate as cubecl;
use crate::{prelude::*, unexpanded};
use cubecl_ir::{LineSize, Scope};
#[allow(clippy::len_without_is_empty)]
#[cube(self_type = "ref", expand_base_traits = "SliceOperatorExpand<T>")]
pub trait List<T: CubePrimitive>: SliceOperator<T> + Lined {
#[allow(unused)]
fn read(&self, index: usize) -> T {
unexpanded!()
}
#[allow(unused)]
fn read_unchecked(&self, index: usize) -> T {
unexpanded!()
}
#[allow(unused)]
fn len(&self) -> usize {
unexpanded!();
}
}
#[cube(self_type = "ref", expand_base_traits = "SliceMutOperatorExpand<T>")]
pub trait ListMut<T: CubePrimitive>: List<T> + SliceMutOperator<T> {
#[allow(unused)]
fn write(&self, index: usize, value: T) {
unexpanded!()
}
}
impl<'a, T: CubePrimitive, L: List<T>> List<T> for &'a L
where
&'a L: CubeType<ExpandType = L::ExpandType>,
{
fn read(&self, index: usize) -> T {
L::read(self, index)
}
fn __expand_read(
scope: &mut Scope,
this: Self::ExpandType,
index: ExpandElementTyped<usize>,
) -> <T as CubeType>::ExpandType {
L::__expand_read(scope, this, index)
}
}
impl<'a, T: CubePrimitive, L: List<T>> List<T> for &'a mut L
where
&'a mut L: CubeType<ExpandType = L::ExpandType>,
{
fn read(&self, index: usize) -> T {
L::read(self, index)
}
fn __expand_read(
scope: &mut Scope,
this: Self::ExpandType,
index: ExpandElementTyped<usize>,
) -> <T as CubeType>::ExpandType {
L::__expand_read(scope, this, index)
}
}
impl<'a, T: CubePrimitive, L: ListMut<T>> ListMut<T> for &'a L
where
&'a L: CubeType<ExpandType = L::ExpandType>,
{
fn write(&self, index: usize, value: T) {
L::write(self, index, value);
}
fn __expand_write(
scope: &mut Scope,
this: Self::ExpandType,
index: ExpandElementTyped<usize>,
value: T::ExpandType,
) {
L::__expand_write(scope, this, index, value);
}
}
impl<'a, T: CubePrimitive, L: ListMut<T>> ListMut<T> for &'a mut L
where
&'a mut L: CubeType<ExpandType = L::ExpandType>,
{
fn write(&self, index: usize, value: T) {
L::write(self, index, value);
}
fn __expand_write(
scope: &mut Scope,
this: Self::ExpandType,
index: ExpandElementTyped<usize>,
value: T::ExpandType,
) {
L::__expand_write(scope, this, index, value);
}
}
pub trait Lined: CubeType<ExpandType: LinedExpand> {
fn line_size(&self) -> LineSize {
unexpanded!()
}
fn __expand_line_size(_scope: &mut Scope, this: Self::ExpandType) -> LineSize {
this.line_size()
}
}
pub trait LinedExpand {
fn line_size(&self) -> LineSize;
fn __expand_line_size_method(&self, _scope: &mut Scope) -> LineSize {
self.line_size()
}
}
impl<'a, L: Lined> Lined for &'a L where &'a L: CubeType<ExpandType: LinedExpand> {}
impl<'a, L: Lined> Lined for &'a mut L where &'a mut L: CubeType<ExpandType: LinedExpand> {}