use super::{
CubeDebug, CubePrimitive, CubeType, IntoMut, NativeExpand, ReadOnly, Slice, SliceExpand,
SliceMut,
};
use crate::{self as cubecl, prelude::*};
use crate::{
ir::{self, Instruction},
unexpanded,
};
use core::marker::PhantomData;
use cubecl_macros::{comptime_type, cube, intrinsic};
use cubecl_ir::{CoopMma, ManagedVariable, Scope, StorageType, VectorSize};
pub use ir::{MatrixIdent, MatrixLayout};
#[derive(Copy, Clone)]
pub struct Matrix<C: CubeType> {
_c: PhantomData<C>,
}
#[derive(Copy, Clone)]
pub struct MmaDefinition<A: CubeType, B: CubeType, CD: CubeType> {
_a: PhantomData<A>,
_b: PhantomData<B>,
_cd: PhantomData<CD>,
}
impl<A: CubeType, B: CubeType, CD: CubeType> CubeDebug for &MmaDefinitionExpand<A, B, CD> {
fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
MmaDefinitionExpand::set_debug_name(self, scope, name);
}
}
pub struct MatrixExpand<C: CubeType> {
elem: ManagedVariable,
ident: MatrixIdent,
_c: PhantomData<C>,
}
#[derive(Debug)]
pub struct MmaDefinitionExpand<A: CubeType, B: CubeType, CD: CubeType> {
pub m: usize,
pub n: usize,
pub k: usize,
pub a_type: StorageType,
pub b_type: StorageType,
pub cd_type: StorageType,
pub scales_factor: Option<usize>,
pub scales_type: Option<StorageType>,
_a: PhantomData<A>,
_b: PhantomData<B>,
_cd: PhantomData<CD>,
}
impl<C: CubeType> Clone for MatrixExpand<C> {
fn clone(&self) -> Self {
Self {
elem: self.elem.clone(),
ident: self.ident,
_c: self._c,
}
}
}
impl<A: CubeType, B: CubeType, CD: CubeType> Clone for MmaDefinitionExpand<A, B, CD> {
fn clone(&self) -> Self {
Self {
m: self.m,
n: self.n,
k: self.k,
a_type: self.a_type,
b_type: self.b_type,
cd_type: self.cd_type,
scales_factor: self.scales_factor,
scales_type: self.scales_type,
_a: PhantomData,
_b: PhantomData,
_cd: PhantomData,
}
}
}
impl<C: CubeType> CubeType for Matrix<C> {
type ExpandType = MatrixExpand<C>;
}
impl<A: CubeType, B: CubeType, CD: CubeType> CubeType for MmaDefinition<A, B, CD> {
type ExpandType = MmaDefinitionExpand<A, B, CD>;
}
impl<C: CubeType> IntoMut for MatrixExpand<C> {
fn into_mut(self, _scope: &mut Scope) -> Self {
self
}
}
impl<C: CubeType> CubeDebug for MatrixExpand<C> {
fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
scope.update_variable_name(*self.elem, name);
}
}
impl<A: CubeType, B: CubeType, CD: CubeType> IntoMut for MmaDefinitionExpand<A, B, CD> {
fn into_mut(self, _scope: &mut Scope) -> Self {
self
}
}
impl<A: CubeType, B: CubeType, CD: CubeType> CubeDebug for MmaDefinitionExpand<A, B, CD> {}
#[cube]
impl<C: CubePrimitive> Matrix<C> {
#[allow(unused_variables)]
pub unsafe fn uninitialized(
#[comptime] ident: MatrixIdent,
#[comptime] m: usize,
#[comptime] n: usize,
#[comptime] k: usize,
layout: MatrixLayout,
) -> Self {
intrinsic!(|scope| {
let elem = C::as_type(scope).storage_type();
let elem = scope.create_matrix(ir::Matrix::new(ident, m, n, k, elem, layout));
MatrixExpand {
elem,
ident,
_c: PhantomData,
}
})
}
#[allow(unused_variables)]
pub fn from_value(
#[comptime] ident: MatrixIdent,
#[comptime] m: usize,
#[comptime] n: usize,
#[comptime] k: usize,
layout: MatrixLayout,
value: C,
) -> Self
where
C: Scalar,
{
let mat = unsafe { Self::uninitialized(ident, m, n, k, layout) };
intrinsic!(|scope| {
fill::expand(scope, mat.clone(), value);
mat
})
}
#[allow(unused_variables)]
pub fn from_slice(
#[comptime] ident: MatrixIdent,
#[comptime] m: usize,
#[comptime] n: usize,
#[comptime] k: usize,
layout: MatrixLayout,
value: &Slice<C>,
stride: u32,
) -> Self {
let mat = unsafe { Self::uninitialized(ident, m, n, k, layout) };
intrinsic!(|scope| {
load::expand(scope, mat.clone(), value, stride);
mat
})
}
}
#[cube(self_type = "ref")]
impl<A: Scalar, B: Scalar, CD: Scalar> MmaDefinition<A, B, CD> {
#[allow(unused_variables)]
pub fn new(#[comptime] m: usize, #[comptime] n: usize, #[comptime] k: usize) -> Self {
intrinsic!(|scope| {
let a_type = A::as_type(scope).storage_type();
let b_type = B::as_type(scope).storage_type();
let cd_type = CD::as_type(scope).storage_type();
MmaDefinitionExpand {
m,
n,
k,
a_type,
b_type,
cd_type,
scales_factor: None,
scales_type: None,
_a: PhantomData,
_b: PhantomData,
_cd: PhantomData,
}
})
}
#[allow(unused_variables)]
pub fn new_scaled<S: CubePrimitive>(
#[comptime] m: usize,
#[comptime] n: usize,
#[comptime] k: usize,
#[comptime] scale_factor: usize,
) -> Self {
intrinsic!(|scope| {
let a_type = A::as_type(scope).storage_type();
let b_type = B::as_type(scope).storage_type();
let cd_type = CD::as_type(scope).storage_type();
MmaDefinitionExpand {
m,
n,
k,
a_type,
b_type,
cd_type,
scales_factor: Some(scale_factor),
scales_type: Some(S::as_type(scope).storage_type()),
_a: PhantomData,
_b: PhantomData,
_cd: PhantomData,
}
})
}
#[allow(unused)]
pub fn num_elems(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(usize) {
intrinsic!(|scope| {
match ident {
MatrixIdent::A => (self.m * self.k) / self.a_type.packing_factor(),
MatrixIdent::B => (self.k * self.n) / self.b_type.packing_factor(),
MatrixIdent::Accumulator => (self.m * self.n) / self.cd_type.packing_factor(),
}
})
}
#[allow(unused)]
pub fn elems_per_lane(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(usize) {
intrinsic!(|scope| {
let elems = self.__expand_num_elems_method(scope, ident);
let plane_dim = scope.runtime_properties.mma.const_plane_size as usize;
let duplication = match ident {
MatrixIdent::A => scope.runtime_properties.mma.register_duplication_a,
MatrixIdent::B => scope.runtime_properties.mma.register_duplication_b,
MatrixIdent::Accumulator => scope.runtime_properties.mma.register_duplication_acc,
};
(elems * duplication) / plane_dim
})
}
#[allow(unused)]
pub fn vectors_per_lane(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(usize) {
intrinsic!(|scope| {
let elems = self.clone().__expand_elems_per_lane_method(scope, ident);
let vector_size = self.__expand_vector_size_method(scope, ident);
elems / vector_size
})
}
#[allow(unused)]
pub fn vector_layout(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(MatrixLayout) {
intrinsic!(|scope| {
match ident {
MatrixIdent::A => scope.runtime_properties.mma.register_layout_a,
MatrixIdent::B => scope.runtime_properties.mma.register_layout_b,
MatrixIdent::Accumulator => scope.runtime_properties.mma.register_layout_acc,
}
})
}
#[allow(unused_variables)]
pub fn vector_size(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(VectorSize) {
intrinsic!(|scope| {
let storage = match ident {
MatrixIdent::A => self.a_type,
MatrixIdent::B => self.b_type,
MatrixIdent::Accumulator => self.cd_type,
};
let matrix = cubecl_ir::Matrix {
ident,
m: self.m,
n: self.n,
k: self.k,
storage: storage,
layout: MatrixLayout::ColMajor,
};
scope
.runtime_properties
.mma
.contiguous_elements
.apply(ident, matrix)
})
}
#[allow(unused_variables)]
pub fn position_of_nth(
&self,
lane_id: u32,
elem_idx: u32,
#[comptime] ident: MatrixIdent,
) -> (u32, u32) {
intrinsic!(|scope| {
let lane_id: ManagedVariable = lane_id.into();
let elem_idx: ManagedVariable = elem_idx.into();
let ty = match ident {
MatrixIdent::A => self.a_type,
MatrixIdent::B => self.b_type,
MatrixIdent::Accumulator => self.cd_type,
};
let layout = match ident {
MatrixIdent::A => scope.runtime_properties.mma.register_layout_a,
MatrixIdent::B => scope.runtime_properties.mma.register_layout_b,
MatrixIdent::Accumulator => scope.runtime_properties.mma.register_layout_acc,
};
let matrix = cubecl_ir::Matrix {
ident,
m: self.m,
n: self.n,
k: self.k,
storage: ty,
layout,
};
let row = scope.create_local(u32::as_type(scope));
let col = scope.create_local(u32::as_type(scope));
scope.register(Instruction::new(
CoopMma::RowIndex {
lane_id: *lane_id,
i: *elem_idx,
matrix,
},
*row,
));
scope.register(Instruction::new(
CoopMma::ColIndex {
lane_id: *lane_id,
i: *elem_idx,
matrix,
},
*col,
));
(row.into(), col.into())
})
}
pub fn scales_index(&self, lane_id: u32, #[comptime] ident: MatrixIdent) -> u32 {
let quad_id = lane_id / 4;
let t_id = lane_id % 4;
match ident {
MatrixIdent::A => quad_id + (t_id % 2) * 8,
MatrixIdent::B => quad_id,
MatrixIdent::Accumulator => panic!("Accumulator doesn't have scales"),
}
}
pub fn scales_count(&self) -> comptime_type!(usize) {
intrinsic!(|_| {
self.scales_factor
.expect("Can't retrieve scales count for matrix with no scales")
})
}
pub fn scales_vector_size(&self) -> comptime_type!(VectorSize) {
intrinsic!(|scope| {
let elem = self
.scales_type
.expect("Can't retrieve scales vector size for matrix with no scales");
scope.runtime_properties.mma.register_size_bits / elem.size_bits()
})
}
#[allow(unused_variables)]
pub fn load_matrix<E: CubePrimitive, NO: Size>(
&self,
row: &Slice<E>,
#[comptime] ident: MatrixIdent,
#[comptime] num_matrices: usize,
#[comptime] transpose: bool,
) -> Array<Vector<E::Scalar, NO>> {
intrinsic!(|scope| {
let slice_vector_size = row.vector_size;
let (buffer, offset) = row.__to_raw_parts();
let out = Array::__expand_new(scope, num_matrices);
scope.register(Instruction::new(
CoopMma::LoadMatrix {
buffer,
offset,
vector_size: slice_vector_size,
factor: num_matrices,
transpose,
},
*out.expand,
));
out
})
}
#[allow(unused_variables)]
pub fn load_matrix_inplace<E: Scalar, N: Size>(
&self,
row: &Slice<E>,
fragment: &mut Array<Vector<E, N>>,
#[comptime] ident: MatrixIdent,
#[comptime] num_matrices: usize,
#[comptime] transpose: bool,
) {
intrinsic!(|scope| {
let vector_size = self.__expand_vector_size_method(scope, ident);
let slice_vector_size = row.vector_size;
let (buffer, offset) = row.__to_raw_parts();
scope.register(Instruction::new(
CoopMma::LoadMatrix {
buffer,
offset,
vector_size: slice_vector_size,
factor: num_matrices,
transpose,
},
*fragment.expand,
));
})
}
#[allow(unused_variables)]
pub fn store_matrix<E: CubePrimitive, N: Size>(
&self,
row: &mut Slice<E, ReadWrite>,
registers: &Array<Vector<E::Scalar, N>>,
#[comptime] ident: MatrixIdent,
#[comptime] num_matrices: usize,
#[comptime] transpose: bool,
) {
intrinsic!(|scope| {
let vector_size = self.__expand_vector_size_method(scope, ident);
let slice_vector_size = row.vector_size;
let (buffer, offset) = row.__to_raw_parts();
scope.register(Instruction::new(
CoopMma::StoreMatrix {
offset,
vector_size: slice_vector_size,
registers: *registers.expand,
factor: num_matrices,
transpose,
},
buffer,
));
})
}
#[allow(unused)]
pub fn execute<NA: Size, NB: Size, NC: Size>(
&self,
registers_a: &Array<Vector<A, NA>>,
registers_b: &Array<Vector<B, NB>>,
registers_c: &Array<Vector<CD, NC>>,
) -> Array<Vector<CD, NC>> {
intrinsic!(|scope| {
let acc_elems = self
.clone()
.__expand_elems_per_lane_method(scope, MatrixIdent::Accumulator);
let acc_vector_size = self
.clone()
.__expand_vector_size_method(scope, MatrixIdent::Accumulator);
let num_registers = acc_elems / acc_vector_size;
let registers_d = Array::__expand_new(scope, num_registers);
let registers_a = *registers_a.expand;
let registers_b = *registers_b.expand;
let registers_c = *registers_c.expand;
let matrix = cubecl_ir::Matrix {
ident: MatrixIdent::A,
m: self.m,
n: self.n,
k: self.k,
storage: self.a_type,
layout: MatrixLayout::ColMajor,
};
scope.register(Instruction::new(
CoopMma::ExecuteManual {
matrix,
registers_a,
registers_b,
registers_c,
},
*registers_d.expand,
));
registers_d
})
}
#[allow(unused)]
pub fn execute_inplace<NA: Size, NB: Size, NC: Size>(
&self,
registers_a: &Array<Vector<A, NA>>,
registers_b: &Array<Vector<B, NB>>,
registers_c: &mut Array<Vector<CD, NC>>,
) {
intrinsic!(|scope| {
let acc_elems = self
.clone()
.__expand_elems_per_lane_method(scope, MatrixIdent::Accumulator);
let acc_vector_size = self
.clone()
.__expand_vector_size_method(scope, MatrixIdent::Accumulator);
let num_registers = acc_elems / acc_vector_size;
let registers_a = *registers_a.expand;
let registers_b = *registers_b.expand;
let registers_c = *registers_c.expand;
let matrix = cubecl_ir::Matrix {
ident: MatrixIdent::A,
m: self.m,
n: self.n,
k: self.k,
storage: self.a_type,
layout: MatrixLayout::ColMajor,
};
scope.register(Instruction::new(
CoopMma::ExecuteManual {
matrix,
registers_a,
registers_b,
registers_c,
},
registers_c,
));
})
}
#[allow(unused)]
pub fn execute_scaled<S: Scalar, NA: Size, NB: Size, NC: Size, NS: Size>(
&self,
registers_a: &Array<Vector<A, NA>>,
registers_b: &Array<Vector<B, NB>>,
registers_c: &Array<Vector<CD, NC>>,
scales_a: Vector<S, NS>,
scales_b: Vector<S, NS>,
) -> Array<Vector<CD, NC>> {
intrinsic!(|scope| {
let acc_elems = self
.clone()
.__expand_elems_per_lane_method(scope, MatrixIdent::Accumulator);
let acc_vector_size = self
.clone()
.__expand_vector_size_method(scope, MatrixIdent::Accumulator);
let num_registers = acc_elems / acc_vector_size;
let registers_d = Array::__expand_new(scope, num_registers);
let registers_a = *registers_a.expand;
let registers_b = *registers_b.expand;
let registers_c = *registers_c.expand;
let matrix = cubecl_ir::Matrix {
ident: MatrixIdent::A,
m: self.m,
n: self.n,
k: self.k,
storage: self.a_type,
layout: MatrixLayout::ColMajor,
};
scope.register(Instruction::new(
CoopMma::ExecuteScaled {
matrix,
registers_a,
registers_b,
registers_c,
scales_a: *scales_a.expand,
scales_b: *scales_b.expand,
scales_factor: self
.scales_factor
.expect("Can't execute scaled on matrix with no scales"),
},
*registers_d.expand,
));
registers_d
})
}
}
#[allow(unused_variables)]
pub fn fill<C: Scalar>(mat: &Matrix<C>, value: C) {
unexpanded!()
}
pub mod fill {
use super::*;
pub fn expand<C: Scalar>(scope: &mut Scope, mat: MatrixExpand<C>, value: NativeExpand<C>) {
let value: ManagedVariable = value.into();
scope.register(Instruction::new(
ir::CoopMma::Fill { value: *value },
*mat.elem,
));
}
}
#[allow(unused_variables)]
pub fn load<C: CubePrimitive, V: CubePrimitive>(mat: &Matrix<C>, value: &Slice<V>, stride: u32) {
unexpanded!()
}
pub mod load {
use super::*;
#[allow(unused_variables)]
pub fn expand<C: CubePrimitive, V: CubePrimitive>(
scope: &mut Scope,
mat: MatrixExpand<C>,
value: SliceExpand<V, ReadOnly>,
stride: NativeExpand<u32>,
) {
let stride: ManagedVariable = stride.into();
assert_ne!(
mat.ident,
MatrixIdent::Accumulator,
"Loading accumulator requires explicit layout. Use `load_with_layout` instead."
);
let (value, offset) = value.__to_raw_parts();
scope.register(Instruction::new(
ir::CoopMma::Load {
value,
stride: *stride,
offset,
layout: None,
},
*mat.elem,
));
}
}
#[allow(unused_variables)]
pub fn load_with_layout<C: CubePrimitive, V: CubePrimitive>(
mat: &Matrix<C>,
value: &Slice<V>,
stride: u32,
layout: MatrixLayout,
) {
unexpanded!()
}
pub mod load_with_layout {
use super::*;
#[allow(unused_variables)]
pub fn expand<C: CubeType, V: CubePrimitive>(
scope: &mut Scope,
mat: MatrixExpand<C>,
value: SliceExpand<V, ReadOnly>,
stride: NativeExpand<u32>,
layout: MatrixLayout,
) {
let stride: ManagedVariable = stride.into();
let (value, offset) = value.__to_raw_parts();
scope.register(Instruction::new(
ir::CoopMma::Load {
value,
stride: *stride,
offset,
layout: Some(layout),
},
*mat.elem,
));
}
}
#[allow(unused_variables)]
pub fn store<C: CubePrimitive, O: CubePrimitive>(
output: &mut SliceMut<O>,
mat: &Matrix<C>,
stride: u32,
layout: MatrixLayout,
) {
unexpanded!()
}
pub mod store {
use crate::prelude::ReadWrite;
use super::*;
#[allow(unused_variables)]
pub fn expand<C: CubePrimitive, O: CubePrimitive>(
scope: &mut Scope,
output: SliceExpand<O, ReadWrite>,
mat: MatrixExpand<C>,
stride: NativeExpand<u32>,
layout: MatrixLayout,
) {
let stride: ManagedVariable = stride.into();
let (output, offset) = output.__to_raw_parts();
scope.register(Instruction::new(
ir::CoopMma::Store {
mat: *mat.elem,
offset,
stride: *stride,
layout,
},
output,
));
}
}
#[allow(unused_variables)]
pub fn execute<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
mat_a: &Matrix<A>,
mat_b: &Matrix<B>,
mat_c: &Matrix<C>,
mat_d: &Matrix<D>,
) {
unexpanded!()
}
pub mod execute {
use super::*;
pub fn expand<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
scope: &mut Scope,
mat_a: MatrixExpand<A>,
mat_b: MatrixExpand<B>,
mat_c: MatrixExpand<C>,
mat_d: MatrixExpand<D>,
) {
scope.register(Instruction::new(
ir::CoopMma::Execute {
mat_a: *mat_a.elem,
mat_b: *mat_b.elem,
mat_c: *mat_c.elem,
},
*mat_d.elem,
));
}
}
#[allow(unused_variables)]
pub fn cast<C: CubePrimitive, O: CubePrimitive>(input: &Matrix<C>) -> Matrix<O> {
unexpanded!()
}
pub mod cast {
use super::*;
#[allow(unused_variables)]
pub fn expand<C: CubePrimitive, O: CubePrimitive>(
scope: &mut Scope,
input: MatrixExpand<C>,
) -> MatrixExpand<O> {
let ident = input.ident;
if core::any::TypeId::of::<C>() == core::any::TypeId::of::<O>() {
return MatrixExpand {
elem: input.elem,
ident,
_c: PhantomData,
};
}
let input = *input.elem;
let input_mat = match input.kind {
ir::VariableKind::Matrix { mat, .. } => mat,
_ => unreachable!(),
};
let elem = O::as_type(scope).storage_type();
let elem = scope.create_matrix(ir::Matrix::new(
ident,
input_mat.m,
input_mat.n,
input_mat.k,
elem,
MatrixLayout::Undefined,
));
let output = MatrixExpand {
ident,
elem,
_c: PhantomData,
};
scope.register(Instruction::new(ir::CoopMma::Cast { input }, *output.elem));
output
}
}
impl CubeType for MatrixLayout {
type ExpandType = Self;
}
impl IntoMut for MatrixLayout {
fn into_mut(self, _scope: &mut crate::ir::Scope) -> Self {
self
}
}
impl CubeDebug for MatrixLayout {}