use std::cmp::max;
use crate::{
components::global::memory::ViewDirection,
definition::{MatmulGlobalElems, MatmulSetupError},
};
use cubecl::{
prelude::*,
quant::scheme::QuantScheme,
zspace::{Shape, Strides},
};
use cubek_std::{MatmulProblemSize, MatrixLayout, StageIdent};
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug)]
pub struct MatmulProblem {
pub m: usize,
pub n: usize,
pub k: usize,
pub lhs_batches: Shape,
pub rhs_batches: Shape,
pub out_batches: Shape,
pub lhs_shape: Shape,
pub rhs_shape: Shape,
pub out_shape: Shape,
pub lhs_strides: Strides,
pub rhs_strides: Strides,
pub out_strides: Strides,
pub lhs_layout: MatrixLayout,
pub rhs_layout: MatrixLayout,
pub out_layout: MatrixLayout,
pub lhs_scheme: Option<QuantScheme>,
pub rhs_scheme: Option<QuantScheme>,
pub global_dtypes: MatmulGlobalElems,
pub address_type: AddressType,
}
impl MatmulProblem {
#[allow(clippy::too_many_arguments)]
pub fn from_shapes_and_strides(
lhs_shape: Shape,
rhs_shape: Shape,
out_shape: Shape,
lhs_strides: Strides,
rhs_strides: Strides,
out_strides: Strides,
global_dtypes: MatmulGlobalElems,
address_type: AddressType,
lhs_scheme: Option<&QuantScheme>,
rhs_scheme: Option<&QuantScheme>,
) -> Result<Self, MatmulSetupError> {
let rank = out_shape.len();
let lhs_layout =
MatrixLayout::from_shape_and_strides(&lhs_shape, &lhs_strides, lhs_scheme)?;
let rhs_layout =
MatrixLayout::from_shape_and_strides(&rhs_shape, &rhs_strides, rhs_scheme)?;
let out_layout = MatrixLayout::from_shape_and_strides(&out_shape, &out_strides, None)?;
Ok(Self {
m: lhs_shape[rank - 2],
n: rhs_shape[rank - 1],
k: lhs_shape[rank - 1],
lhs_batches: lhs_shape[..lhs_shape.len() - 2].into(),
rhs_batches: rhs_shape[..rhs_shape.len() - 2].into(),
out_batches: out_shape[..out_shape.len() - 2].into(),
lhs_shape,
rhs_shape,
out_shape,
lhs_strides,
rhs_strides,
out_strides,
lhs_layout,
rhs_layout,
out_layout,
lhs_scheme: lhs_scheme.copied(),
rhs_scheme: rhs_scheme.copied(),
global_dtypes,
address_type,
})
}
#[allow(clippy::too_many_arguments)]
pub fn from_parameters(
m: usize,
n: usize,
k: usize,
lhs_batches: Shape,
rhs_batches: Shape,
lhs_layout: MatrixLayout,
rhs_layout: MatrixLayout,
out_layout: MatrixLayout,
lhs_scheme: Option<&QuantScheme>,
rhs_scheme: Option<&QuantScheme>,
global_dtypes: MatmulGlobalElems,
address_type: AddressType,
) -> Self {
fn broadcast_batches(lhs: &[usize], rhs: &[usize]) -> Option<Shape> {
let max_len = max(lhs.len(), rhs.len());
let lhs_padded = std::iter::repeat_n(1, max_len - lhs.len()).chain(lhs.iter().cloned());
let rhs_padded = std::iter::repeat_n(1, max_len - rhs.len()).chain(rhs.iter().cloned());
lhs_padded
.zip(rhs_padded)
.map(|(l, r)| {
if l != r && l != 1 && r != 1 {
None
} else {
Some(max(l, r))
}
})
.collect()
}
let out_batches =
broadcast_batches(&lhs_batches, &rhs_batches).expect("Batches should match");
let lhs_shape: Shape = lhs_batches.iter().cloned().chain([m, k]).collect();
let rhs_shape: Shape = rhs_batches.iter().cloned().chain([k, n]).collect();
let out_shape: Shape = out_batches.iter().cloned().chain([m, n]).collect();
let lhs_strides = lhs_layout.to_strides(&lhs_shape);
let rhs_strides = rhs_layout.to_strides(&rhs_shape);
let out_strides = out_layout.to_strides(&out_shape);
Self {
m,
n,
k,
lhs_batches,
rhs_batches,
out_batches,
lhs_shape,
rhs_shape,
out_shape,
lhs_strides,
rhs_strides,
out_strides,
lhs_layout,
rhs_layout,
out_layout,
lhs_scheme: lhs_scheme.copied(),
rhs_scheme: rhs_scheme.copied(),
global_dtypes,
address_type,
}
}
pub fn num_batches(&self) -> usize {
self.out_batches.iter().product()
}
}
#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize)]
pub enum MatmulKind {
General,
MatVec,
VecMat,
ScalarVec,
VecScalar,
InnerProduct,
OuterProduct,
ScalarProduct,
}
impl From<MatmulProblemSize> for MatmulKind {
fn from(matmul_size: MatmulProblemSize) -> Self {
enum DimKind {
Scalar,
Vector,
}
impl From<u32> for DimKind {
fn from(x: u32) -> Self {
match x {
1 => DimKind::Scalar,
_ => DimKind::Vector,
}
}
}
use DimKind::*;
match (
matmul_size.m().into(),
matmul_size.n().into(),
matmul_size.k().into(),
) {
(Scalar, Scalar, Scalar) => MatmulKind::ScalarProduct,
(Scalar, Scalar, Vector) => MatmulKind::InnerProduct,
(Scalar, Vector, Scalar) => MatmulKind::ScalarVec,
(Scalar, Vector, Vector) => MatmulKind::VecMat,
(Vector, Scalar, Scalar) => MatmulKind::VecScalar,
(Vector, Scalar, Vector) => MatmulKind::MatVec,
(Vector, Vector, Scalar) => MatmulKind::OuterProduct,
(Vector, Vector, Vector) => MatmulKind::General,
}
}
}
impl From<MatmulProblem> for MatmulProblemSize {
fn from(problem: MatmulProblem) -> Self {
MatmulProblemSize::new(problem.m as u32, problem.n as u32, problem.k as u32)
}
}
impl From<MatmulProblem> for MatmulKind {
fn from(problem: MatmulProblem) -> Self {
MatmulProblemSize::new(problem.m as u32, problem.n as u32, problem.k as u32).into()
}
}
impl From<&MatmulProblem> for MatmulKind {
fn from(problem: &MatmulProblem) -> Self {
MatmulProblemSize::new(problem.m as u32, problem.n as u32, problem.k as u32).into()
}
}
#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
pub enum MatmulIdent {
Lhs,
Rhs,
Out,
}
impl MatmulIdent {
pub fn into_stage(self) -> StageIdent {
self.into()
}
pub fn view_direction(&self) -> ViewDirection {
match self {
MatmulIdent::Lhs => ViewDirection::Col,
MatmulIdent::Rhs => ViewDirection::Row,
MatmulIdent::Out => ViewDirection::None,
}
}
}
impl From<MatmulIdent> for StageIdent {
fn from(matmul_ident: MatmulIdent) -> Self {
match matmul_ident {
MatmulIdent::Lhs => StageIdent::Lhs,
MatmulIdent::Rhs => StageIdent::Rhs,
MatmulIdent::Out => StageIdent::Acc,
}
}
}
impl From<StageIdent> for MatmulIdent {
fn from(matmul_ident: StageIdent) -> Self {
match matmul_ident {
StageIdent::Lhs => MatmulIdent::Lhs,
StageIdent::Rhs => MatmulIdent::Rhs,
StageIdent::Acc => MatmulIdent::Out,
StageIdent::Out => MatmulIdent::Out,
}
}
}