cubek-matmul 0.2.0

CubeK: Matrix Multiplication Kernels
Documentation
use std::fmt::Display;

use cubek_std::cube_count::CubeCountPlan;

use crate::{
    components::batch::{
        BatchMatmulFamily,
        naive::{NaiveBatchMatmulFamily, NaiveBlueprint},
    },
    definition::{MatmulAvailabilityError, MatmulElems, MatmulProblem, MatmulSetupError},
    routines::{BlueprintStrategy, DeviceSettings, ExpandInfo, LaunchInfo, Routine},
};

pub struct NaiveRoutine {}

#[derive(Default, Clone)]
pub struct NaiveStrategy {}

impl Display for NaiveStrategy {
    fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        Ok(())
    }
}

impl From<()> for NaiveStrategy {
    fn from(_value: ()) -> Self {
        Self {}
    }
}

impl Routine<()> for NaiveRoutine {
    type Strategy = NaiveStrategy;
    type BatchMatmul = NaiveBatchMatmulFamily;
    type Blueprint = <Self::BatchMatmul as BatchMatmulFamily<()>>::Blueprint;
    type Config = <Self::BatchMatmul as BatchMatmulFamily<()>>::Config;

    fn expand_blueprint<R: cubecl::Runtime>(
        problem: &MatmulProblem,
        device_settings: &DeviceSettings<R>,
        _strategy: &BlueprintStrategy<(), Self>,
    ) -> Result<ExpandInfo<Self::Blueprint>, MatmulSetupError> {
        let dtypes = MatmulElems::from_globals(&problem.global_dtypes);
        let blueprint = NaiveBlueprint {
            vector_size_out: device_settings.vector_sizes.out as u32,
            dtypes: dtypes.clone(),
        };
        Ok(ExpandInfo { blueprint, dtypes })
    }

    fn prepare<R: cubecl::Runtime>(
        problem: &MatmulProblem,
        device_settings: &DeviceSettings<R>,
        expand_info: ExpandInfo<Self::Blueprint>,
    ) -> Result<LaunchInfo<Self::Blueprint>, MatmulSetupError> {
        let ExpandInfo { blueprint, dtypes } = expand_info;

        Self::validate_blueprint(
            &device_settings.client,
            &blueprint,
            problem,
            &dtypes,
            &device_settings.vector_sizes,
        )?;

        let cube_dim = Self::BatchMatmul::cubedim_resource(
            &blueprint,
            &dtypes,
            &device_settings.vector_sizes,
        )?
        .to_cube_dim(device_settings.plane_dim)?;

        Ok(LaunchInfo {
            blueprint,
            dtypes,
            cube_dim,
            cube_count_plan: simple_cube_count(
                &problem.lhs_shape,
                &problem.rhs_shape,
                &problem.out_shape,
                cube_dim.x,
                cube_dim.y,
            )?,
            address_type: problem.address_type,
            vector_sizes: device_settings.vector_sizes,
        })
    }
}

#[allow(clippy::result_large_err)]
fn simple_cube_count(
    lhs_shape: &[usize],
    rhs_shape: &[usize],
    output_shape: &[usize],
    cube_dim_x: u32,
    cube_dim_y: u32,
) -> Result<CubeCountPlan, MatmulSetupError> {
    let ndims = lhs_shape.len();
    let m = lhs_shape[ndims - 2];
    let n = rhs_shape[ndims - 1];

    let m_cubes = f32::ceil(m as f32 / cube_dim_x as f32) as u32;
    let n_cubes = f32::ceil(n as f32 / cube_dim_y as f32) as u32;
    let mut batch_cubes = 1u32;

    #[allow(clippy::needless_range_loop)]
    for i in 0..ndims - 2 {
        batch_cubes *= output_shape[i] as u32;
    }

    let cube_count_plan = CubeCountPlan::new_from_problem((m_cubes, n_cubes, batch_cubes).into());
    let max_cube_count = u16::MAX as u32;

    if m_cubes > max_cube_count || n_cubes > max_cube_count || batch_cubes > max_cube_count {
        return Err(MatmulSetupError::Unavailable(
            MatmulAvailabilityError::CubeCountTooBig(cube_count_plan.resolve()),
        ));
    }

    Ok(cube_count_plan)
}