cubecl-convolution 0.6.0

CubeCL Convolution Kernels Engine
Documentation
use cubecl_core as cubecl;
use cubecl_core::prelude::*;
use cubecl_std::{
    CubeOption, CubeOptionExpand, FastDivmod, FastDivmodArgs,
    tensor::r#virtual::{ReadWrite, VirtualTensor},
};

use crate::base::{Convolution, ConvolutionFamily, RuntimeArgs};

use cubecl_matmul::components::{
    Ident,
    global::{
        GlobalConfig,
        args::{MatmulArgs, TensorInput, TensorInputIdent, TensorOutput},
    },
};

type Input<Args, EI> = <Args as MatmulArgs>::Input<EI>;
type Output<Args, EO> = <Args as MatmulArgs>::Output<EO>;

#[cube(launch_unchecked)]
pub(crate) fn implicit_conv<
    Args: MatmulArgs,
    EI: Numeric,
    ES: Numeric,
    EA: Numeric,
    EO: Numeric,
    GMM: ConvolutionFamily,
>(
    inputs: &Input<Args, EI>,
    bias: &CubeOption<Tensor<Line<EO>>>,
    output: &mut Output<Args, EO>,
    runtime_args: RuntimeArgs,
    #[comptime] config: GMM::Config,
) {
    let mut state = Args::init_state(inputs, output);

    let lhs = TensorInput::<EI, EO, Args>::new(&state, TensorInputIdent::Lhs);
    let rhs = TensorInput::<EI, EO, Args>::new(&state, TensorInputIdent::Rhs);
    let mut out = TensorOutput::<EI, EO, Args>::new(&mut state);

    let lhs = VirtualTensor::<EI>::new::<TensorInput<EI, EO, Args>>(&lhs);
    let rhs = VirtualTensor::<EI>::new::<TensorInput<EI, EO, Args>>(&rhs);
    let out = VirtualTensor::<EO, ReadWrite>::new::<TensorOutput<EI, EO, Args>>(&mut out);

    let x_offset = CUBE_POS_X * config.tiling_scheme().elements_in_stage_m();
    let y_offset = CUBE_POS_Y * config.tiling_scheme().elements_in_stage_n();
    let k_range = (0, runtime_args.size_k);

    let bias = match bias {
        CubeOption::Some(bias) => {
            CubeOption::new_Some(VirtualTensor::<EO>::new::<Tensor<Line<EO>>>(bias))
        }
        CubeOption::None => CubeOption::new_None(),
    };

    GMM::Convolution::<(EI, ES, EA, EO)>::execute(
        GMM::Convolution::<(EI, ES, EA, EO)>::init_lhs_loader(
            lhs,
            x_offset,
            k_range.0,
            &runtime_args,
            config,
        ),
        GMM::Convolution::<(EI, ES, EA, EO)>::init_rhs_loader(
            rhs,
            k_range.0,
            y_offset,
            &runtime_args,
            config,
        ),
        GMM::Convolution::<(EI, ES, EA, EO)>::init_bias_loader(bias, y_offset, config),
        GMM::Convolution::<(EI, ES, EA, EO)>::init_writer(out, x_offset, y_offset),
        &mut GMM::Convolution::<(EI, ES, EA, EO)>::init_accumulator(config),
        k_range,
        config,
    );
}

pub(crate) fn shape_divmod<'a, R: Runtime>(
    client: &ComputeClient<R::Server, R::Channel>,
    shape: &[usize],
) -> SequenceArg<'a, R, FastDivmod> {
    let shape = shape
        .iter()
        .map(|s| FastDivmodArgs::new(client, *s as u32))
        .collect();
    SequenceArg { values: shape }
}

pub mod config {
    use std::ops::Deref;

    use crate::{ConvGemmConfig, base::Dimensionality};
    use cubecl_matmul::components::{
        InputIdent, MatmulLineSizes, MatmulSetupError, MatrixLayout, TilingScheme,
        global::{
            GlobalConfig, PlaneRoleConfig, SpecializedLoadingSides, load::LoaderMode,
            multi_stage::EventLoadingMode,
        },
    };

    use super::*;

    #[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
    pub struct ConvolutionConfig<M: GlobalConfig> {
        matmul: M,
        kernel_size: [u32; 3],
        stride: [u32; 3],
        dilation: [u32; 3],
        padding: [i32; 3],
        dimensionality: Dimensionality,
        num_stages: u32,
    }

