use cubecl::prelude::*;
use cubecl_core as cubecl;
use cubecl_matmul::components::{
AccG, AvailableLineSizes, LhsG, MatmulLineSizes, MatmulPrecision, MatmulSelection,
MatmulSetupError, RhsG,
global::GlobalWriter,
stage::{ContiguousTilingLayout, RowMajorTilingOrder},
};
use cubecl_std::{
CubeOption,
tensor::{View, layout::Coords2d},
};
use crate::{
components::{ConvGemmConfig, ConvolutionProblem, global::entry_point::ConvolutionLaunch},
kernels::layered::selector::RuntimeArgs,
};
pub type ConvTilingLayout = ContiguousTilingLayout<RowMajorTilingOrder>;
pub type GlobalConfig<F> = <F as GlobalConvolutionFamily>::Config;
pub trait GlobalConvolutionFamily: ConvolutionLaunch<Self::Config> + 'static {
type Config: ConvGemmConfig;
type Convolution<MP: MatmulPrecision>: GlobalConvolution<MP, Config = Self::Config>;
fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes;
fn setup<R: Runtime, MP: MatmulPrecision>(
client: &ComputeClient<R::Server>,
problem: &ConvolutionProblem,
selection: &MatmulSelection,
line_sizes: &MatmulLineSizes,
) -> Result<Self::Config, MatmulSetupError>;
}
#[cube]
pub trait GlobalConvolution<MP: MatmulPrecision>: 'static + Send + Sync {
type LhsGlobalReader: CubeType;
type RhsGlobalReader: CubeType;
type AccGlobalReader: CubeType;
type Config: ConvGemmConfig;
type GlobalWriter: GlobalWriter<MP::Acc>;
type Accumulators: CubeType;
fn execute(
lhs_reader: Self::LhsGlobalReader,
rhs_reader: Self::RhsGlobalReader,
acc_reader: Self::AccGlobalReader,
writer: Self::GlobalWriter,
acc: &mut Self::Accumulators,
k_range: (u32, u32),
#[comptime] config: Self::Config,
);
fn init_lhs_global_reader(
lhs: View<Line<LhsG<MP>>, Coords2d>,
offset: Coords2d,
slice_size: Coords2d,
runtime_args: &RuntimeArgs,
#[comptime] config: Self::Config,
) -> Self::LhsGlobalReader;
fn init_rhs_global_reader(
rhs: View<Line<RhsG<MP>>, Coords2d>,
#[comptime] config: Self::Config,
) -> Self::RhsGlobalReader;
fn init_bias_global_reader(
bias: CubeOption<View<Line<AccG<MP>>, Coords2d>>,
#[comptime] config: Self::Config,
) -> Self::AccGlobalReader;
fn init_global_writer(
out: View<Line<AccG<MP>>, Coords2d, ReadWrite>,
#[comptime] config: Self::Config,
) -> Self::GlobalWriter;
fn init_accumulator(#[comptime] config: Self::Config) -> Self::Accumulators;
}