cubek-matmul 0.2.0

CubeK: Matrix Multiplication Kernels
Documentation
use std::cmp::max;

use crate::{
    components::global::memory::ViewDirection,
    definition::{MatmulGlobalElems, MatmulSetupError},
};
use cubecl::{
    prelude::*,
    quant::scheme::QuantScheme,
    zspace::{Shape, Strides},
};
use cubek_std::{MatmulProblemSize, MatrixLayout, StageIdent};
use serde::{Deserialize, Serialize};

#[derive(Clone, Debug)]
/// Description of a matmul problem to solve, regardless of actual data
pub struct MatmulProblem {
    /// Number of rows in the output matrix
    pub m: usize,
    /// Number of columns in the output matrix
    pub n: usize,
    /// Reduction dimension
    pub k: usize,

    /// Batch shape for Lhs tensor
    pub lhs_batches: Shape,
    /// Batch shape for Rhs tensor
    pub rhs_batches: Shape,
    /// Batch shape for Out tensor
    pub out_batches: Shape,

    /// Shape of Lhs tensor
    pub lhs_shape: Shape,
    /// Shape of Rhs tensor
    pub rhs_shape: Shape,
    /// Shape of Out tensor
    pub out_shape: Shape,

    /// Strides for the Lhs tensor
    pub lhs_strides: Strides,
    /// Strides for the Rhs tensor
    pub rhs_strides: Strides,
    /// Strides for the Out tensor
    pub out_strides: Strides,

    /// Memory layout of the Lhs matrix.
    pub lhs_layout: MatrixLayout,
    /// Memory layout of the Rhs matrix.
    pub rhs_layout: MatrixLayout,
    /// Memory layout of the Out matrix.
    pub out_layout: MatrixLayout,

    /// Quantization scheme of lhs, if present
    pub lhs_scheme: Option<QuantScheme>,
    /// Quantization scheme of rhs, if present
    pub rhs_scheme: Option<QuantScheme>,

    pub global_dtypes: MatmulGlobalElems,

    /// Address type, defined by the max of each handle's `required_address_type`
    pub address_type: AddressType,
}

impl MatmulProblem {
    #[allow(clippy::too_many_arguments)]
    pub fn from_shapes_and_strides(
        lhs_shape: Shape,
        rhs_shape: Shape,
        out_shape: Shape,
        lhs_strides: Strides,
        rhs_strides: Strides,
        out_strides: Strides,
        global_dtypes: MatmulGlobalElems,
        address_type: AddressType,
        lhs_scheme: Option<&QuantScheme>,
        rhs_scheme: Option<&QuantScheme>,
    ) -> Result<Self, MatmulSetupError> {
        let rank = out_shape.len();
        let lhs_layout =
            MatrixLayout::from_shape_and_strides(&lhs_shape, &lhs_strides, lhs_scheme)?;
        let rhs_layout =
            MatrixLayout::from_shape_and_strides(&rhs_shape, &rhs_strides, rhs_scheme)?;
        let out_layout = MatrixLayout::from_shape_and_strides(&out_shape, &out_strides, None)?;

        Ok(Self {
            m: lhs_shape[rank - 2],
            n: rhs_shape[rank - 1],
            k: lhs_shape[rank - 1],
            lhs_batches: lhs_shape[..lhs_shape.len() - 2].into(),
            rhs_batches: rhs_shape[..rhs_shape.len() - 2].into(),
            out_batches: out_shape[..out_shape.len() - 2].into(),
            lhs_shape,
            rhs_shape,
            out_shape,
            lhs_strides,
            rhs_strides,
            out_strides,
            lhs_layout,
            rhs_layout,
            out_layout,
            lhs_scheme: lhs_scheme.copied(),
            rhs_scheme: rhs_scheme.copied(),
            global_dtypes,
            address_type,
        })
    }

    #[allow(clippy::too_many_arguments)]
    pub fn from_parameters(
        m: usize,
        n: usize,
        k: usize,
        lhs_batches: Shape,
        rhs_batches: Shape,
        lhs_layout: MatrixLayout,
        rhs_layout: MatrixLayout,
        out_layout: MatrixLayout,
        lhs_scheme: Option<&QuantScheme>,
        rhs_scheme: Option<&QuantScheme>,
        global_dtypes: MatmulGlobalElems,
        address_type: AddressType,
    ) -> Self {
        fn broadcast_batches(lhs: &[usize], rhs: &[usize]) -> Option<Shape> {
            let max_len = max(lhs.len(), rhs.len());
            let lhs_padded = std::iter::repeat_n(1, max_len - lhs.len()).chain(lhs.iter().cloned());

            let rhs_padded = std::iter::repeat_n(1, max_len - rhs.len()).chain(rhs.iter().cloned());

            lhs_padded
                .zip(rhs_padded)
                .map(|(l, r)| {
                    if l != r && l != 1 && r != 1 {
                        None
                    } else {
                        Some(max(l, r))
                    }
                })
                .collect()
        }

        let out_batches =
            broadcast_batches(&lhs_batches, &rhs_batches).expect("Batches should match");

        let lhs_shape: Shape = lhs_batches.iter().cloned().chain([m, k]).collect();
        let rhs_shape: Shape = rhs_batches.iter().cloned().chain([k, n]).collect();
        let out_shape: Shape = out_batches.iter().cloned().chain([m, n]).collect();

        let lhs_strides = lhs_layout.to_strides(&lhs_shape);
        let rhs_strides = rhs_layout.to_strides(&rhs_shape);
        let out_strides = out_layout.to_strides(&out_shape);

        Self {
            m,
            n,
            k,
            lhs_batches,
            rhs_batches,
            out_batches,
            lhs_shape,
            rhs_shape,
            out_shape,
            lhs_strides,
            rhs_strides,
            out_strides,
            lhs_layout,
            rhs_layout,
            out_layout,
            lhs_scheme: lhs_scheme.copied(),
            rhs_scheme: rhs_scheme.copied(),
            global_dtypes,
            address_type,
        }
    }

