use cubecl_core::prelude::*;
use cubecl_core::{CubeElement, server::Allocation};
use cubecl_matmul::components::MatmulIdent;
use cubecl_matmul::components::MatmulSelection;
use cubecl_matmul::components::global::GlobalConfig;
use cubecl_matmul::{MatmulInputHandleRef, components::AvailableLineSizes};
use cubecl_matmul::components::global::args::MatmulArgs;
use cubecl_matmul::tests::layered::matmul_test_launcher::TensorRawParts;
use cubecl_matmul::tests::test_utils::Sample;
use crate::{
components::{
ConvGemmConfig as _, ConvolutionProblem,
global::{
args::{ConcreteInputsFactory, ConcreteOutputFactory},
entry_point::ConvolutionLaunch,
},
},
kernels::layered::algorithm::Algorithm,
};
use super::test_utils::TestPrecision;
type Input<Args, Lhs, Rhs, EO> = <Args as MatmulArgs>::Input<Lhs, Rhs, EO>;
type Output<Args, EO> = <Args as MatmulArgs>::Output<EO>;
pub fn test_convolution_algorithm<A, Args, P, R>(
client: ComputeClient<R::Server>,
problem: ConvolutionProblem,
selection: MatmulSelection,
) where
A: Algorithm,
Args: MatmulArgs,
P: TestPrecision,
R: Runtime,
Args::Input<P::EG, P::EG, P::EG>: ConcreteInputsFactory,
Args::Output<P::EG>: ConcreteOutputFactory,
{
let env = std::env::var("MATMUL_TEST_MODE");
let panic_on_launch_err = match env {
Ok(val) => match val.as_str() {
"panic" => true,
"skip" => false,
_ => false,
},
Err(_) => false,
};
let lhs = tensor_raw_parts::<P, R>(&client, &problem, MatmulIdent::Lhs);
let rhs = tensor_raw_parts::<P, R>(&client, &problem, MatmulIdent::Rhs);
let out = tensor_raw_parts::<P, R>(&client, &problem, MatmulIdent::Out);
let line_sizes = AvailableLineSizes {
lhs: vec![1],
rhs: vec![1],
out: R::io_optimized_line_sizes_unchecked(size_of::<P::EG>()).collect(),
}
.filter_lhs_with_tensor(&lhs.strides, &lhs.shape, problem.lhs_layout)
.filter_rhs_with_tensor(&rhs.strides, &rhs.shape, problem.rhs_layout)
.filter_out_with_tensor(&out.strides, &out.shape)
.pick_max()
.unwrap();
let config = match A::setup::<R, (P::EG, P::EG, P::EG, P::ES, P::ES, f32)>(
&client,
&problem,
&selection,
&line_sizes,
) {
Ok(config) => config,
Err(err) => {
let msg = format!("Can't launch the test: {err}");
if panic_on_launch_err {
panic!("{msg}");
} else {
println!("{msg}");
return;
}
}
};
let props = &client.properties().hardware;
if !props.max_cube_dim.can_contain(config.cube_dim())
|| config.cube_dim().num_elems() > props.max_units_per_cube
{
println!("Skipping test, too many resources requested");
return;
}
let elem_size = size_of::<P::EG>();
let lhs_handle = unsafe {
TensorHandleRef::from_raw_parts(&lhs.handle, &lhs.strides, &lhs.shape, elem_size)
};
let rhs_handle = unsafe {
TensorHandleRef::from_raw_parts(&rhs.handle, &rhs.strides, &rhs.shape, elem_size)
};
let out_handle = unsafe {
TensorHandleRef::from_raw_parts(&out.handle, &out.strides, &out.shape, elem_size)
};
let lhs_handle = A::into_tensor_handle::<R, P::EG>(&client, &lhs_handle, MatmulIdent::Lhs);
let rhs_handle = A::into_tensor_handle::<R, P::EG>(&client, &rhs_handle, MatmulIdent::Rhs);
let lhs_handle = MatmulInputHandleRef::new(lhs_handle.as_ref());
let rhs_handle = MatmulInputHandleRef::new(rhs_handle.as_ref());
let inputs = <Input<Args, P::EG, P::EG, P::EG> as ConcreteInputsFactory>::create(
&client,
&lhs_handle,
&rhs_handle,
None,
&selection,
&problem,
&config.line_sizes(),
config,
);
let output = <Output<Args, P::EG> as ConcreteOutputFactory>::create(
&client,
&out_handle,
&selection,
&problem,
&config.line_sizes(),
config,
);
unsafe {
A::GlobalConvolution::launch_unchecked::<
((P::EG, P::EG, P::EG, P::ES, P::ES, P::EA), Args),
R,
>(
&client,
config.cube_dim(),
A::cube_count(&selection, &problem),
inputs,
output,
&problem,
config,
);
}
P::assert_result::<R>(
&lhs.original_data.unwrap(),
&rhs.original_data.unwrap(),
&problem,
&client,
out.handle,
&out.shape,
&out.strides,
);
}
fn tensor_raw_parts<P: TestPrecision, R: Runtime>(
client: &ComputeClient<R::Server>,
problem: &ConvolutionProblem,
ident: MatmulIdent,
) -> TensorRawParts<P::EG> {
match ident {
MatmulIdent::Lhs => {
let shape = shape(problem, ident);
let handle = P::EG::sample::<R>(client, &shape, 1234);
let data = client.read_one_tensor(handle.as_copy_descriptor());
let data = P::EG::from_bytes(&data);
let original_data = data.to_owned();
TensorRawParts {
handle: handle.handle,
scale: None,
shape,
strides: handle.strides,
original_data: Some(original_data),
}
}
MatmulIdent::Rhs => {
let shape = shape(problem, ident);
let handle = P::EG::sample::<R>(client, &shape, 1234);
let data = client.read_one_tensor(handle.as_copy_descriptor());
let data = P::EG::from_bytes(&data);
let original_data = data.to_owned();
TensorRawParts {
handle: handle.handle,
scale: None,
shape,
strides: handle.strides,
original_data: Some(original_data),
}
}
MatmulIdent::Out => {
let zero = P::EG::from_int(0);
let data = vec![zero; tensor_size(problem, MatmulIdent::Out)];
let shape = shape(problem, MatmulIdent::Out);
let Allocation { handle, strides } =
client.create_tensor(P::EG::as_bytes(&data), &shape, size_of::<P::EG>());
TensorRawParts {
handle,
scale: None,
shape,
strides,
original_data: None,
}
}
}
}
pub(crate) fn tensor_size(problem: &ConvolutionProblem, ident: MatmulIdent) -> usize {
match ident {
MatmulIdent::Lhs => problem.m * problem.k,
MatmulIdent::Rhs => problem.k * problem.n,
MatmulIdent::Out => problem.m * problem.n,
}
}
pub(crate) fn shape(problem: &ConvolutionProblem, ident: MatmulIdent) -> Vec<usize> {
match ident {
MatmulIdent::Lhs => vec![
problem.batches,
problem.shape[0],
problem.shape[1],
problem.channels,
],
MatmulIdent::Rhs => vec![
problem.n,
problem.kernel_size[0] as usize,
problem.kernel_size[1] as usize,
problem.channels,
],
MatmulIdent::Out => vec![
problem.batches,
problem.out_shape[0],
problem.out_shape[1],
problem.n,
],
}
}