use super::{bilinear_interpolate, deform_im2col, index};
use crate::{
CubeRuntime,
kernel::{
cast, into_contiguous_aligned,
matmul::{MatmulStrategy, matmul},
reduce::reduce_dim,
slice_assign,
utils::{address_type, decompose_linear},
},
ops::{
numeric::{empty_device_dtype, zeros_client},
reshape, swap_dims,
},
tensor::CubeTensor,
};
use burn_backend::{DType, Shape, TensorMetadata, ops::DeformConvOptions};
use cubecl::{
CubeDim, CubeLaunch, calculate_cube_count_elemwise, cube,
features::AtomicUsage,
ir::FloatKind,
prelude::*,
std::{FastDivmod, tensor::layout::linear::LinearView},
};
use cubek::{
convolution::components::ConvSetupError,
reduce::components::instructions::ReduceOperationConfig,
};
use std::marker::PhantomData;
#[allow(
clippy::single_range_in_vec_init,
clippy::type_complexity,
clippy::too_many_arguments
)]
pub(crate) fn deform_conv2d_backward<R: CubeRuntime>(
input: CubeTensor<R>,
offset: CubeTensor<R>,
weight: CubeTensor<R>,
mask: Option<CubeTensor<R>>,
bias: Option<CubeTensor<R>>,
out_grad: CubeTensor<R>,
options: DeformConvOptions<2>,
) -> Result<
(
CubeTensor<R>,
CubeTensor<R>,
CubeTensor<R>,
Option<CubeTensor<R>>,
Option<CubeTensor<R>>,
),
ConvSetupError,
> {
let [_, _, out_h, out_w] = out_grad.meta.shape().dims();
let [_, _, kernel_h, kernel_w] = weight.meta.shape().dims();
let gradient_bias = bias.map(|bias| {
let grad = reduce_dim(
out_grad.clone(),
None,
0,
Default::default(),
ReduceOperationConfig::Sum,
)
.unwrap();
let grad = reduce_dim(
grad,
None,
2,
Default::default(),
ReduceOperationConfig::Sum,
)
.unwrap();
let grad = reduce_dim(
grad,
None,
3,
Default::default(),
ReduceOperationConfig::Sum,
)
.unwrap();
reshape(grad, bias.meta.shape.clone())
});
let input = into_contiguous_aligned(input);
let offset = into_contiguous_aligned(offset);
let weight = into_contiguous_aligned(weight);
let mask = mask.map(|it| into_contiguous_aligned(it));
let (input_gradient, offset_gradient, mask_gradient) = backward_gradient_inputs(
input.clone(),
weight.clone(),
offset.clone(),
mask.clone(),
out_grad.clone(),
&options,
(kernel_h, kernel_w),
)?;
let weight_grad = compute_weight_grad(
input,
offset,
mask,
out_grad,
options,
(kernel_h, kernel_w),
(out_h, out_w),
)?;
Ok((
input_gradient,
offset_gradient,
weight_grad,
mask_gradient,
gradient_bias,
))
}
fn compute_weight_grad<R: CubeRuntime>(
input: CubeTensor<R>,
offset: CubeTensor<R>,
mask: Option<CubeTensor<R>>,
out_grad: CubeTensor<R>,
options: DeformConvOptions<2>,
kernel_dims: (usize, usize),
out_dims: (usize, usize),
) -> Result<CubeTensor<R>, ConvSetupError> {
let [_, in_channels, _, _] = input.meta.shape().dims();
let [_, out_channels, _, _] = out_grad.meta.shape().dims();
let (kernel_h, kernel_w) = kernel_dims;
let groups = options.weight_groups;
let dtype = input.dtype;
let in_c_per_group = in_channels / groups;
let out_c_per_group = out_channels / groups;
let columns = deform_im2col(input, offset, mask, options, out_dims, kernel_dims)?;
let [col_size_0, col_size_1] = columns.meta.shape().dims();
let col_size_0 = col_size_0 / groups;
let out_grad = swap_dims(out_grad, 0, 1);
let out_grad = reshape(out_grad, Shape::new([groups, out_c_per_group, col_size_1]));
let columns = reshape(columns, Shape::new([groups, col_size_0, col_size_1]));
let columns = swap_dims(columns, 1, 2);
let grad_weight = matmul(out_grad, columns, None, MatmulStrategy::default(), dtype)?;
Ok(reshape(
grad_weight,
Shape::new([out_channels, in_c_per_group, kernel_h, kernel_w]),
))
}
type InputGradients<R> = (CubeTensor<R>, CubeTensor<R>, Option<CubeTensor<R>>);
fn backward_gradient_inputs<R: CubeRuntime>(
image: CubeTensor<R>,
weight: CubeTensor<R>,
offset: CubeTensor<R>,
mask: Option<CubeTensor<R>>,
out_grad: CubeTensor<R>,
options: &DeformConvOptions<2>,
kernel_dims: (usize, usize),
) -> Result<InputGradients<R>, ConvSetupError> {
let client = out_grad.client.clone();
let device = out_grad.device.clone();
let [out_channels, in_c_per_group, kernel_h, kernel_w] = weight.meta.shape().dims();
let [batch_size, _, out_h, out_w] = out_grad.meta.shape().dims();
let groups = options.weight_groups;
let out_c_per_group = out_channels / groups;
let col_shape_0 = in_c_per_group * kernel_h * kernel_w;
let col_shape_1 = batch_size * out_h * out_w;
let col_shape = Shape::new([groups, col_shape_0, col_shape_1]);
let mut columns = empty_device_dtype(client, device, col_shape, weight.dtype);
let weight = reshape(weight, Shape::new([groups, out_c_per_group, col_shape_0]));
let out_grad = swap_dims(out_grad, 0, 1);
let out_grad_shape = Shape::new([groups, out_c_per_group, col_shape_1]);
let out_grad = reshape(out_grad, out_grad_shape);
for group in 0..groups {
let dtype = weight.dtype;
let weight = swap_dims(index(weight.clone(), group), 0, 1);
let out_grad = index(out_grad.clone(), group);
let values = matmul(weight, out_grad, None, MatmulStrategy::default(), dtype)?;
let values = reshape(values, Shape::new([1, col_shape_0, col_shape_1]));
columns = slice_assign(
columns,
&[
burn_backend::Slice::from(group..group + 1),
burn_backend::Slice::from(0..col_shape_0),
burn_backend::Slice::from(0..col_shape_1),
],
values,
);
}
let columns = reshape(columns, Shape::new([col_shape_0 * groups, col_shape_1]));
let input_shape = image.shape();
let (offset_gradient, mask_gradient) = compute_offset_and_mask_gradient(
columns.clone(),
image,
offset.clone(),
mask.clone(),
options,
kernel_dims,
)?;
let input_gradient =
compute_input_grad(columns, offset, mask, options, kernel_dims, input_shape)?;
Ok((input_gradient, offset_gradient, mask_gradient))
}
fn compute_offset_and_mask_gradient<R: CubeRuntime>(
columns: CubeTensor<R>,
image: CubeTensor<R>,
offset: CubeTensor<R>,
mask: Option<CubeTensor<R>>,
options: &DeformConvOptions<2>,
kernel_dims: (usize, usize),
) -> Result<(CubeTensor<R>, Option<CubeTensor<R>>), ConvSetupError> {
let client = offset.client.clone();
let device = offset.device.clone();
let (kernel_h, kernel_w) = kernel_dims;
let [batches, _, out_h, out_w] = offset.meta.shape().dims();
let offset_groups = options.offset_groups;
let pos_shape = [batches, offset_groups, kernel_h, kernel_w, 2, out_h, out_w];
let pos_shape = pos_shape.into_iter().collect();
let grad_offset =
empty_device_dtype(client.clone(), device.clone(), offset.shape(), offset.dtype);
let grad_mask = mask
.as_ref()
.map(|mask| empty_device_dtype(client.clone(), device.clone(), mask.shape(), mask.dtype));
let num_elements_offset = offset.meta.num_elements();
let cube_dim = CubeDim::new(&image.client, num_elements_offset);
let cube_count = calculate_cube_count_elemwise(&image.client, num_elements_offset, cube_dim);
let dtype: StorageType = image.dtype.into();
unsafe {
deform_col2img_coord_kernel::launch_unchecked(
&grad_offset.client,
cube_count,
cube_dim,
address_type!(image, offset, mask, grad_offset, grad_mask),
image.into_tensor_arg(),
offset.into_tensor_arg(),
mask.map(|mask| mask.into_tensor_arg()).into(),
columns.into_tensor_arg(),
grad_offset.clone().into_linear_view(),
grad_mask
.clone()
.map(|grad_mask| grad_mask.into_tensor_arg())
.into(),
pos_shape,
DeformConv2dCol2ImgCoordArgsLaunch::new(
options.stride[0],
options.stride[1],
options.dilation[0],
options.dilation[1],
InputScalar::new(options.padding[0] as f32, dtype.elem_type()),
InputScalar::new(options.padding[1] as f32, dtype.elem_type()),
offset_groups,
kernel_h,
kernel_w,
),
dtype,
)
};
Ok((grad_offset, grad_mask))
}
#[derive(CubeLaunch, CubeType)]
struct DeformConv2dCol2ImgCoordArgs {
stride_h: usize,
stride_w: usize,
dilation_h: usize,
dilation_w: usize,
pad_h: InputScalar,
pad_w: InputScalar,
offset_groups: usize,
kernel_height: usize,
kernel_width: usize,
}
#[allow(clippy::collapsible_if)]
#[cube(launch_unchecked, address_type = "dynamic")]
fn deform_col2img_coord_kernel<F: Float>(
image: &Tensor<F>,
offset: &Tensor<F>,
mask: &ComptimeOption<Tensor<F>>,
columns: &Tensor<F>,
grad_offset: &mut LinearView<F, ReadWrite>,
grad_mask: &mut ComptimeOption<Tensor<F>>,
pos_shape: Sequence<FastDivmod<usize>>,
args: &DeformConv2dCol2ImgCoordArgs,
#[define(F)] _dtype: StorageType,
) {
if ABSOLUTE_POS >= grad_offset.shape() {
terminate!();
}
let out_h = offset.shape(2);
let out_w = offset.shape(3);
let in_channels = image.shape(1);
let height = image.shape(2);
let width = image.shape(3);
let kernel_w = args.kernel_width;
let kernel_h = args.kernel_height;
let mut grad_offset_val = F::new(0.0);
let mut grad_mask_val = F::new(0.0);
let (_, pos) = decompose_linear(ABSOLUTE_POS, &pos_shape);
let [batch, offset_group, kernel_y, kernel_x, dir, out_y, out_x] = *pos else {
unreachable!()
};
let channels_per_offset_group = in_channels / args.offset_groups;
let col_n = batch * out_h * out_w + out_y * out_w + out_x;
let col_base_idx =
offset_group * channels_per_offset_group * kernel_h * kernel_w * columns.stride(0)
+ col_n * columns.stride(1);
let mut image_base_idx =
batch * image.stride(0) + offset_group * channels_per_offset_group * image.stride(1);
let offset_pos_1 =
offset_group * kernel_h * kernel_w * 2 + kernel_y * kernel_w * 2 + kernel_x * 2;
let offset_base_idx = batch * offset.stride(0)
+ offset_pos_1 * offset.stride(1)
+ out_y * offset.stride(2)
+ out_x * offset.stride(3);
let offset_y_idx = offset_base_idx;
let offset_x_idx = offset_base_idx + offset.stride(1);
let offset_y = offset[offset_y_idx];
let offset_x = offset[offset_x_idx];
let mask_pos_1 = offset_group * kernel_h * kernel_w + kernel_y * kernel_w + kernel_x;
#[comptime]
let mask_value = match &mask {
ComptimeOption::Some(mask) => {
let mask_idx = batch * mask.stride(0)
+ mask_pos_1 * mask.stride(1)
+ out_y * mask.stride(2)
+ out_x * mask.stride(3);
mask[mask_idx]
}
ComptimeOption::None => F::new(1.0),
};
let is_y_direction = dir == 0;
for col_c in 0..channels_per_offset_group {
let col_pos = col_base_idx + col_c * kernel_h * kernel_w * columns.stride(0);
let y = F::cast_from(out_y * args.stride_h + kernel_y * args.dilation_h)
- args.pad_h.get::<F>()
+ offset_y;
let x = F::cast_from(out_x * args.stride_w + kernel_x * args.dilation_w)
- args.pad_w.get::<F>()
+ offset_x;
let weight =
get_coordinate_weight(image, image_base_idx, height, width, y, x, is_y_direction);
let columns_value = columns[col_pos];
grad_offset_val += mask_value * weight * columns_value;
if grad_mask.is_some() && is_y_direction {
grad_mask_val +=
columns_value * bilinear_interpolate(image, height, width, y, x, image_base_idx);
}
image_base_idx += image.stride(1);
}
grad_offset[ABSOLUTE_POS] = grad_offset_val;
#[comptime]
if let ComptimeOption::Some(grad_mask) = grad_mask {
if is_y_direction {
let idx = batch * grad_mask.stride(0)
+ mask_pos_1 * grad_mask.stride(1)
+ out_y * grad_mask.stride(2)
+ out_x * grad_mask.stride(3);
grad_mask[idx] = grad_mask_val
}
}
}
#[cube]
fn get_coordinate_weight<F: Float>(
input: &Tensor<F>,
offset: usize,
height: usize,
width: usize,
y: F,
x: F,
is_y_direction: bool,
) -> F {
let stride_y = input.stride(2);
let stride_x = input.stride(3);
let y = f32::cast_from(y);
let x = f32::cast_from(x);
let y_low = f32::floor(y);
let x_low = f32::floor(x);
let y_high = y_low + 1.;
let x_high = x_low + 1.;
let valid_y_low = y_low >= 0. && y_low < height as f32;
let valid_y_high = y_high >= 0. && y_high < height as f32;
let valid_x_low = x_low >= 0. && x_low < width as f32;
let valid_x_high = x_high >= 0. && x_high < width as f32;
let bottom_left = if valid_y_low && valid_x_low {
input[offset + y_low as usize * stride_y + x_low as usize * stride_x]
} else {
F::new(0.0)
};
let bottom_right = if valid_y_low && valid_x_high {
input[offset + y_low as usize * stride_y + x_high as usize * stride_x]
} else {
F::new(0.0)
};
let top_left = if valid_y_high && valid_x_low {
input[offset + y_high as usize * stride_y + x_low as usize * stride_x]
} else {
F::new(0.0)
};
let top_right = if valid_y_high && valid_x_high {
input[offset + y_high as usize * stride_y + x_high as usize * stride_x]
} else {
F::new(0.0)
};
if is_y_direction {
let delta_x = F::cast_from(x - x_low);
delta_x * (top_right - bottom_right) + (F::new(1.0) - delta_x) * (top_left - bottom_left)
} else {
let delta_y = F::cast_from(y - y_low);
delta_y * (top_right - top_left) + (F::new(1.0) - delta_y) * (bottom_right - bottom_left)
}
}
fn compute_input_grad<R: CubeRuntime>(
columns: CubeTensor<R>,
offset: CubeTensor<R>,
mask: Option<CubeTensor<R>>,
options: &DeformConvOptions<2>,
kernel_dims: (usize, usize),
input_shape: Shape,
) -> Result<CubeTensor<R>, LaunchError> {
let client = offset.client.clone();
let device = offset.device.clone();
let supports_fadd = client
.properties()
.atomic_type_usage(Type::new(StorageType::Atomic(FloatKind::F32.into())))
.contains(AtomicUsage::Add);
let supports_same_type = client
.properties()
.atomic_type_usage(Type::new(StorageType::Atomic(columns.dtype.into())))
.contains(AtomicUsage::Add);
let [batches, in_channels, height, width] = input_shape.dims();
let [_, _, out_h, out_w] = offset.meta.shape().dims();
let (kernel_h, kernel_w) = kernel_dims;
let pos_shape = [in_channels, kernel_h, kernel_w, batches, out_h, out_w];
let pos_shape = pos_shape.into_iter().collect();
let shape = Shape::new([batches, in_channels, height, width]);
let grad_in = match supports_fadd && supports_same_type {
true => zeros_client(client.clone(), device.clone(), shape, columns.dtype),
false => zeros_client(client.clone(), device.clone(), shape, DType::F32),
};
let grad_arg = grad_in.clone().into_tensor_arg();
let num_elements = columns.meta.num_elements();
let cube_dim = CubeDim::new(&offset.client, num_elements);
let cube_count = calculate_cube_count_elemwise(&offset.client, num_elements, cube_dim);
let launch = match supports_fadd {
true => deform_col2img_kernel::launch_unchecked::<IntrinsicFloatAtomicAddFamily, R>,
false => deform_col2img_kernel::launch_unchecked::<CASFloatAtomicAdd, R>,
};
let dtype = offset.dtype;
let dtypes: [StorageType; 2] = match supports_same_type {
true => [dtype.into(), dtype.into()],
false => [dtype.into(), DType::F32.into()],
};
unsafe {
launch(
&grad_in.client,
cube_count,
cube_dim,
address_type!(offset, mask, columns, grad_in),
offset.into_tensor_arg(),
mask.map(|mask| mask.into_tensor_arg()).into(),
reshape(columns, Shape::new([num_elements])).into_linear_view(),
grad_arg,
pos_shape,
DeformConv2dCol2ImgArgsLaunch::new(
options.stride[0],
options.stride[1],
options.dilation[0],
options.dilation[1],
InputScalar::new(options.padding[0] as f32, dtypes[0].elem_type()),
InputScalar::new(options.padding[1] as f32, dtypes[0].elem_type()),
options.offset_groups,
kernel_h,
kernel_w,
),
dtypes,
)
};
Ok(if !supports_same_type || !supports_fadd {
cast(grad_in, dtype)
} else {
grad_in
})
}
#[derive(CubeLaunch, CubeType)]
struct DeformConv2dCol2ImgArgs {
stride_h: usize,
stride_w: usize,
dilation_h: usize,
dilation_w: usize,
pad_h: InputScalar,
pad_w: InputScalar,
offset_groups: usize,
kernel_height: usize,
kernel_width: usize,
}
#[cube(launch_unchecked, address_type = "dynamic")]
fn deform_col2img_kernel<F: Float, FP: Float, FAdd: FloatAtomicAddFamily>(
offset: &Tensor<F>,
mask: &ComptimeOption<Tensor<F>>,
columns: &LinearView<F>,
grad_input: &mut Tensor<Atomic<ProxyType<FAdd, FP>>>,
pos_shape: Sequence<FastDivmod<usize>>,
args: &DeformConv2dCol2ImgArgs,
#[define(F, FP)] _dtype: [StorageType; 2],
) {
if ABSOLUTE_POS >= columns.shape() {
terminate!();
}
let n_in_channels = grad_input.shape(1);
let height = grad_input.shape(2);
let width = grad_input.shape(3);
let kernel_h = args.kernel_height;
let kernel_w = args.kernel_width;
let n_offset_groups = args.offset_groups;
let (_, pos) = decompose_linear(ABSOLUTE_POS, &pos_shape);
let [in_channel, kernel_y, kernel_x, batch, out_y, out_x] = *pos else {
unreachable!()
};
let channels_per_offset_group = n_in_channels / n_offset_groups;
let offset_group = in_channel / channels_per_offset_group;
let offset_pos_1 =
offset_group * kernel_h * kernel_w * 2 + kernel_y * kernel_w * 2 + kernel_x * 2;
let offset_base_idx = batch * offset.stride(0)
+ offset_pos_1 * offset.stride(1)
+ out_y * offset.stride(2)
+ out_x * offset.stride(3);
let offset_y_idx = offset_base_idx;
let offset_x_idx = offset_base_idx + offset.stride(1);
let offset_y = offset[offset_y_idx];
let offset_x = offset[offset_x_idx];
#[comptime]
let mask_value = match mask {
ComptimeOption::Some(mask) => {
let mask_pos_1 = offset_group * kernel_h * kernel_w + kernel_y * kernel_w + kernel_x;
mask[batch * mask.stride(0)
+ mask_pos_1 * mask.stride(1)
+ out_y * mask.stride(2)
+ out_x * mask.stride(3)]
}
ComptimeOption::None => F::new(1.0),
};
let y = F::cast_from(out_y * args.stride_h + kernel_y * args.dilation_h)
- args.pad_h.get::<F>()
+ offset_y;
let x = F::cast_from(out_x * args.stride_w + kernel_x * args.dilation_w)
- args.pad_w.get::<F>()
+ offset_x;
for dy in -1..=1i32 {
#[unroll]
for dx in -1..=1i32 {
let yp = y.floor() + F::cast_from(dy);
let xp = x.floor() + F::cast_from(dx);
if yp >= F::new(0.0)
&& yp < F::cast_from(height)
&& xp >= F::new(0.0)
&& xp < F::cast_from(width)
&& F::abs(y - yp) < F::new(1.0)
&& F::abs(x - xp) < F::new(1.0)
{
let gradient_pos = batch * grad_input.stride(0)
+ in_channel * grad_input.stride(1)
+ usize::cast_from(yp) * grad_input.stride(2)
+ usize::cast_from(xp) * grad_input.stride(3);
let weight = (F::new(1.0) - F::abs(y - yp)) * (F::new(1.0) - F::abs(x - xp));
let value = mask_value * F::cast_from(weight) * columns[ABSOLUTE_POS];
FAdd::Op::<FP>::float_atomic_add::<F>(&mut grad_input[gradient_pos], value);
}
}
}
}
type ProxyType<FADF, FP> = <<FADF as FloatAtomicAddFamily>::Op<FP> as FloatAtomicAdd>::ProxyType;
#[cube]
trait FloatAtomicAddFamily: Send + Sync + 'static {
type Op<ProxyType: Float>: FloatAtomicAdd;
}
#[cube]
trait FloatAtomicAdd: Send + Sync + 'static {
type ProxyType: Numeric;
fn float_atomic_add<F: Float>(ptr: &mut Atomic<Self::ProxyType>, value: F);
}
#[derive(CubeType)]
struct IntrinsicFloatAtomicAdd<F: Float> {
#[cube(comptime)]
_ty: PhantomData<F>,
}
#[derive(CubeType)]
struct CASFloatAtomicAdd;
struct IntrinsicFloatAtomicAddFamily;
impl FloatAtomicAddFamily for IntrinsicFloatAtomicAddFamily {
type Op<ProxyType: Float> = IntrinsicFloatAtomicAdd<ProxyType>;
}
impl FloatAtomicAddFamily for CASFloatAtomicAdd {
type Op<ProxyType: Float> = Self;
}
#[cube]
impl<FAdd: Float> FloatAtomicAdd for IntrinsicFloatAtomicAdd<FAdd> {
type ProxyType = FAdd;
fn float_atomic_add<F: Float>(ptr: &mut Atomic<FAdd>, value: F) {
let value = FAdd::cast_from(value);
ptr.fetch_add(value);
}
}
#[cube]
impl FloatAtomicAdd for CASFloatAtomicAdd {
type ProxyType = u32;
fn float_atomic_add<F: Float>(ptr: &mut Atomic<Self::ProxyType>, value: F) {
let value = f32::cast_from(value);
if value != 0.0 {
let mut v = ptr.load();
loop {
let prev = v;
let v_float = f32::from_bits(v);
let new = (v_float + value).to_bits();
v = ptr.compare_exchange_weak(v, new);
if prev == v {
break;
}
}
}
}
}