    impl<M: GlobalConfig> Deref for ConvolutionConfig<M> {
        type Target = M;

        fn deref(&self) -> &Self::Target {
            &self.matmul
        }
    }

    impl<M: GlobalConfig> GlobalConfig for ConvolutionConfig<M> {
        type StageConfig = M::StageConfig;

        fn stage_config(&self) -> Self::StageConfig {
            self.matmul.stage_config()
        }

        fn global_line_size<I: Into<Ident>>(&self, ident: I) -> u32 {
            self.matmul.global_line_size(ident)
        }

        fn matrix_layout<I: Into<Ident>>(&self, ident: I) -> MatrixLayout {
            self.matmul.matrix_layout(ident)
        }

        fn num_loading_planes<I: Into<Ident>>(&self, ident: I) -> u32 {
            self.matmul.num_loading_planes(ident)
        }

        fn plane_dim(&self) -> u32 {
            self.matmul.plane_dim()
        }

        fn check_row_bounds<I: Into<Ident>>(&self, ident: I) -> bool {
            self.matmul.check_row_bounds(ident)
        }

        fn check_col_bounds<I: Into<Ident>>(&self, ident: I) -> bool {
            self.matmul.check_col_bounds(ident)
        }

        fn check_k_bounds(&self) -> bool {
            self.matmul.check_k_bounds()
        }

        fn precompute_job(&self) -> bool {
            self.matmul.precompute_job()
        }

        fn num_stages(&self, _ident: InputIdent) -> u32 {
            self.num_stages
        }

        fn loader_mode(&self) -> LoaderMode {
            self.matmul.loader_mode()
        }

        fn tiling_scheme(&self) -> TilingScheme {
            self.matmul.tiling_scheme()
        }

        fn event_loading_mode(&self, ident: InputIdent) -> EventLoadingMode {
            self.matmul.event_loading_mode(ident)
        }

        fn plane_role_config(&self) -> PlaneRoleConfig {
            self.matmul.plane_role_config()
        }

        fn specialized_loading_sides(&self) -> SpecializedLoadingSides {
            self.matmul.specialized_loading_sides()
        }

        fn cube_dim(&self) -> CubeDim {
            CubeDim::new(self.plane_dim(), self.tiling_scheme().tiles_in_stage_m(), 1)
        }
    }

    impl<M: GlobalConfig> ConvGemmConfig for ConvolutionConfig<M> {
        fn kernel_size(&self, dim: u32) -> u32 {
            self.kernel_size[dim as usize]
        }

        fn dilation(&self, dim: u32) -> u32 {
            self.dilation[dim as usize]
        }

        fn stride(&self, dim: u32) -> u32 {
            self.stride[dim as usize]
        }

        fn padding(&self, dim: u32) -> i32 {
            self.padding[dim as usize]
        }

        fn dimensionality(&self) -> Dimensionality {
            self.dimensionality
        }

        fn line_sizes(&self) -> cubecl_matmul::components::MatmulLineSizes {
            MatmulLineSizes {
                lhs: self.global_line_size(Ident::Lhs) as u8,
                rhs: self.global_line_size(Ident::Rhs) as u8,
                out: self.global_line_size(Ident::Out) as u8,
            }
        }
    }

    impl<M: GlobalConfig> ConvolutionConfig<M> {
        #[allow(clippy::too_many_arguments)]
        pub fn new(
            matmul: M,
            kernel_size: &[u32],
            stride: &[u32],
            dilation: &[u32],
            padding: &[i32],
            dim: Dimensionality,
            num_stages: u32,
        ) -> Result<Self, MatmulSetupError> {
            let dims = kernel_size.len();

            let mut this = Self {
                matmul,
                kernel_size: [0; 3],
                stride: [0; 3],
                dilation: [0; 3],
                padding: [0; 3],
                dimensionality: dim,
                num_stages,
            };
            this.kernel_size[0..dims].copy_from_slice(kernel_size);
            this.stride[0..dims].copy_from_slice(stride);
            this.dilation[0..dims].copy_from_slice(dilation);
            this.padding[0..dims].copy_from_slice(padding);
            Ok(this)
        }

        pub fn to_matmul_config(self) -> M {
            self.matmul
        }
    }
}