use super::{ReadOnly, ReadWrite, Slice, SliceExpand, SliceOriginExpand, SliceVisibility};
use crate as cubecl;
use crate::{ir::Scope, prelude::*, unexpanded};
use cubecl_common::tf32;
use cubecl_ir::ManagedVariable;
pub(crate) fn is_tf32<C: CubePrimitive, T: CubePrimitive>(scope: &mut Scope) -> bool {
let ty_c = C::as_type(scope).storage_type();
let ty_t = T::as_type(scope).storage_type();
let ty_f32 = f32::as_type(scope).storage_type();
let ty_tf32 = tf32::as_type(scope).storage_type();
(ty_c == ty_f32 && ty_t == ty_tf32) || (ty_c == ty_tf32 && ty_t == ty_f32)
}
impl<E: CubePrimitive> SliceOperator<E> for SharedMemory<E> {}
impl<E: CubePrimitive> SliceOperatorExpand<E> for NativeExpand<SharedMemory<E>> {
fn __expand_slice_method(
&self,
scope: &mut Scope,
start: NativeExpand<usize>,
end: NativeExpand<usize>,
) -> SliceExpand<E, ReadOnly> {
Slice::__expand_new(
scope,
SliceOriginExpand::SharedMemory(self.clone()),
start,
end,
)
}
fn __expand_to_slice_method(&self, scope: &mut Scope) -> SliceExpand<E, ReadOnly> {
let len = expand_length_native(scope, *self.expand);
Slice::__expand_new(
scope,
SliceOriginExpand::SharedMemory(self.clone()),
0usize.into(),
ManagedVariable::Plain(len).into(),
)
}
}
impl<E: CubePrimitive> SliceMutOperator<E> for SharedMemory<E> {}
impl<E: CubePrimitive> SliceMutOperatorExpand<E> for NativeExpand<SharedMemory<E>> {
fn __expand_slice_mut_method(
&self,
scope: &mut Scope,
start: NativeExpand<usize>,
end: NativeExpand<usize>,
) -> SliceExpand<E, ReadWrite> {
Slice::__expand_new(
scope,
SliceOriginExpand::SharedMemory(self.clone()),
start,
end,
)
}
fn __expand_to_slice_mut_method(&self, scope: &mut Scope) -> SliceExpand<E, ReadWrite> {
let len = expand_length_native(scope, *self.expand);
Slice::__expand_new(
scope,
SliceOriginExpand::SharedMemory(self.clone()),
0usize.into(),
ManagedVariable::Plain(len).into(),
)
}
}
impl<E: CubePrimitive> SliceOperator<E> for Tensor<E> {}
impl<E: CubePrimitive> SliceOperatorExpand<E> for NativeExpand<Tensor<E>> {
fn __expand_slice_method(
&self,
scope: &mut Scope,
start: NativeExpand<usize>,
end: NativeExpand<usize>,
) -> SliceExpand<E, ReadOnly> {
Slice::__expand_new(scope, SliceOriginExpand::Tensor(self.clone()), start, end)
}
fn __expand_to_slice_method(&self, scope: &mut Scope) -> SliceExpand<E, ReadOnly> {
let len = self.clone().__expand_len_method(scope);
Slice::__expand_new(
scope,
SliceOriginExpand::Tensor(self.clone()),
0usize.into(),
len,
)
}
}
impl<E: CubePrimitive> SliceMutOperator<E> for Tensor<E> {}
impl<E: CubePrimitive> SliceMutOperatorExpand<E> for NativeExpand<Tensor<E>> {
fn __expand_slice_mut_method(
&self,
scope: &mut Scope,
start: NativeExpand<usize>,
end: NativeExpand<usize>,
) -> SliceExpand<E, ReadWrite> {
Slice::__expand_new(scope, SliceOriginExpand::Tensor(self.clone()), start, end)
}
fn __expand_to_slice_mut_method(&self, scope: &mut Scope) -> SliceExpand<E, ReadWrite> {
let len = self.clone().__expand_len_method(scope);
Slice::__expand_new(
scope,
SliceOriginExpand::Tensor(self.clone()),
0usize.into(),
len,
)
}
}
impl<E: CubePrimitive> SliceOperator<E> for Array<E> {}
impl<E: CubePrimitive> SliceOperatorExpand<E> for NativeExpand<Array<E>> {
fn __expand_slice_method(
&self,
scope: &mut Scope,
start: NativeExpand<usize>,
end: NativeExpand<usize>,
) -> SliceExpand<E, ReadOnly> {
Slice::__expand_new(scope, SliceOriginExpand::Array(self.clone()), start, end)
}
fn __expand_to_slice_method(&self, scope: &mut Scope) -> SliceExpand<E, ReadOnly> {
let len = self.clone().__expand_len_method(scope);
Slice::__expand_new(
scope,
SliceOriginExpand::Array(self.clone()),
0usize.into(),
len,
)
}
}
impl<E: CubePrimitive> SliceMutOperator<E> for Array<E> {}
impl<E: CubePrimitive> SliceMutOperatorExpand<E> for NativeExpand<Array<E>> {
fn __expand_slice_mut_method(
&self,
scope: &mut Scope,
start: NativeExpand<usize>,
end: NativeExpand<usize>,
) -> SliceExpand<E, ReadWrite> {
Slice::__expand_new(scope, SliceOriginExpand::Array(self.clone()), start, end)
}
fn __expand_to_slice_mut_method(&self, scope: &mut Scope) -> SliceExpand<E, ReadWrite> {
let len = self.clone().__expand_len_method(scope);
Slice::__expand_new(
scope,
SliceOriginExpand::Array(self.clone()),
0usize.into(),
len,
)
}
}
impl<E: CubePrimitive, IO: SliceVisibility> SliceOperator<E> for Slice<E, IO> {}
impl<E: CubePrimitive, IO: SliceVisibility> SliceOperatorExpand<E> for SliceExpand<E, IO> {
fn __expand_slice_method(
&self,
scope: &mut Scope,
start: NativeExpand<usize>,
end: NativeExpand<usize>,
) -> SliceExpand<E, ReadOnly> {
let length = crate::frontend::sub::expand(scope, end, start.clone());
let offset = crate::frontend::add::expand(scope, start, self.offset.clone());
SliceExpand {
origin: self.origin.clone(),
io: core::marker::PhantomData,
offset,
length,
vector_size: self.vector_size,
}
}
fn __expand_to_slice_method(&self, _scope: &mut Scope) -> SliceExpand<E, ReadOnly> {
SliceExpand {
origin: self.origin.clone(),
io: core::marker::PhantomData,
offset: self.offset.clone(),
length: self.length.clone(),
vector_size: self.vector_size,
}
}
}
impl<E: CubePrimitive> SliceMutOperator<E> for Slice<E, ReadWrite> {}
impl<E: CubePrimitive> SliceMutOperatorExpand<E> for SliceExpand<E, ReadWrite> {
fn __expand_slice_mut_method(
&self,
scope: &mut Scope,
start: NativeExpand<usize>,
end: NativeExpand<usize>,
) -> SliceExpand<E, ReadWrite> {
let length = crate::frontend::sub::expand(scope, end, start.clone());
let offset = crate::frontend::add::expand(scope, start, self.offset.clone());
SliceExpand {
origin: self.origin.clone(),
io: core::marker::PhantomData,
offset,
length,
vector_size: self.vector_size,
}
}
fn __expand_to_slice_mut_method(&self, _scope: &mut Scope) -> SliceExpand<E, ReadWrite> {
SliceExpand {
origin: self.origin.clone(),
io: core::marker::PhantomData,
offset: self.offset.clone(),
length: self.length.clone(),
vector_size: self.vector_size,
}
}
}
#[cube(self_type = "ref")]
pub trait SliceOperator<E: CubePrimitive> {
#[allow(unused_variables)]
fn slice(&self, start: usize, end: usize) -> Slice<E, ReadOnly> {
unexpanded!()
}
#[allow(unused_variables)]
fn to_slice(&self) -> Slice<E, ReadOnly> {
unexpanded!()
}
}
#[cube(self_type = "ref")]
pub trait SliceMutOperator<E: CubePrimitive> {
#[allow(unused_variables)]
fn slice_mut(&mut self, start: usize, end: usize) -> Slice<E, ReadWrite> {
unexpanded!()
}
#[allow(unused_variables)]
fn to_slice_mut(&mut self) -> Slice<E, ReadWrite> {
unexpanded!()
}
}
impl<'a, T: CubePrimitive, L: SliceOperator<T>> SliceOperator<T> for &'a L where
&'a L: CubeType<ExpandType = L::ExpandType>
{
}
impl<'a, T: CubePrimitive, L: SliceOperator<T>> SliceOperator<T> for &'a mut L where
&'a mut L: CubeType<ExpandType = L::ExpandType>
{
}
impl<'a, T: CubePrimitive, L: SliceMutOperator<T>> SliceMutOperator<T> for &'a L where
&'a L: CubeType<ExpandType = L::ExpandType>
{
}
impl<'a, T: CubePrimitive, L: SliceMutOperator<T>> SliceMutOperator<T> for &'a mut L where
&'a mut L: CubeType<ExpandType = L::ExpandType>
{
}