use burn_backend::{
DType,
ops::{ConvOptions, conv::calculate_conv_output_sizes},
};
use burn_std::{Metadata, Shape};
use core::iter;
use cubecl::{
prelude::*,
std::tensor::{TensorHandle, into_contiguous_pitched},
};
use cubek::convolution::components::ConvSetupError;
use crate::{
CubeRuntime,
kernel::{
AddOp, into_contiguous_aligned, launch_binop,
matmul::{MatmulStrategy, matmul},
utils::split_dim,
},
ops::{reshape, swap_dims},
tensor::CubeTensor,
};
#[cfg(not(test))]
pub(crate) fn batches_per_run(
batch_size: usize,
out_shape: usize,
plane_size: usize,
) -> Result<usize, ConvSetupError> {
use cubek::matmul::definition::MatmulAvailabilityError;
let cube_count_per_batch = out_shape.div_ceil(plane_size);
let max_cube_count = u16::MAX as usize;
let max_simultaneous = Ord::min(max_cube_count / cube_count_per_batch, batch_size);
if max_simultaneous == 0 {
return Err(MatmulAvailabilityError::CubeCountTooBig(CubeCount::Static(
cube_count_per_batch as u32,
1,
1,
))
.into());
}
Ok((0..=max_simultaneous)
.rev()
.find(|per_run| batch_size.is_multiple_of(*per_run))
.expect("Logically not possible"))
}
#[cfg(test)]
#[allow(unused)]
pub(crate) fn batches_per_run(
batch_size: usize,
out_shape: usize,
plane_size: usize,
) -> Result<usize, ConvSetupError> {
Ok(1)
}
pub fn conv_im2col_1x1<R: CubeRuntime, const N: usize>(
input: CubeTensor<R>,
mut weight: CubeTensor<R>,
bias: Option<CubeTensor<R>>,
options: ConvOptions<N>,
) -> Result<CubeTensor<R>, ConvSetupError> {
if options.groups != 1 {
return Err(ConvSetupError::Groups(options.groups));
}
let rank = input.meta.num_dims();
let dim_c = rank - 1;
let batch_size = input.meta.shape()[0];
let in_channels = input.meta.shape()[dim_c];
let in_shape = &input.meta.shape()[1..dim_c];
let out_channels = weight.meta.shape()[0];
let kernel_shape = &weight.meta.shape()[1..dim_c];
if kernel_shape.iter().any(|s| *s != 1) {
return Err(ConvSetupError::Unknown);
}
let out_shape = calculate_conv_output_sizes(
kernel_shape,
&options.stride,
&options.padding,
&options.dilation,
in_shape,
);
let mut split_m = vec![batch_size];
split_m.extend(out_shape.iter().copied());
if kernel_shape.iter().any(|it| *it != 1) || in_shape != out_shape {
return Err(ConvSetupError::Unknown);
}
let input = reshape_input(input); let dtype = input.dtype;
let weight = if weight.meta.strides()[dim_c] != 1 {
*weight.meta = Metadata::new(
[out_channels, in_channels], [weight.meta.strides()[0], weight.meta.strides()[dim_c]],
);
into_contiguous_aligned(weight)
} else {
*weight.meta = Metadata::new([out_channels, in_channels], [weight.meta.strides()[0], 1]);
weight
};
let weight = swap_dims(weight, 0, 1);
let out = matmul(input, weight, None, MatmulStrategy::default(), dtype)?;
let mut out = split_dim(out, 0, &split_m);
if let Some(bias) = bias {
let mut bias_shape = iter::repeat_n(1, rank - 1).collect::<Vec<_>>();
bias_shape.push(out_channels);
let bias = reshape(bias, bias_shape.into());
out = launch_binop::<R, AddOp>(out, bias);
}
Ok(out)
}
fn reshape_input<R: CubeRuntime>(input: CubeTensor<R>) -> CubeTensor<R> {
let rank = input.meta.num_dims();
let dim_c = rank - 1;
let dtype = input.dtype;
let batch_size = input.meta.shape()[0];
let in_c: usize = input.meta.shape()[dim_c];
let in_shape: Shape = input.meta.shape()[1..dim_c].into();
let mut input = if !is_spatial_contiguous(input.meta.shape(), input.meta.strides()) {
let (client, device) = (input.client.clone(), input.device.clone());
let contiguous = into_contiguous_pitched(&client, input.binding(), dtype.into());
from_handle(client, device, contiguous, dtype)
} else {
input
};
*input.meta = Metadata::new(
[batch_size * in_shape.num_elements(), in_c], [input.meta.strides()[dim_c - 1], input.meta.strides()[dim_c]],
);
input
}
fn is_spatial_contiguous(shape: &[usize], strides: &[usize]) -> bool {
let rank = shape.len();
let dim_c = rank - 1;
if strides[dim_c] != 1 {
return false;
}
for i in (1..dim_c).rev() {
if strides[i + 1] * shape[i + 1] != strides[i] {
return false;
}
}
true
}
fn from_handle<R: CubeRuntime>(
client: ComputeClient<R>,
device: R::Device,
handle: TensorHandle<R>,
dtype: DType,
) -> CubeTensor<R> {
CubeTensor::new(
client.clone(),
handle.handle,
*handle.metadata,
device.clone(),
dtype,
)
}