use std::ops::Deref;
use cubecl::CubeDim;
use cubek_matmul::{
components::global::{GlobalConfig, memory::GlobalMemoryConfig},
definition::{MatmulLineSizes, MatmulSetupError},
};
use std::fmt::Debug;
use std::hash::Hash;
use super::*;
pub trait ConvGemmConfig:
Deref<Target: GlobalConfig> + Copy + Clone + Eq + PartialEq + Hash + Debug + Send + Sync + 'static
{
type GlobalMatmulConfig: GlobalConfig;
fn matmul_config(&self) -> Self::GlobalMatmulConfig;
fn params(&self) -> ConvolutionParams;
fn operation(&self) -> ConvolutionOperation;
fn line_sizes(&self) -> MatmulLineSizes;
fn check_spatial_bounds(&self) -> bool;
fn cube_dim(&self) -> CubeDim;
fn lhs_global_memory_config(&self) -> GlobalMemoryConfig;
fn rhs_global_memory_config(&self) -> GlobalMemoryConfig;
fn out_global_memory_config(&self) -> GlobalMemoryConfig;
}
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
pub struct ConvolutionConfig<M: GlobalConfig> {
pub matmul: M,
pub params: ConvolutionParams,
}
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
pub struct ConvolutionParams {
pub kernel_size: [u32; 3],
pub stride: [u32; 3],
pub dilation: [u32; 3],
pub padding: [i32; 3],
pub dimensionality: Dimensionality,
pub operation: ConvolutionOperation,
}
impl ConvolutionParams {
pub fn from_problem(problem: &ConvolutionProblem) -> Self {
let dims = problem.dimensionality.num_dims();
let mut params = ConvolutionParams {
kernel_size: [0; 3],
stride: [0; 3],
dilation: [0; 3],
padding: [0; 3],
dimensionality: problem.dimensionality,
operation: problem.operation,
};
params.kernel_size[0..dims].copy_from_slice(&problem.kernel_size);
params.stride[0..dims].copy_from_slice(&problem.stride);
params.dilation[0..dims].copy_from_slice(&problem.dilation);
params.padding[0..dims].copy_from_slice(&problem.padding);
params
}
}
impl<M: GlobalConfig> Deref for ConvolutionConfig<M> {
type Target = M;
fn deref(&self) -> &Self::Target {
&self.matmul
}
}
impl<M: GlobalConfig> ConvGemmConfig for ConvolutionConfig<M> {
type GlobalMatmulConfig = M;
fn matmul_config(&self) -> Self::GlobalMatmulConfig {
self.matmul
}
fn line_sizes(&self) -> MatmulLineSizes {
self.matmul.global_line_sizes()
}
fn cube_dim(&self) -> CubeDim {
self.matmul.cube_dim()
}
fn check_spatial_bounds(&self) -> bool {
let spatial_dims = self.params.dimensionality.num_dims();
let mut has_padding = false;
for i in 0..spatial_dims {
has_padding |= self.params.padding[i] != 0;
}
has_padding
}
fn params(&self) -> ConvolutionParams {
self.params
}
fn operation(&self) -> ConvolutionOperation {
self.params.operation
}
fn lhs_global_memory_config(&self) -> GlobalMemoryConfig {
self.matmul.lhs_reader_config().gmem_config
}
fn rhs_global_memory_config(&self) -> GlobalMemoryConfig {
self.matmul.rhs_reader_config().gmem_config
}
fn out_global_memory_config(&self) -> GlobalMemoryConfig {
self.matmul.writer_config().gmem_config
}
}
impl<M: GlobalConfig> ConvolutionConfig<M> {
#[allow(clippy::too_many_arguments)]
pub fn new(
matmul: M,
kernel_size: &[u32],
stride: &[u32],
dilation: &[u32],
padding: &[i32],
dim: Dimensionality,
operation: ConvolutionOperation,
) -> Result<Self, MatmulSetupError> {
let dims = kernel_size.len();
let mut params = ConvolutionParams {
kernel_size: [0; 3],
stride: [0; 3],
dilation: [0; 3],
padding: [0; 3],
dimensionality: dim,
operation,
};
params.kernel_size[0..dims].copy_from_slice(kernel_size);
params.stride[0..dims].copy_from_slice(stride);
params.dilation[0..dims].copy_from_slice(dilation);
params.padding[0..dims].copy_from_slice(padding);
Ok(Self { matmul, params })
}
}