use alloc::vec;
use core::marker::PhantomData;
use crate as cubecl;
use crate::{prelude::*, unexpanded};
use cubecl_ir::{Type, VectorSize};
use cubecl_runtime::server::TensorMapMeta;
use cubecl_zspace::{Strides, metadata::Metadata, strides};
use paste::paste;
pub use cubecl_runtime::tma::*;
pub trait TensorMapKind: CubeType + Clone + Copy + Send + Sync + 'static {
type Args: Clone;
fn as_format(args: Self::Args) -> TensorMapFormat;
}
#[derive(CubeType, CubeLaunch, Clone, Copy)]
pub struct Tiled {}
#[derive(CubeType, CubeLaunch, Clone, Copy)]
pub struct Im2col;
#[derive(CubeType, CubeLaunch, Clone, Copy)]
pub struct Im2colWide;
impl TensorMapKind for Tiled {
type Args = TiledArgs;
fn as_format(args: Self::Args) -> TensorMapFormat {
TensorMapFormat::Tiled(args)
}
}
impl TensorMapKind for Im2col {
type Args = Im2colArgs;
fn as_format(args: Self::Args) -> TensorMapFormat {
TensorMapFormat::Im2col(args)
}
}
impl TensorMapKind for Im2colWide {
type Args = Im2colWideArgs;
fn as_format(args: Self::Args) -> TensorMapFormat {
TensorMapFormat::Im2colWide(args)
}
}
pub struct TensorMapArg<R: Runtime, K: TensorMapKind> {
pub tensor: TensorArg<R>,
pub metadata: TensorMapMeta,
pub _kind: PhantomData<K>,
}
impl<R: Runtime, K: TensorMapKind> TensorMapArg<R, K> {
pub fn new(args: K::Args, tensor: TensorArg<R>, ty: impl Into<Type>) -> Self {
let ty = ty.into();
let TensorArg::Handle { handle, .. } = &tensor else {
panic!("Can't use alias for TensorMap")
};
let rank = handle.shape.len();
Self {
metadata: TensorMapMeta {
format: K::as_format(args),
metadata: Metadata::new(handle.shape.clone(), handle.strides.clone()),
elem_stride: strides![1; rank],
interleave: TensorMapInterleave::None,
swizzle: TensorMapSwizzle::None,
prefetch: TensorMapPrefetch::None,
oob_fill: OobFill::Zero,
storage_ty: ty.storage_type(),
},
tensor,
_kind: PhantomData,
}
}
pub fn with_elem_stride(mut self, elem_stride: Strides) -> Self {
self.metadata.elem_stride = elem_stride;
self
}
pub fn with_interleave(mut self, interleave: TensorMapInterleave) -> Self {
self.metadata.interleave = interleave;
self
}
pub fn with_swizzle(mut self, swizzle: TensorMapSwizzle) -> Self {
self.metadata.swizzle = swizzle;
self
}
pub fn with_prefetch(mut self, prefetch: TensorMapPrefetch) -> Self {
self.metadata.prefetch = prefetch;
self
}
pub fn with_nan_fill(mut self) -> Self {
self.metadata.oob_fill = OobFill::NaN;
self
}
}
#[derive(Clone)]
pub struct TensorMap<E: CubePrimitive, K: TensorMapKind> {
_ty: PhantomData<E>,
_kind: PhantomData<K>,
}
impl<E: CubePrimitive, K: TensorMapKind> Copy for TensorMap<E, K> {}
impl<E: CubePrimitive, K: TensorMapKind> TensorMap<E, K> {}
impl<E: CubePrimitive, K: TensorMapKind> IntoMut for NativeExpand<TensorMap<E, K>> {
fn into_mut(self, _scope: &mut Scope) -> Self {
self
}
}
impl<E: CubePrimitive, K: TensorMapKind> CubeType for TensorMap<E, K> {
type ExpandType = NativeExpand<TensorMap<E, K>>;
}
impl<E: CubePrimitive, K: TensorMapKind> CubeType for *const TensorMap<E, K> {
type ExpandType = NativeExpand<TensorMap<E, K>>;
}
impl<E: CubePrimitive, K: TensorMapKind> CubeType for *mut TensorMap<E, K> {
type ExpandType = NativeExpand<TensorMap<E, K>>;
}
impl<E: CubePrimitive, K: TensorMapKind> Vectorized for TensorMap<E, K> {}
impl<E: CubePrimitive, K: TensorMapKind> VectorizedExpand for NativeExpand<TensorMap<E, K>> {
fn vector_size(&self) -> VectorSize {
1
}
}
impl<E: CubePrimitive, K: TensorMapKind> LaunchArg for TensorMap<E, K> {
type RuntimeArg<R: Runtime> = TensorMapArg<R, K>;
type CompilationArg = ();
fn register<R: Runtime>(
arg: Self::RuntimeArg<R>,
launcher: &mut KernelLauncher<R>,
) -> Self::CompilationArg {
let ty = launcher.with_scope(|scope| E::as_type(scope));
launcher.register_tensor_map(arg, ty);
}
fn expand(
_arg: &Self::CompilationArg,
builder: &mut KernelBuilder,
) -> NativeExpand<TensorMap<E, K>> {
let tensor = builder.input_tensor_map(E::as_type(&builder.scope));
tensor.into()
}
fn expand_output(
_arg: &Self::CompilationArg,
builder: &mut KernelBuilder,
) -> NativeExpand<TensorMap<E, K>> {
let tensor = builder.output_tensor_map(E::as_type(&builder.scope));
tensor.into()
}
}
pub fn tma_group_commit() {
unexpanded!()
}
pub mod tma_group_commit {
use cubecl_ir::TmaOps;
use super::*;
pub fn expand(scope: &mut Scope) {
scope.register(TmaOps::CommitGroup)
}
}
pub fn tma_group_wait(_max_pending: u32) {
unexpanded!()
}
pub mod tma_group_wait {
use cubecl_ir::TmaOps;
use super::*;
pub fn expand(scope: &mut Scope, max_pending: u32) {
scope.register(TmaOps::WaitGroup { max_pending })
}
}
pub fn tma_group_wait_read(_max_pending: u32) {
unexpanded!()
}
pub mod tma_group_wait_read {
use cubecl_ir::TmaOps;
use super::*;
pub fn expand(scope: &mut Scope, max_pending: u32) {
scope.register(TmaOps::WaitGroupRead { max_pending })
}
}
macro_rules! tma_store {
($dim: literal, $($arg: expr),*) => {
paste! {
#[allow(unused)]
pub fn [<tma_store_ $dim d>]<T: CubePrimitive, T2: CubePrimitive<Scalar = T::Scalar>>(
src: &Slice<T2>,
dst: &mut TensorMap<T, Tiled>,
$($arg: i32),*
) {
unexpanded!()
}
pub mod [<tma_store_ $dim d>] {
use cubecl_ir::{Instruction, TmaOps};
use super::*;
#[allow(clippy::too_many_arguments)]
pub fn expand<T: CubePrimitive, T2: CubePrimitive<Scalar = T::Scalar>>(
scope: &mut Scope,
src: SliceExpand<T2, ReadOnly>,
dst: NativeExpand<TensorMap<T, Tiled>>,
$($arg: NativeExpand<i32>),*
) {
let (source, source_offset) = src.__to_raw_parts();
let dst = *dst.expand;
let coordinates = vec![$(*$arg.expand),*];
scope.register(Instruction::new(
TmaOps::TmaStore {
source,
coordinates,
offset_source: source_offset,
},
dst,
))
}
}
}
};
}
tma_store!(1, x);
tma_store!(2, y, x);
tma_store!(3, z, y, x);
tma_store!(4, w, z, y, x);
tma_store!(5, v, w, z, y, x);
mod metadata {
use cubecl_ir::{ManagedVariable, Metadata, VariableKind};
use super::*;
use crate::{
ir::{Arithmetic, BinaryOperator, Instruction},
prelude::Array,
};
impl<T: Scalar, K: TensorMapKind> TensorMap<T, K> {
pub fn buffer<N: Size>(&self) -> Tensor<Vector<T, N>> {
unexpanded!()
}
pub fn stride(&self, _dim: usize) -> usize {
unexpanded!()
}
pub fn shape(&self, _dim: usize) -> usize {
unexpanded!()
}
pub fn coordinate(&self, _index: usize, _dim: usize) -> usize {
unexpanded!()
}
#[allow(clippy::len_without_is_empty)]
pub fn len(&self) -> usize {
unexpanded!()
}
#[allow(clippy::len_without_is_empty)]
pub fn buffer_len(&self) -> usize {
unexpanded!()
}
pub fn rank(&self) -> usize {
unexpanded!()
}
pub fn downcast<E: CubePrimitive>(&self) -> TensorMap<E, K> {
unexpanded!()
}
pub fn __expand_buffer(
scope: &mut Scope,
expand: NativeExpand<TensorMap<T, K>>,
) -> NativeExpand<Tensor<T>> {
expand.__expand_buffer_method(scope)
}
pub fn __expand_stride(
scope: &mut Scope,
expand: NativeExpand<TensorMap<T, K>>,
dim: NativeExpand<usize>,
) -> NativeExpand<usize> {
expand.__expand_stride_method(scope, dim)
}
pub fn __expand_shape(
scope: &mut Scope,
expand: NativeExpand<TensorMap<T, K>>,
dim: NativeExpand<usize>,
) -> NativeExpand<usize> {
expand.__expand_shape_method(scope, dim)
}
pub fn __expand_coordinate(
scope: &mut Scope,
expand: NativeExpand<TensorMap<T, K>>,
index: NativeExpand<usize>,
dim: NativeExpand<usize>,
) -> NativeExpand<usize> {
expand.__expand_coordinate_method(scope, index, dim)
}
pub fn __expand_len(
scope: &mut Scope,
expand: NativeExpand<TensorMap<T, K>>,
) -> NativeExpand<usize> {
expand.__expand_len_method(scope)
}
pub fn __expand_buffer_len(
scope: &mut Scope,
expand: NativeExpand<TensorMap<T, K>>,
) -> NativeExpand<usize> {
expand.__expand_buffer_len_method(scope)
}
pub fn __expand_rank(
scope: &mut Scope,
expand: NativeExpand<TensorMap<T, K>>,
) -> NativeExpand<usize> {
expand.__expand_rank_method(scope)
}
}
impl<T: CubePrimitive, K: TensorMapKind> NativeExpand<TensorMap<T, K>> {
pub fn __expand_buffer_method(self, scope: &mut Scope) -> NativeExpand<Tensor<T>> {
let tensor = match self.expand.kind {
VariableKind::TensorMapInput(id) => scope.input(id, self.expand.ty),
VariableKind::TensorMapOutput(id) => scope.output(id, self.expand.ty),
_ => unreachable!(),
};
tensor.into()
}
pub fn __expand_stride_method(
self,
scope: &mut Scope,
dim: NativeExpand<usize>,
) -> NativeExpand<usize> {
let dim: ManagedVariable = dim.into();
let out = scope.create_local(usize::as_type(scope));
scope.register(Instruction::new(
Metadata::Stride {
dim: *dim,
var: self.expand.into(),
},
out.clone().into(),
));
out.into()
}
pub fn __expand_shape_method(
self,
scope: &mut Scope,
dim: NativeExpand<usize>,
) -> NativeExpand<usize> {
let dim: ManagedVariable = dim.into();
let out = scope.create_local(usize::as_type(scope));
scope.register(Instruction::new(
Metadata::Shape {
dim: *dim,
var: self.expand.into(),
},
out.clone().into(),
));
out.into()
}
pub fn __expand_coordinate_method(
self,
scope: &mut Scope,
index: NativeExpand<usize>,
dim: NativeExpand<usize>,
) -> NativeExpand<usize> {
let index: ManagedVariable = index.into();
let stride = self.clone().__expand_stride_method(scope, dim.clone());
let shape = self.clone().__expand_shape_method(scope, dim.clone());
let num_strides = scope.create_local(usize::as_type(scope));
scope.register(Instruction::new(
Arithmetic::Div(BinaryOperator {
lhs: *index,
rhs: stride.expand.into(),
}),
num_strides.clone().into(),
));
let coordinate = scope.create_local(usize::as_type(scope));
scope.register(Instruction::new(
Arithmetic::Modulo(BinaryOperator {
lhs: *num_strides,
rhs: shape.expand.into(),
}),
coordinate.clone().into(),
));
coordinate.into()
}
pub fn __expand_len_method(self, scope: &mut Scope) -> NativeExpand<usize> {
let elem: NativeExpand<Array<u32>> = self.expand.into();
elem.__expand_len_method(scope)
}
pub fn __expand_buffer_len_method(self, scope: &mut Scope) -> NativeExpand<usize> {
let elem: NativeExpand<Array<u32>> = self.expand.into();
elem.__expand_buffer_len_method(scope)
}
pub fn __expand_rank_method(self, scope: &mut Scope) -> NativeExpand<usize> {
let out = scope.create_local(usize::as_type(scope));
scope.register(Instruction::new(Metadata::Rank { var: *self.expand }, *out));
out.into()
}
pub fn __expand_downcast_method<E: CubePrimitive>(
self,
scope: &mut Scope,
) -> NativeExpand<TensorMap<E, K>> {
if T::as_type(scope) != E::as_type(scope) && !is_tf32::<E, T>(scope) {
panic!("Downcast should only be used to satisfy the Rust type system.")
}
self.expand.into()
}
}
}