cubek-convolution 0.2.0-pre.5

CubeK: Convolution Kernels
Documentation
use std::ops::Deref;

use cubecl::CubeDim;
use cubek_matmul::{
    components::global::{GlobalConfig, memory::GlobalMemoryConfig},
    definition::{MatmulSetupError, MatmulVectorSizes},
};
use std::{fmt::Debug, hash::Hash};

use super::*;

/// Convolution specific config, extends regular matmul `Config`.
pub trait ConvGemmConfig:
    Deref<Target: GlobalConfig> + Copy + Clone + Eq + PartialEq + Hash + Debug + Send + Sync + 'static
{
    type GlobalMatmulConfig: GlobalConfig;

    fn matmul_config(&self) -> Self::GlobalMatmulConfig;

    /// The size of the convolution kernel at `dim`
    fn params(&self) -> ConvolutionParams;
    fn operation(&self) -> ConvolutionOperation;
    fn vector_sizes(&self) -> MatmulVectorSizes;
    fn check_spatial_bounds(&self) -> bool;
    fn cube_dim(&self) -> CubeDim;
    fn lhs_global_memory_config(&self) -> GlobalMemoryConfig;
    fn rhs_global_memory_config(&self) -> GlobalMemoryConfig;
    fn out_global_memory_config(&self) -> GlobalMemoryConfig;
}

#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
pub struct ConvolutionConfig<M: GlobalConfig> {
    pub matmul: M,
    pub params: ConvolutionParams,
}

#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
pub struct ConvolutionParams {
    pub kernel_size: [u32; 3],
    pub stride: [u32; 3],
    pub dilation: [u32; 3],
    pub padding: [i32; 3],
    pub dimensionality: Dimensionality,
    pub operation: ConvolutionOperation,
}

impl ConvolutionParams {
    pub fn from_problem(problem: &ConvolutionProblem) -> Self {
        let dims = problem.dimensionality.num_dims();

        let mut params = ConvolutionParams {
            kernel_size: [0; 3],
            stride: [0; 3],
            dilation: [0; 3],
            padding: [0; 3],
            dimensionality: problem.dimensionality,
            operation: problem.operation,
        };
        params.kernel_size[0..dims].copy_from_slice(&problem.kernel_size);
        params.stride[0..dims].copy_from_slice(&problem.stride);
        params.dilation[0..dims].copy_from_slice(&problem.dilation);
        params.padding[0..dims].copy_from_slice(&problem.padding);
        params
    }
}

impl<M: GlobalConfig> Deref for ConvolutionConfig<M> {
    type Target = M;

    fn deref(&self) -> &Self::Target {
        &self.matmul
    }
}

impl<M: GlobalConfig> ConvGemmConfig for ConvolutionConfig<M> {
    type GlobalMatmulConfig = M;

    fn matmul_config(&self) -> Self::GlobalMatmulConfig {
        self.matmul
    }

    fn vector_sizes(&self) -> MatmulVectorSizes {
        self.matmul.global_vector_sizes()
    }

    fn cube_dim(&self) -> CubeDim {
        self.matmul.cube_dim()
    }

    fn check_spatial_bounds(&self) -> bool {
        let spatial_dims = self.params.dimensionality.num_dims();
        let mut has_padding = false;
        for i in 0..spatial_dims {
            has_padding |= self.params.padding[i] != 0;
        }
        has_padding
    }

    fn params(&self) -> ConvolutionParams {
        self.params
    }

    fn operation(&self) -> ConvolutionOperation {
        self.params.operation
    }

    fn lhs_global_memory_config(&self) -> GlobalMemoryConfig {
        self.matmul.lhs_reader_config().gmem_config
    }

    fn rhs_global_memory_config(&self) -> GlobalMemoryConfig {
        self.matmul.rhs_reader_config().gmem_config
    }

    fn out_global_memory_config(&self) -> GlobalMemoryConfig {
        self.matmul.writer_config().gmem_config
    }
}

impl<M: GlobalConfig> ConvolutionConfig<M> {
    #[allow(clippy::too_many_arguments)]
    pub fn new(
        matmul: M,
        kernel_size: &[u32],
        stride: &[u32],
        dilation: &[u32],
        padding: &[i32],
        dim: Dimensionality,
        operation: ConvolutionOperation,
    ) -> Result<Self, MatmulSetupError> {
        let dims = kernel_size.len();

        let mut params = ConvolutionParams {
            kernel_size: [0; 3],
            stride: [0; 3],
            dilation: [0; 3],
            padding: [0; 3],
            dimensionality: dim,
            operation,
        };
        params.kernel_size[0..dims].copy_from_slice(kernel_size);
        params.stride[0..dims].copy_from_slice(stride);
        params.dilation[0..dims].copy_from_slice(dilation);
        params.padding[0..dims].copy_from_slice(padding);
        Ok(Self { matmul, params })
    }
}