use std::fmt::Display;
use cubecl_core::{
CubeElement, Runtime,
client::ComputeClient,
prelude::{Float, Numeric},
server::{self},
};
use cubecl_matmul::tests::test_utils::{CastInto, Sample};
use cubecl_runtime::MmaConfig;
use crate::components::ConvolutionProblem;
pub trait TestPrecision {
type EG: Numeric + CubeElement + Display + CastInto<Self::ES> + Sample;
type ES: Numeric + Display + CastInto<Self::EA>;
type EA: Numeric + Display + CastInto<Self::EG>;
fn assert_result<R: Runtime>(
lhs: &[Self::EG],
rhs: &[Self::EG],
problem: &ConvolutionProblem,
client: &ComputeClient<R::Server>,
out: server::Handle,
shape: &[usize],
strides: &[usize],
);
}
impl<EG, ES> TestPrecision for (EG, ES)
where
EG: Float + CubeElement + Display + CastInto<ES> + Sample,
ES: Numeric + Display + CastInto<f32>,
f32: CastInto<EG>,
{
type EG = EG;
type ES = ES;
type EA = f32;
fn assert_result<R: Runtime>(
lhs: &[EG],
rhs: &[EG],
problem: &ConvolutionProblem,
client: &ComputeClient<R::Server>,
out: server::Handle,
shape: &[usize],
strides: &[usize],
) {
let maybe_f16 = client.properties().features.cmma.contains(&MmaConfig {
a_type: ES::as_type_native().expect("To be a native type"),
b_type: ES::as_type_native().expect("To be a native type"),
cd_type: EG::as_type_native().expect("To be a native type"),
m: 16,
k: 16,
n: 16,
});
let maybe_tf32 = client.properties().features.cmma.contains(&MmaConfig {
a_type: ES::as_type_native().expect("To be a native type"),
b_type: ES::as_type_native().expect("To be a native type"),
cd_type: EG::as_type_native().expect("To be a native type"),
m: 16,
k: 8,
n: 16,
});
let epsilon = match maybe_f16 || maybe_tf32 {
true => 3.0 * 10e-6 / EG::EPSILON.to_f32().unwrap() * half::f16::EPSILON.to_f32(),
false => 3.0 * 10e-6,
};
let expected = conv_cpu_reference::<Self>(lhs, rhs, problem)
.into_iter()
.map(|x| x.cast_into())
.collect::<Vec<EG>>();
if let Err(e) =
assert_equals_approx::<R, EG>(client, out, shape, strides, &expected, epsilon)
{
panic!("{}", e);
}
}
}
pub(crate) fn assert_equals_approx<R: Runtime, F: Float + CubeElement + Display>(
client: &ComputeClient<R::Server>,
output: server::Handle,
shape: &[usize],
strides: &[usize],
expected: &[F],
epsilon: f32,
) -> Result<(), String> {
let actual = client.read_one_tensor(output.copy_descriptor(shape, strides, size_of::<F>()));
let actual = F::from_bytes(&actual);
let epsilon = (epsilon / f32::EPSILON * F::EPSILON.to_f32().unwrap()).max(epsilon);
for (i, (a, e)) in actual.iter().zip(expected.iter()).enumerate() {
let allowed_error = (epsilon * e.to_f32().unwrap().abs()).max(epsilon);
if f32::abs(a.to_f32().unwrap() - e.to_f32().unwrap()) >= allowed_error {
return Err(format!(
"Values differ more than epsilon: index={} actual={}, expected={}, difference={}, epsilon={}",
i,
*a,
*e,
f32::abs(a.to_f32().unwrap() - e.to_f32().unwrap()),
epsilon
));
}
}
Ok(())
}
pub(crate) fn conv_cpu_reference<P: TestPrecision>(
lhs: &[P::EG],
rhs: &[P::EG],
problem: &ConvolutionProblem,
) -> Vec<P::EA>
where
{
let n = problem.batches;
let h = problem.shape[0];
let w = problem.shape[1];
let c = problem.channels;
let out_h = problem.out_shape[0];
let out_w = problem.out_shape[1];
let out_channels = problem.n;
let kh = problem.kernel_size[0] as usize;
let kw = problem.kernel_size[1] as usize;
let padding = &problem.padding;
let stride = &problem.stride;
let dilation = &problem.dilation;
let lhs_stride_n = h * w * c;
let lhs_stride_h = w * c;
let lhs_stride_w = c;
let lhs_stride_c = 1;
let rhs_stride_out_c = kh * kw * c;
let rhs_stride_kh = kw * c;
let rhs_stride_kw = c;
let rhs_stride_in_c = 1;
let out_stride_n = out_h * out_w * out_channels;
let out_stride_h = out_w * out_channels;
let out_stride_w = out_channels;
let out_stride_c = 1;
let mut out = vec![P::EA::from_int(0); n * out_h * out_w * out_channels];
for nth_batch in 0..n {
let batch_in = nth_batch * lhs_stride_n;
let batch_out = nth_batch * out_stride_n;
for out_y in 0..out_h {
let out_offset = batch_out + out_y * out_stride_h;
for out_x in 0..out_w {
let out_offset = out_offset + out_x * out_stride_w;
for out_c in 0..out_channels {
let out_pos = out_offset + out_c * out_stride_c;
let weight_offset = out_c * rhs_stride_out_c;
let mut acc = P::EA::from_int(0);
for in_c in 0..c {
let in_offset = batch_in + in_c * lhs_stride_c;
let weight_offset = weight_offset + in_c * rhs_stride_in_c;
for ky in 0..kh {
let weight_offset = weight_offset + ky * rhs_stride_kh;
for kx in 0..kw {
let weight_pos = weight_offset + kx * rhs_stride_kw;
let in_y = out_y as i32 * stride[0] as i32
+ ky as i32 * dilation[0] as i32
- padding[0];
let in_x = out_x as i32 * stride[1] as i32
+ kx as i32 * dilation[1] as i32
- padding[1];
if in_y >= 0 && in_y < h as i32 && in_x >= 0 && in_x < w as i32 {
let in_pos = in_offset
+ in_y as usize * lhs_stride_h
+ in_x as usize * lhs_stride_w;
let value: P::ES = lhs[in_pos].cast_into();
let weight: P::ES = rhs[weight_pos].cast_into();
acc += (value * weight).cast_into();
}
}
}
}
out[out_pos] = acc;
}
}
}
}
out
}