use super::{
CubeDebug, CubePrimitive, CubeType, ExpandElementTyped, IntoMut, ReadOnly, Slice, SliceExpand,
SliceMut,
};
use crate::{
self as cubecl,
prelude::{Array, Line, Sequence},
};
use crate::{
ir::{self, Instruction},
unexpanded,
};
use cubecl_macros::{comptime_type, cube, intrinsic};
use std::marker::PhantomData;
use cubecl_ir::{CoopMma, ExpandElement, Scope, StorageType, Type};
pub use ir::{MatrixIdent, MatrixLayout};
#[derive(Copy, Clone)]
pub struct Matrix<C: CubeType> {
_c: PhantomData<C>,
}
#[derive(Copy, Clone)]
pub struct MmaDefinition<A: CubeType, B: CubeType, CD: CubeType> {
_a: PhantomData<A>,
_b: PhantomData<B>,
_cd: PhantomData<CD>,
}
pub struct MatrixExpand<C: CubeType> {
elem: ExpandElement,
ident: MatrixIdent,
_c: PhantomData<C>,
}
#[derive(Debug)]
pub struct MmaDefinitionExpand<A: CubeType, B: CubeType, CD: CubeType> {
pub m: u32,
pub n: u32,
pub k: u32,
pub a_type: StorageType,
pub b_type: StorageType,
pub cd_type: StorageType,
pub scales_factor: Option<u32>,
pub scales_type: Option<StorageType>,
_a: PhantomData<A>,
_b: PhantomData<B>,
_cd: PhantomData<CD>,
}
impl<C: CubeType> Clone for MatrixExpand<C> {
fn clone(&self) -> Self {
Self {
elem: self.elem.clone(),
ident: self.ident,
_c: self._c,
}
}
}
impl<A: CubeType, B: CubeType, CD: CubeType> Clone for MmaDefinitionExpand<A, B, CD> {
fn clone(&self) -> Self {
Self {
m: self.m,
n: self.n,
k: self.k,
a_type: self.a_type,
b_type: self.b_type,
cd_type: self.cd_type,
scales_factor: self.scales_factor,
scales_type: self.scales_type,
_a: PhantomData,
_b: PhantomData,
_cd: PhantomData,
}
}
}
impl<C: CubeType> CubeType for Matrix<C> {
type ExpandType = MatrixExpand<C>;
}
impl<A: CubeType, B: CubeType, CD: CubeType> CubeType for MmaDefinition<A, B, CD> {
type ExpandType = MmaDefinitionExpand<A, B, CD>;
}
impl<C: CubeType> IntoMut for MatrixExpand<C> {
fn into_mut(self, _scope: &mut Scope) -> Self {
self
}
}
impl<C: CubeType> CubeDebug for MatrixExpand<C> {
fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
scope.update_variable_name(*self.elem, name);
}
}
impl<A: CubeType, B: CubeType, CD: CubeType> IntoMut for MmaDefinitionExpand<A, B, CD> {
fn into_mut(self, _scope: &mut Scope) -> Self {
self
}
}
impl<A: CubeType, B: CubeType, CD: CubeType> CubeDebug for MmaDefinitionExpand<A, B, CD> {}
#[cube]
impl<C: CubePrimitive> Matrix<C> {
#[allow(unused_variables)]
pub unsafe fn uninitialized(
#[comptime] ident: MatrixIdent,
m: u32,
n: u32,
k: u32,
layout: MatrixLayout,
) -> Self {
intrinsic!(|scope| {
let elem = C::as_type(scope);
let elem = scope.create_matrix(ir::Matrix::new(
ident,
m.constant().unwrap().as_u32(),
n.constant().unwrap().as_u32(),
k.constant().unwrap().as_u32(),
elem,
layout,
));
MatrixExpand {
elem,
ident,
_c: PhantomData,
}
})
}
#[allow(unused_variables)]
pub fn from_value(
#[comptime] ident: MatrixIdent,
m: u32,
n: u32,
k: u32,
layout: MatrixLayout,
value: C,
) -> Self {
let mat = unsafe { Self::uninitialized(ident, m, n, k, layout) };
intrinsic!(|scope| {
fill::expand(scope, mat.clone(), value);
mat
})
}
#[allow(unused_variables)]
pub fn from_slice(
#[comptime] ident: MatrixIdent,
m: u32,
n: u32,
k: u32,
layout: MatrixLayout,
value: &Slice<C>,
stride: u32,
) -> Self {
let mat = unsafe { Self::uninitialized(ident, m, n, k, layout) };
intrinsic!(|scope| {
load::expand(scope, mat.clone(), value, stride);
mat
})
}
}
#[cube]
impl<A: CubePrimitive, B: CubePrimitive, CD: CubePrimitive> MmaDefinition<A, B, CD> {
#[allow(unused_variables)]
pub fn new(#[comptime] m: u32, #[comptime] n: u32, #[comptime] k: u32) -> Self {
intrinsic!(|scope| {
let a_type = A::as_type(scope);
let b_type = B::as_type(scope);
let cd_type = CD::as_type(scope);
MmaDefinitionExpand {
m,
n,
k,
a_type,
b_type,
cd_type,
scales_factor: None,
scales_type: None,
_a: PhantomData,
_b: PhantomData,
_cd: PhantomData,
}
})
}
#[allow(unused_variables)]
pub fn new_scaled<S: CubePrimitive>(
#[comptime] m: u32,
#[comptime] n: u32,
#[comptime] k: u32,
#[comptime] scale_factor: u32,
) -> Self {
intrinsic!(|scope| {
let a_type = A::as_type(scope);
let b_type = B::as_type(scope);
let cd_type = CD::as_type(scope);
MmaDefinitionExpand {
m,
n,
k,
a_type,
b_type,
cd_type,
scales_factor: Some(scale_factor),
scales_type: Some(S::as_type(scope)),
_a: PhantomData,
_b: PhantomData,
_cd: PhantomData,
}
})
}
#[allow(unused)]
pub fn num_elems(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(u32) {
intrinsic!(|scope| {
match ident {
MatrixIdent::A => (self.m * self.k) / self.a_type.packing_factor() as u32,
MatrixIdent::B => (self.k * self.n) / self.b_type.packing_factor() as u32,
MatrixIdent::Accumulator => {
(self.m * self.n) / self.cd_type.packing_factor() as u32
}
}
})
}
#[allow(unused)]
pub fn elems_per_lane(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(u32) {
intrinsic!(|scope| {
let elems = self.__expand_num_elems_method(scope, ident);
let plane_dim = scope.runtime_properties.mma.const_plane_size;
let duplication = match ident {
MatrixIdent::A => scope.runtime_properties.mma.register_duplication_a,
MatrixIdent::B => scope.runtime_properties.mma.register_duplication_b,
MatrixIdent::Accumulator => scope.runtime_properties.mma.register_duplication_acc,
};
(elems / plane_dim) * duplication
})
}
#[allow(unused)]
pub fn lines_per_lane(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(u32) {
intrinsic!(|scope| {
let elems = self.clone().__expand_elems_per_lane_method(scope, ident);
let line_size = self.__expand_line_size_method(scope, ident);
elems / line_size
})
}
#[allow(unused)]
pub fn line_layout(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(MatrixLayout) {
intrinsic!(|scope| {
match ident {
MatrixIdent::A => scope.runtime_properties.mma.register_layout_a,
MatrixIdent::B => scope.runtime_properties.mma.register_layout_b,
MatrixIdent::Accumulator => scope.runtime_properties.mma.register_layout_acc,
}
})
}
#[allow(unused_variables)]
pub fn line_size(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(u32) {
intrinsic!(|scope| {
let bits = match ident {
MatrixIdent::A => StorageType::size_bits(&self.a_type) as u32,
MatrixIdent::B => StorageType::size_bits(&self.b_type) as u32,
MatrixIdent::Accumulator => StorageType::size_bits(&self.cd_type) as u32,
};
let register_size = scope.runtime_properties.mma.register_size_bits;
register_size.div_ceil(bits)
})
}
#[allow(unused_variables)]
pub fn position_of_nth(
&self,
lane_id: u32,
elem_idx: u32,
#[comptime] ident: MatrixIdent,
) -> (u32, u32) {
intrinsic!(|scope| {
let lane_id: ExpandElement = lane_id.into();
let elem_idx: ExpandElement = elem_idx.into();
let ty = match ident {
MatrixIdent::A => self.a_type,
MatrixIdent::B => self.b_type,
MatrixIdent::Accumulator => self.cd_type,
};
let layout = match ident {
MatrixIdent::A => scope.runtime_properties.mma.register_layout_a,
MatrixIdent::B => scope.runtime_properties.mma.register_layout_b,
MatrixIdent::Accumulator => scope.runtime_properties.mma.register_layout_acc,
};
let matrix = cubecl_ir::Matrix {
ident,
m: self.m,
n: self.n,
k: self.k,
storage: ty,
layout,
};
let row = scope.create_local(Type::new(u32::as_type(scope)));
let col = scope.create_local(Type::new(u32::as_type(scope)));
scope.register(Instruction::new(
CoopMma::RowIndex {
lane_id: *lane_id,
i: *elem_idx,
matrix,
},
*row,
));
scope.register(Instruction::new(
CoopMma::ColIndex {
lane_id: *lane_id,
i: *elem_idx,
matrix,
},
*col,
));
(row.into(), col.into())
})
}
pub fn scales_index(&self, lane_id: u32, #[comptime] ident: MatrixIdent) -> u32 {
let quad_id = lane_id / 4;
let t_id = lane_id % 4;
match ident {
MatrixIdent::A => quad_id + (t_id % 2) * 8,
MatrixIdent::B => quad_id,
MatrixIdent::Accumulator => panic!("Accumulator doesn't have scales"),
}
}
pub fn scales_count(&self) -> comptime_type!(u32) {
intrinsic!(|_| {
self.scales_factor
.expect("Can't retrieve scales count for matrix with no scales")
})
}
pub fn scales_line_size(&self) -> comptime_type!(u32) {
intrinsic!(|scope| {
let elem = self
.scales_type
.expect("Can't retrieve scales line size for matrix with no scales");
scope.runtime_properties.mma.register_size_bits / elem.size_bits() as u32
})
}
#[allow(unused)]
pub fn execute(
&self,
registers_a: &Sequence<Line<A>>,
registers_b: &Sequence<Line<B>>,
registers_c: &Sequence<Line<CD>>,
) -> Array<Line<CD>> {
intrinsic!(|scope| {
let acc_elems = self
.clone()
.__expand_elems_per_lane_method(scope, MatrixIdent::Accumulator);
let acc_line_size = self
.clone()
.__expand_line_size_method(scope, MatrixIdent::Accumulator);
let num_registers = acc_elems / acc_line_size;
let registers_d = Array::__expand_vectorized(scope, num_registers, acc_line_size);
let registers_a = registers_a
.iter_cloned()
.map(|it| *it.expand)
.collect::<Vec<_>>();
let registers_b = registers_b
.iter_cloned()
.map(|it| *it.expand)
.collect::<Vec<_>>();
let registers_c = registers_c
.iter_cloned()
.map(|it| *it.expand)
.collect::<Vec<_>>();
let matrix = cubecl_ir::Matrix {
ident: MatrixIdent::A,
m: self.m,
n: self.n,
k: self.k,
storage: self.a_type,
layout: MatrixLayout::ColMajor,
};
scope.register(Instruction::new(
CoopMma::ExecuteManual {
matrix,
registers_a,
registers_b,
registers_c,
},
*registers_d.expand,
));
registers_d
})
}
#[allow(unused)]
pub fn execute_scaled<S: CubePrimitive>(
&self,
registers_a: &Sequence<Line<A>>,
registers_b: &Sequence<Line<B>>,
registers_c: &Sequence<Line<CD>>,
scales_a: Line<S>,
scales_b: Line<S>,
) -> Array<Line<CD>> {
intrinsic!(|scope| {
let acc_elems = self
.clone()
.__expand_elems_per_lane_method(scope, MatrixIdent::Accumulator);
let acc_line_size = self
.clone()
.__expand_line_size_method(scope, MatrixIdent::Accumulator);
let num_registers = acc_elems / acc_line_size;
let registers_d = Array::__expand_vectorized(scope, num_registers, acc_line_size);
let registers_a = registers_a
.iter_cloned()
.map(|it| *it.expand)
.collect::<Vec<_>>();
let registers_b = registers_b
.iter_cloned()
.map(|it| *it.expand)
.collect::<Vec<_>>();
let registers_c = registers_c
.iter_cloned()
.map(|it| *it.expand)
.collect::<Vec<_>>();
let matrix = cubecl_ir::Matrix {
ident: MatrixIdent::A,
m: self.m,
n: self.n,
k: self.k,
storage: self.a_type,
layout: MatrixLayout::ColMajor,
};
scope.register(Instruction::new(
CoopMma::ExecuteScaled {
matrix,
registers_a,
registers_b,
registers_c,
scales_a: *scales_a.expand,
scales_b: *scales_b.expand,
scales_factor: self
.scales_factor
.expect("Can't execute scaled on matrix with no scales"),
},
*registers_d.expand,
));
registers_d
})
}
}
#[allow(unused_variables)]
pub fn fill<C: CubeType>(mat: &Matrix<C>, value: C) {
unexpanded!()
}
pub mod fill {
use super::*;
pub fn expand<C: CubeType>(
scope: &mut Scope,
mat: MatrixExpand<C>,
value: ExpandElementTyped<C>,
) {
let value: ExpandElement = value.into();
scope.register(Instruction::new(
ir::CoopMma::Fill { value: *value },
*mat.elem,
));
}
}
#[allow(unused_variables)]
pub fn load<C: CubePrimitive, V: CubePrimitive>(mat: &Matrix<C>, value: &Slice<V>, stride: u32) {
unexpanded!()
}
pub mod load {
use super::*;
#[allow(unused_variables)]
pub fn expand<C: CubePrimitive, V: CubePrimitive>(
scope: &mut Scope,
mat: MatrixExpand<C>,
value: SliceExpand<V, ReadOnly>,
stride: ExpandElementTyped<u32>,
) {
let stride: ExpandElement = stride.into();
assert_ne!(
mat.ident,
MatrixIdent::Accumulator,
"Loading accumulator requires explicit layout. Use `load_with_layout` instead."
);
let (value, offset) = value.__to_raw_parts();
scope.register(Instruction::new(
ir::CoopMma::Load {
value,
stride: *stride,
offset,
layout: None,
},
*mat.elem,
));
}
}
#[allow(unused_variables)]
pub fn load_with_layout<C: CubePrimitive, V: CubePrimitive>(
mat: &Matrix<C>,
value: &Slice<V>,
stride: u32,
layout: MatrixLayout,
) {
unexpanded!()
}
pub mod load_with_layout {
use super::*;
#[allow(unused_variables)]
pub fn expand<C: CubeType, V: CubePrimitive>(
scope: &mut Scope,
mat: MatrixExpand<C>,
value: SliceExpand<V, ReadOnly>,
stride: ExpandElementTyped<u32>,
layout: MatrixLayout,
) {
let stride: ExpandElement = stride.into();
let (value, offset) = value.__to_raw_parts();
scope.register(Instruction::new(
ir::CoopMma::Load {
value,
stride: *stride,
offset,
layout: Some(layout),
},
*mat.elem,
));
}
}
#[allow(unused_variables)]
pub fn store<C: CubePrimitive, O: CubePrimitive>(
output: &mut SliceMut<O>,
mat: &Matrix<C>,
stride: u32,
layout: MatrixLayout,
) {
unexpanded!()
}
pub mod store {
use crate::prelude::ReadWrite;
use super::*;
#[allow(unused_variables)]
pub fn expand<C: CubePrimitive, O: CubePrimitive>(
scope: &mut Scope,
output: SliceExpand<O, ReadWrite>,
mat: MatrixExpand<C>,
stride: ExpandElementTyped<u32>,
layout: MatrixLayout,
) {
let stride: ExpandElement = stride.into();
let (output, offset) = output.__to_raw_parts();
scope.register(Instruction::new(
ir::CoopMma::Store {
mat: *mat.elem,
offset,
stride: *stride,
layout,
},
output,
));
}
}
#[allow(unused_variables)]
pub fn execute<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
mat_a: &Matrix<A>,
mat_b: &Matrix<B>,
mat_c: &Matrix<C>,
mat_d: &Matrix<D>,
) {
unexpanded!()
}
pub mod execute {
use super::*;
pub fn expand<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
scope: &mut Scope,
mat_a: MatrixExpand<A>,
mat_b: MatrixExpand<B>,
mat_c: MatrixExpand<C>,
mat_d: MatrixExpand<D>,
) {
scope.register(Instruction::new(
ir::CoopMma::Execute {
mat_a: *mat_a.elem,
mat_b: *mat_b.elem,
mat_c: *mat_c.elem,
},
*mat_d.elem,
));
}
}
#[allow(unused_variables)]
pub fn cast<C: CubePrimitive, O: CubePrimitive>(input: &Matrix<C>) -> Matrix<O> {
unexpanded!()
}
pub mod cast {
use super::*;
#[allow(unused_variables)]
pub fn expand<C: CubePrimitive, O: CubePrimitive>(
scope: &mut Scope,
input: MatrixExpand<C>,
) -> MatrixExpand<O> {
let ident = input.ident;
if core::any::TypeId::of::<C>() == core::any::TypeId::of::<O>() {
return MatrixExpand {
elem: input.elem,
ident,
_c: PhantomData,
};
}
let input = *input.elem;
let input_mat = match input.kind {
ir::VariableKind::Matrix { mat, .. } => mat,
_ => unreachable!(),
};
let elem = O::as_type(scope);
let elem = scope.create_matrix(ir::Matrix::new(
ident,
input_mat.m,
input_mat.n,
input_mat.k,
elem,
MatrixLayout::Undefined,
));
let output = MatrixExpand {
ident,
elem,
_c: PhantomData,
};
scope.register(Instruction::new(ir::CoopMma::Cast { input }, *output.elem));
output
}
}
impl CubeType for MatrixLayout {
type ExpandType = Self;
}
impl IntoMut for MatrixLayout {
fn into_mut(self, _scope: &mut crate::ir::Scope) -> Self {
self
}
}
impl CubeDebug for MatrixLayout {}