use crate::{
frontend::{CubePrimitive, CubeType, ExpandElementIntoMut, ExpandElementTyped, SizedContainer},
ir::{Metadata, Scope, Type},
prelude::{
Line, Lined, LinedExpand, List, ListExpand, ListMut, ListMutExpand, index, index_assign,
index_unchecked,
},
unexpanded,
};
use cubecl_ir::{ExpandElement, LineSize};
use cubecl_macros::{cube, intrinsic};
use std::marker::PhantomData;
use crate as cubecl;
#[derive(new, Clone, Copy)]
pub struct Tensor<T: CubeType> {
_val: PhantomData<T>,
}
type TensorExpand<T> = ExpandElementTyped<Tensor<T>>;
mod metadata {
use cubecl_ir::ExpandElement;
use super::*;
use crate::{
ir::{Arithmetic, BinaryOperator, Instruction},
prelude::Array,
};
#[cube]
impl<T: CubeType> Tensor<T> {
#[allow(unused_variables)]
pub fn stride(&self, dim: usize) -> usize {
intrinsic!(|scope| {
let dim: ExpandElement = dim.into();
let out = scope.create_local(Type::new(usize::as_type(scope)));
scope.register(Instruction::new(
Metadata::Stride {
dim: *dim,
var: self.expand.into(),
},
out.clone().into(),
));
out.into()
})
}
#[allow(unused_variables)]
pub fn shape(&self, dim: usize) -> usize {
intrinsic!(|scope| {
let dim: ExpandElement = dim.into();
let out = scope.create_local(Type::new(usize::as_type(scope)));
scope.register(Instruction::new(
Metadata::Shape {
dim: *dim,
var: self.expand.into(),
},
out.clone().into(),
));
out.into()
})
}
#[allow(unused_variables)]
pub fn coordinate(&self, index: usize, dim: usize) -> usize {
intrinsic!(|scope| {
let index: ExpandElement = index.into();
let stride = self.clone().__expand_stride_method(scope, dim.clone());
let shape = self.clone().__expand_shape_method(scope, dim.clone());
let num_strides = scope.create_local(Type::new(usize::as_type(scope)));
scope.register(Instruction::new(
Arithmetic::Div(BinaryOperator {
lhs: *index,
rhs: stride.expand.into(),
}),
num_strides.clone().into(),
));
let coordinate = scope.create_local(Type::new(usize::as_type(scope)));
scope.register(Instruction::new(
Arithmetic::Modulo(BinaryOperator {
lhs: *num_strides,
rhs: shape.expand.into(),
}),
coordinate.clone().into(),
));
coordinate.into()
})
}
#[allow(clippy::len_without_is_empty)]
pub fn len(&self) -> usize {
intrinsic!(|scope| {
let elem: ExpandElementTyped<Array<u32>> = self.expand.into();
elem.__expand_len_method(scope)
})
}
#[allow(clippy::len_without_is_empty)]
pub fn buffer_len(&self) -> usize {
intrinsic!(|scope| {
let elem: ExpandElementTyped<Array<u32>> = self.expand.into();
elem.__expand_buffer_len_method(scope)
})
}
pub fn rank(&self) -> usize {
intrinsic!(|scope| {
let out = scope.create_local(Type::new(usize::as_type(scope)));
scope.register(Instruction::new(Metadata::Rank { var: *self.expand }, *out));
out.into()
})
}
}
}
mod indexation {
use cubecl_ir::{IndexAssignOperator, IndexOperator, Operator};
use crate::{
ir::Instruction,
prelude::{CubeIndex, CubeIndexMut},
};
use super::*;
#[cube]
impl<E: CubePrimitive> Tensor<E> {
#[allow(unused_variables)]
pub unsafe fn index_unchecked(&self, i: usize) -> &E
where
Self: CubeIndex,
{
intrinsic!(|scope| {
let out = scope.create_local(self.expand.ty);
scope.register(Instruction::new(
Operator::UncheckedIndex(IndexOperator {
list: *self.expand,
index: i.expand.consume(),
line_size: 0,
unroll_factor: 1,
}),
*out,
));
out.into()
})
}
#[allow(unused_variables)]
pub unsafe fn index_assign_unchecked(&mut self, i: usize, value: E)
where
Self: CubeIndexMut,
{
intrinsic!(|scope| {
scope.register(Instruction::new(
Operator::UncheckedIndexAssign(IndexAssignOperator {
index: i.expand.consume(),
value: value.expand.consume(),
line_size: 0,
unroll_factor: 1,
}),
*self.expand,
));
})
}
}
}
mod line {
use super::*;
impl<P: CubePrimitive> Tensor<Line<P>> {
pub fn line_size(&self) -> LineSize {
unexpanded!()
}
pub fn __expand_line_size(
expand: <Self as CubeType>::ExpandType,
scope: &mut Scope,
) -> LineSize {
expand.__expand_line_size_method(scope)
}
}
}
impl<T: CubePrimitive> SizedContainer for Tensor<T> {
type Item = T;
}
impl<T: CubeType> Iterator for &Tensor<T> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
unexpanded!()
}
}
impl<T: CubeType> CubeType for Tensor<T> {
type ExpandType = ExpandElementTyped<Tensor<T>>;
}
impl<T: CubeType> CubeType for *const Tensor<T> {
type ExpandType = ExpandElementTyped<Tensor<T>>;
}
impl<T: CubeType> CubeType for *mut Tensor<T> {
type ExpandType = ExpandElementTyped<Tensor<T>>;
}
impl<T: CubeType> CubeType for &mut Tensor<T> {
type ExpandType = ExpandElementTyped<Tensor<T>>;
}
impl<T: CubeType> CubeType for &Tensor<T> {
type ExpandType = ExpandElementTyped<Tensor<T>>;
}
impl<C: CubeType> ExpandElementIntoMut for Tensor<C> {
fn elem_into_mut(_scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
elem
}
}
impl<T: CubePrimitive> List<T> for Tensor<T> {
fn __expand_read(
scope: &mut Scope,
this: ExpandElementTyped<Tensor<T>>,
idx: ExpandElementTyped<usize>,
) -> ExpandElementTyped<T> {
index::expand(scope, this, idx)
}
}
impl<T: CubePrimitive> ListExpand<T> for ExpandElementTyped<Tensor<T>> {
fn __expand_read_method(
&self,
scope: &mut Scope,
idx: ExpandElementTyped<usize>,
) -> ExpandElementTyped<T> {
index::expand(scope, self.clone(), idx)
}
fn __expand_read_unchecked_method(
&self,
scope: &mut Scope,
idx: ExpandElementTyped<usize>,
) -> ExpandElementTyped<T> {
index_unchecked::expand(scope, self.clone(), idx)
}
fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<usize> {
Self::__expand_len(scope, self.clone())
}
}
impl<T: CubePrimitive> Lined for Tensor<T> {}
impl<T: CubePrimitive> LinedExpand for ExpandElementTyped<Tensor<T>> {
fn line_size(&self) -> LineSize {
self.expand.ty.line_size()
}
}
impl<T: CubePrimitive> ListMut<T> for Tensor<T> {
fn __expand_write(
scope: &mut Scope,
this: ExpandElementTyped<Tensor<T>>,
idx: ExpandElementTyped<usize>,
value: ExpandElementTyped<T>,
) {
index_assign::expand(scope, this, idx, value);
}
}
impl<T: CubePrimitive> ListMutExpand<T> for ExpandElementTyped<Tensor<T>> {
fn __expand_write_method(
&self,
scope: &mut Scope,
idx: ExpandElementTyped<usize>,
value: ExpandElementTyped<T>,
) {
index_assign::expand(scope, self.clone(), idx, value);
}
}