    /// Returns the total number of batches of the output
    pub fn num_batches(&self) -> usize {
        self.out_batches.iter().product()
    }
}

#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize)]
/// Interpretation of matrix multiplication based on input shapes.
pub enum MatmulKind {
    /// (M, K) @ (K, N) → (M, N), with M, K, N > 1
    General,

    /// (M, K) @ (K, 1) → (M, 1)
    MatVec,

    /// (1, K) @ (K, N) → (1, N)
    VecMat,

    /// (1, 1) @ (1, N) → (1, N)
    ScalarVec,

    /// (M, 1) @ (1, 1) → (M, 1)
    VecScalar,

    /// (1, K) @ (K, 1) → (1, 1)
    InnerProduct,

    /// (M, 1) @ (1, N) → (M, N)
    OuterProduct,

    /// (1, 1) @ (1, 1) → (1, 1)
    ScalarProduct,
}

impl From<MatmulProblemSize> for MatmulKind {
    fn from(matmul_size: MatmulProblemSize) -> Self {
        enum DimKind {
            Scalar,
            Vector,
        }

        impl From<u32> for DimKind {
            fn from(x: u32) -> Self {
                match x {
                    1 => DimKind::Scalar,
                    _ => DimKind::Vector,
                }
            }
        }

        use DimKind::*;
        match (
            matmul_size.m().into(),
            matmul_size.n().into(),
            matmul_size.k().into(),
        ) {
            (Scalar, Scalar, Scalar) => MatmulKind::ScalarProduct,
            (Scalar, Scalar, Vector) => MatmulKind::InnerProduct,
            (Scalar, Vector, Scalar) => MatmulKind::ScalarVec,
            (Scalar, Vector, Vector) => MatmulKind::VecMat,
            (Vector, Scalar, Scalar) => MatmulKind::VecScalar,
            (Vector, Scalar, Vector) => MatmulKind::MatVec,
            (Vector, Vector, Scalar) => MatmulKind::OuterProduct,
            (Vector, Vector, Vector) => MatmulKind::General,
        }
    }
}

impl From<MatmulProblem> for MatmulProblemSize {
    fn from(problem: MatmulProblem) -> Self {
        MatmulProblemSize::new(problem.m as u32, problem.n as u32, problem.k as u32)
    }
}

impl From<MatmulProblem> for MatmulKind {
    fn from(problem: MatmulProblem) -> Self {
        MatmulProblemSize::new(problem.m as u32, problem.n as u32, problem.k as u32).into()
    }
}

impl From<&MatmulProblem> for MatmulKind {
    fn from(problem: &MatmulProblem) -> Self {
        MatmulProblemSize::new(problem.m as u32, problem.n as u32, problem.k as u32).into()
    }
}

#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
/// Identifier for all three tensors in a matmul
///
/// Useful to specialize some functions depending on the tensor
pub enum MatmulIdent {
    Lhs,
    Rhs,
    Out,
}

impl MatmulIdent {
    /// Equivalent to into, but type inference works better within Cube functions
    pub fn into_stage(self) -> StageIdent {
        self.into()
    }

    pub fn view_direction(&self) -> ViewDirection {
        match self {
            MatmulIdent::Lhs => ViewDirection::Col,
            MatmulIdent::Rhs => ViewDirection::Row,
            MatmulIdent::Out => ViewDirection::None,
        }
    }
}

impl From<MatmulIdent> for StageIdent {
    fn from(matmul_ident: MatmulIdent) -> Self {
        match matmul_ident {
            MatmulIdent::Lhs => StageIdent::Lhs,
            MatmulIdent::Rhs => StageIdent::Rhs,
            MatmulIdent::Out => StageIdent::Acc,
        }
    }
}

impl From<StageIdent> for MatmulIdent {
    fn from(matmul_ident: StageIdent) -> Self {
        match matmul_ident {
            StageIdent::Lhs => MatmulIdent::Lhs,
            StageIdent::Rhs => MatmulIdent::Rhs,
            StageIdent::Acc => MatmulIdent::Out,
            StageIdent::Out => MatmulIdent::Out,
        }
    }
}