cubek-matmul 0.2.0

CubeK: Matrix Multiplication Kernels
Documentation
use cubecl::{
    Runtime, VectorizationError,
    client::ComputeClient,
    ir::VectorSize,
    tensor_vector_size_parallel,
    zspace::{Shape, Strides},
};
use cubek_std::MatrixLayout;

use std::fmt::Debug;

use crate::definition::error::MatmulSetupError;

#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy)]
/// Vector size used for each tensor in global memory accesses.
/// Represents the number of elements processed per SIMD load/store.
pub struct MatmulVectorSizes {
    pub lhs: VectorSize,
    pub rhs: VectorSize,
    pub out: VectorSize,
}

#[derive(Clone, Debug)]
/// Candidate vector sizes supported for each tensor.
///
/// These lists begin with compiler-supported sizes and are progressively
/// filtered based on problem shape divisibility and hardware constraints.
pub struct AvailableVectorSizes {
    pub lhs: Vec<VectorSize>,
    pub rhs: Vec<VectorSize>,
    pub out: Vec<VectorSize>,
}

impl AvailableVectorSizes {
    pub fn from_type_size_tma<R: Runtime>(client: &ComputeClient<R>, elem_out: usize) -> Self {
        // TMA requires vector size 1 for inputs
        AvailableVectorSizes {
            lhs: vec![1],
            rhs: vec![1],
            out: client.io_optimized_vector_sizes(elem_out).collect(),
        }
    }

    pub fn from_type_sizes<R: Runtime>(
        client: &ComputeClient<R>,
        elem_lhs: usize,
        elem_rhs: usize,
        elem_out: usize,
    ) -> Self {
        AvailableVectorSizes {
            lhs: client.io_optimized_vector_sizes(elem_lhs).collect(),
            rhs: client.io_optimized_vector_sizes(elem_rhs).collect(),
            out: client.io_optimized_vector_sizes(elem_out).collect(),
        }
    }

    /// Filter available vector sizes considering tensor shapes and strides for Lhs
    pub fn filter_lhs_with_tensor(
        self,
        strides: &Strides,
        shape: &Shape,
        layout: MatrixLayout,
    ) -> Self {
        let rank = strides.len();

        let target = tensor_vector_size_parallel(
            self.lhs.iter().copied(),
            shape,
            strides,
            match layout {
                MatrixLayout::RowMajor => rank - 1,
                MatrixLayout::ColMajor => rank - 2,
            },
        );

        self.filter_lhs(move |x| *x == target)
    }

    /// Filter available vector sizes considering tensor shapes and strides for Rhs
    pub fn filter_rhs_with_tensor(
        self,
        strides: &Strides,
        shape: &Shape,
        layout: MatrixLayout,
    ) -> Self {
        let rank = strides.len();

        let target = tensor_vector_size_parallel(
            self.rhs.iter().copied(),
            shape,
            strides,
            match layout {
                MatrixLayout::RowMajor => rank - 1,
                MatrixLayout::ColMajor => rank - 2,
            },
        );

        self.filter_rhs(move |x| *x == target)
    }

    /// Filter available vector sizes considering tensor shapes and strides for output
    pub fn filter_out_with_tensor(self, strides: &Strides, shape: &Shape) -> Self {
        let rank = strides.len();

        let target =
            tensor_vector_size_parallel(self.out.iter().copied(), shape, strides, rank - 1);

        self.filter_out(move |x| *x == target)
    }

    /// Filter available vector sizes for Lhs
    pub fn filter_lhs<F>(self, pred: F) -> Self
    where
        F: FnMut(&usize) -> bool,
    {
        Self {
            lhs: self.lhs.iter().copied().filter(pred).collect(),
            rhs: self.rhs,
            out: self.out,
        }
    }

    /// Filter available vector sizes for Rhs
    pub fn filter_rhs<F>(self, pred: F) -> Self
    where
        F: FnMut(&usize) -> bool,
    {
        Self {
            lhs: self.lhs,
            rhs: self.rhs.iter().copied().filter(pred).collect(),
            out: self.out,
        }
    }

    /// Filter available vector sizes for output
    pub fn filter_out<F>(self, pred: F) -> Self
    where
        F: FnMut(&usize) -> bool,
    {
        Self {
            lhs: self.lhs,
            rhs: self.rhs,
            out: self.out.iter().copied().filter(pred).collect(),
        }
    }

    /// Pick the largest remaining vector size for each tensor
    pub fn pick_max(self) -> Result<MatmulVectorSizes, MatmulSetupError> {
        let pick = |v: Vec<usize>| {
            v.into_iter().max().ok_or(MatmulSetupError::Vectorization(
                VectorizationError::NoValidVectorization,
            ))
        };

        Ok(MatmulVectorSizes {
            lhs: pick(self.lhs)?,
            rhs: pick(self.rhs)?,
            out: pick(self.out)?,
        })
    }
}