use core::marker::PhantomData;
use crate as cubecl;
use crate::ir::ExpandElement;
use crate::{prelude::*, unexpanded};
use cubecl_ir::{LineSize, StorageType, Type};
use cubecl_runtime::server::TensorMapMeta;
use paste::paste;
use serde::{Deserialize, Serialize};
pub use cubecl_runtime::tma::*;
pub trait TensorMapKind: CubeType + Clone + Copy + Send + Sync + 'static {
type Args: Clone;
fn as_format(args: Self::Args) -> TensorMapFormat;
}
#[derive(CubeType, CubeLaunch, Clone, Copy)]
pub struct Tiled {}
#[derive(CubeType, CubeLaunch, Clone, Copy)]
pub struct Im2col;
#[derive(CubeType, CubeLaunch, Clone, Copy)]
pub struct Im2colWide;
impl TensorMapKind for Tiled {
type Args = TiledArgs;
fn as_format(args: Self::Args) -> TensorMapFormat {
TensorMapFormat::Tiled(args)
}
}
impl TensorMapKind for Im2col {
type Args = Im2colArgs;
fn as_format(args: Self::Args) -> TensorMapFormat {
TensorMapFormat::Im2col(args)
}
}
impl TensorMapKind for Im2colWide {
type Args = Im2colWideArgs;
fn as_format(args: Self::Args) -> TensorMapFormat {
TensorMapFormat::Im2colWide(args)
}
}
pub struct TensorMapArg<'a, R: Runtime, K: TensorMapKind> {
pub tensor: TensorArg<'a, R>,
pub metadata: TensorMapMeta,
pub _kind: PhantomData<K>,
}
impl<'a, R: Runtime, K: TensorMapKind> TensorMapArg<'a, R, K> {
pub fn new(args: K::Args, tensor: TensorArg<'a, R>, ty: StorageType) -> Self {
let TensorArg::Handle { handle, .. } = &tensor else {
panic!("Can't use alias for TensorMap")
};
let rank = handle.shape.len();
Self {
metadata: TensorMapMeta {
format: K::as_format(args),
rank,
shape: handle.shape.to_vec(),
strides: handle.strides.to_vec(),
elem_stride: vec![1; rank],
interleave: TensorMapInterleave::None,
swizzle: TensorMapSwizzle::None,
prefetch: TensorMapPrefetch::None,
oob_fill: OobFill::Zero,
storage_ty: ty,
},
tensor,
_kind: PhantomData,
}
}
pub fn with_elem_stride(mut self, elem_stride: Vec<usize>) -> Self {
self.metadata.elem_stride = elem_stride;
self
}
pub fn with_interleave(mut self, interleave: TensorMapInterleave) -> Self {
self.metadata.interleave = interleave;
self
}
pub fn with_swizzle(mut self, swizzle: TensorMapSwizzle) -> Self {
self.metadata.swizzle = swizzle;
self
}
pub fn with_prefetch(mut self, prefetch: TensorMapPrefetch) -> Self {
self.metadata.prefetch = prefetch;
self
}
pub fn with_nan_fill(mut self) -> Self {
self.metadata.oob_fill = OobFill::NaN;
self
}
}
#[derive(Clone)]
pub struct TensorMap<E: CubePrimitive, K: TensorMapKind> {
_ty: PhantomData<E>,
_kind: PhantomData<K>,
}
impl<E: CubePrimitive, K: TensorMapKind> Copy for TensorMap<E, K> {}
impl<E: CubePrimitive, K: TensorMapKind> TensorMap<E, K> {}
impl<E: CubePrimitive, K: TensorMapKind> ExpandElementIntoMut for TensorMap<E, K> {
fn elem_into_mut(_scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
elem
}
}
impl<E: CubePrimitive, K: TensorMapKind> CubeType for TensorMap<E, K> {
type ExpandType = ExpandElementTyped<TensorMap<E, K>>;
}
impl<E: CubePrimitive, K: TensorMapKind> CubeType for *const TensorMap<E, K> {
type ExpandType = ExpandElementTyped<TensorMap<E, K>>;
}
impl<E: CubePrimitive, K: TensorMapKind> CubeType for *mut TensorMap<E, K> {
type ExpandType = ExpandElementTyped<TensorMap<E, K>>;
}
impl<R: Runtime, K: TensorMapKind> ArgSettings<R> for TensorMapArg<'_, R, K> {
fn register(&self, launcher: &mut KernelLauncher<R>) {
launcher.register_tensor_map(self)
}
}
impl<E: CubePrimitive, K: TensorMapKind> Lined for TensorMap<E, K> {}
impl<E: CubePrimitive, K: TensorMapKind> LinedExpand for ExpandElementTyped<TensorMap<E, K>> {
fn line_size(&self) -> LineSize {
1
}
}
#[derive(Clone, PartialEq, Eq, Hash, Debug, Serialize, Deserialize)]
pub struct TensorMapCompilationArg;
impl CompilationArg for TensorMapCompilationArg {}
impl<E: CubePrimitive, K: TensorMapKind> LaunchArg for TensorMap<E, K> {
type RuntimeArg<'a, R: Runtime> = TensorMapArg<'a, R, K>;
type CompilationArg = TensorMapCompilationArg;
fn compilation_arg<R: Runtime>(_runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
TensorMapCompilationArg
}
fn expand(
_arg: &Self::CompilationArg,
builder: &mut KernelBuilder,
) -> ExpandElementTyped<TensorMap<E, K>> {
let tensor = builder.input_tensor_map(Type::new(E::as_type(&builder.scope)));
tensor.into()
}
fn expand_output(
_arg: &Self::CompilationArg,
builder: &mut KernelBuilder,
) -> ExpandElementTyped<TensorMap<E, K>> {
let tensor = builder.output_tensor_map(Type::new(E::as_type(&builder.scope)));
tensor.into()
}
}
pub fn tma_group_commit() {
unexpanded!()
}
pub mod tma_group_commit {
use cubecl_ir::TmaOps;
use super::*;
pub fn expand(scope: &mut Scope) {
scope.register(TmaOps::CommitGroup)
}
}
pub fn tma_group_wait(_max_pending: u32) {
unexpanded!()
}
pub mod tma_group_wait {
use cubecl_ir::TmaOps;
use super::*;
pub fn expand(scope: &mut Scope, max_pending: u32) {
scope.register(TmaOps::WaitGroup { max_pending })
}
}
pub fn tma_group_wait_read(_max_pending: u32) {
unexpanded!()
}
pub mod tma_group_wait_read {
use cubecl_ir::TmaOps;
use super::*;
pub fn expand(scope: &mut Scope, max_pending: u32) {
scope.register(TmaOps::WaitGroupRead { max_pending })
}
}
macro_rules! tma_store {
($dim: literal, $($arg: expr),*) => {
paste! {
#[allow(unused)]
pub fn [<tma_store_ $dim d>]<E: CubePrimitive>(
src: &Slice<Line<E>>,
dst: &mut TensorMap<E, Tiled>,
$($arg: i32),*
) {
unexpanded!()
}
pub mod [<tma_store_ $dim d>] {
use cubecl_ir::{Instruction, TmaOps};
use super::*;
#[allow(clippy::too_many_arguments)]
pub fn expand<E: CubePrimitive>(
scope: &mut Scope,
src: SliceExpand<Line<E>, ReadOnly>,
dst: ExpandElementTyped<TensorMap<E, Tiled>>,
$($arg: ExpandElementTyped<i32>),*
) {
let (source, source_offset) = src.__to_raw_parts();
let dst = *dst.expand;
let coordinates = vec![$(*$arg.expand),*];
scope.register(Instruction::new(
TmaOps::TmaStore {
source,
coordinates,
offset_source: source_offset,
},
dst,
))
}
}
}
};
}
tma_store!(1, x);
tma_store!(2, y, x);
tma_store!(3, z, y, x);
tma_store!(4, w, z, y, x);
tma_store!(5, v, w, z, y, x);
mod metadata {
use cubecl_ir::{ExpandElement, Metadata, Type, VariableKind};
use super::*;
use crate::{
ir::{Arithmetic, BinaryOperator, Instruction},
prelude::Array,
};
impl<T: CubePrimitive, K: TensorMapKind> TensorMap<T, K> {
pub fn buffer(&self) -> Tensor<Line<T>> {
unexpanded!()
}
pub fn stride(&self, _dim: usize) -> usize {
unexpanded!()
}
pub fn shape(&self, _dim: usize) -> usize {
unexpanded!()
}
pub fn coordinate(&self, _index: usize, _dim: usize) -> usize {
unexpanded!()
}
#[allow(clippy::len_without_is_empty)]
pub fn len(&self) -> usize {
unexpanded!()
}
#[allow(clippy::len_without_is_empty)]
pub fn buffer_len(&self) -> usize {
unexpanded!()
}
pub fn rank(&self) -> usize {
unexpanded!()
}
pub fn try_cast_unchecked<E: CubePrimitive>(&self) -> TensorMap<E, K> {
unexpanded!()
}
pub fn __expand_buffer(
scope: &mut Scope,
expand: ExpandElementTyped<TensorMap<T, K>>,
) -> ExpandElementTyped<Tensor<Line<T>>> {
expand.__expand_buffer_method(scope)
}
pub fn __expand_stride(
scope: &mut Scope,
expand: ExpandElementTyped<TensorMap<T, K>>,
dim: ExpandElementTyped<usize>,
) -> ExpandElementTyped<usize> {
expand.__expand_stride_method(scope, dim)
}
pub fn __expand_shape(
scope: &mut Scope,
expand: ExpandElementTyped<TensorMap<T, K>>,
dim: ExpandElementTyped<usize>,
) -> ExpandElementTyped<usize> {
expand.__expand_shape_method(scope, dim)
}
pub fn __expand_coordinate(
scope: &mut Scope,
expand: ExpandElementTyped<TensorMap<T, K>>,
index: ExpandElementTyped<usize>,
dim: ExpandElementTyped<usize>,
) -> ExpandElementTyped<usize> {
expand.__expand_coordinate_method(scope, index, dim)
}
pub fn __expand_len(
scope: &mut Scope,
expand: ExpandElementTyped<TensorMap<T, K>>,
) -> ExpandElementTyped<usize> {
expand.__expand_len_method(scope)
}
pub fn __expand_buffer_len(
scope: &mut Scope,
expand: ExpandElementTyped<TensorMap<T, K>>,
) -> ExpandElementTyped<usize> {
expand.__expand_buffer_len_method(scope)
}
pub fn __expand_rank(
scope: &mut Scope,
expand: ExpandElementTyped<TensorMap<T, K>>,
) -> ExpandElementTyped<usize> {
expand.__expand_rank_method(scope)
}
}
impl<T: CubePrimitive, K: TensorMapKind> ExpandElementTyped<TensorMap<T, K>> {
pub fn __expand_buffer_method(
self,
scope: &mut Scope,
) -> ExpandElementTyped<Tensor<Line<T>>> {
let tensor = match self.expand.kind {
VariableKind::TensorMapInput(id) => scope.input(id, self.expand.ty),
VariableKind::TensorMapOutput(id) => scope.output(id, self.expand.ty),
_ => unreachable!(),
};
tensor.into()
}
pub fn __expand_stride_method(
self,
scope: &mut Scope,
dim: ExpandElementTyped<usize>,
) -> ExpandElementTyped<usize> {
let dim: ExpandElement = dim.into();
let out = scope.create_local(Type::new(usize::as_type(scope)));
scope.register(Instruction::new(
Metadata::Stride {
dim: *dim,
var: self.expand.into(),
},
out.clone().into(),
));
out.into()
}
pub fn __expand_shape_method(
self,
scope: &mut Scope,
dim: ExpandElementTyped<usize>,
) -> ExpandElementTyped<usize> {
let dim: ExpandElement = dim.into();
let out = scope.create_local(Type::new(usize::as_type(scope)));
scope.register(Instruction::new(
Metadata::Shape {
dim: *dim,
var: self.expand.into(),
},
out.clone().into(),
));
out.into()
}
pub fn __expand_coordinate_method(
self,
scope: &mut Scope,
index: ExpandElementTyped<usize>,
dim: ExpandElementTyped<usize>,
) -> ExpandElementTyped<usize> {
let index: ExpandElement = index.into();
let stride = self.clone().__expand_stride_method(scope, dim.clone());
let shape = self.clone().__expand_shape_method(scope, dim.clone());
let num_strides = scope.create_local(Type::new(usize::as_type(scope)));
scope.register(Instruction::new(
Arithmetic::Div(BinaryOperator {
lhs: *index,
rhs: stride.expand.into(),
}),
num_strides.clone().into(),
));
let coordinate = scope.create_local(Type::new(usize::as_type(scope)));
scope.register(Instruction::new(
Arithmetic::Modulo(BinaryOperator {
lhs: *num_strides,
rhs: shape.expand.into(),
}),
coordinate.clone().into(),
));
coordinate.into()
}
pub fn __expand_len_method(self, scope: &mut Scope) -> ExpandElementTyped<usize> {
let elem: ExpandElementTyped<Array<u32>> = self.expand.into();
elem.__expand_len_method(scope)
}
pub fn __expand_buffer_len_method(self, scope: &mut Scope) -> ExpandElementTyped<usize> {
let elem: ExpandElementTyped<Array<u32>> = self.expand.into();
elem.__expand_buffer_len_method(scope)
}
pub fn __expand_rank_method(self, scope: &mut Scope) -> ExpandElementTyped<usize> {
let out = scope.create_local(Type::new(u32::as_type(scope)));
scope.register(Instruction::new(Metadata::Rank { var: *self.expand }, *out));
out.into()
}
pub fn __expand_try_cast_unchecked_method<E: CubePrimitive>(
self,
scope: &mut Scope,
) -> ExpandElementTyped<TensorMap<E, K>> {
if T::as_type(scope) != E::as_type(scope) && !is_tf32::<E, T>(scope) {
panic!("Try cast unchecked should only be used to satisfy the rust type system.")
}
self.expand.into()
}
}
}