#include "megdnn/oprs.h"
#include "src/common/utils.h"
namespace megdnn {
void CorrelationBase::deduce_layout_fwd(
const TensorLayout& data1, const TensorLayout& data2, TensorLayout& dst) {
megdnn_assert_contiguous(data1);
megdnn_assert_contiguous(data2);
megdnn_assert_contiguous(dst);
auto errmsg = [&]() {
return megdnn_layout_msg(data1) + ", " + megdnn_layout_msg(data2) + ", " +
megdnn_layout_msg(dst);
};
MEGDNN_MARK_USED_VAR(errmsg);
using Format = CorrelationBase::Param::Format;
megdnn_assert(param().format == Format::NCHW);
auto data1_dtype = data1.dtype, data2_dtype = data2.dtype;
megdnn_assert(
data1_dtype == data2_dtype &&
data1_dtype.category() == DTypeCategory::FLOAT);
megdnn_assert(data1.ndim == 4_z, "%s", errmsg().c_str());
megdnn_assert(data2.ndim == 4_z, "%s", errmsg().c_str());
uint32_t pad_size = param().pad_size;
uint32_t kernel_size = param().kernel_size;
uint32_t stride1 = param().stride1;
uint32_t stride2 = param().stride2;
uint32_t max_displacement = param().max_displacement;
int paddedbottomheight = data1[2] + 2 * pad_size;
int paddedbottomwidth = data1[3] + 2 * pad_size;
uint32_t kernel_radius = (kernel_size - 1) / 2;
uint32_t border_size = max_displacement + kernel_radius;
uint32_t top_width =
ceil(static_cast<float>(paddedbottomwidth - border_size * 2) /
static_cast<float>(stride1));
uint32_t top_height =
ceil(static_cast<float>(paddedbottomheight - border_size * 2) /
static_cast<float>(stride1));
uint32_t neighborhood_grid_radius = max_displacement / stride2;
uint32_t neighborhood_grid_width = neighborhood_grid_radius * 2 + 1;
uint32_t top_channels = neighborhood_grid_width * neighborhood_grid_width;
megdnn_assert(top_width >= 1 && top_height >= 1);
dst = TensorLayout{{data1[0], top_channels, top_height, top_width}, data1.dtype};
}
void CorrelationBase::check_layout_fwd(
const TensorLayout& data1, const TensorLayout& data2, const TensorLayout& dst) {
TensorLayout dst_expected;
megdnn_assert_eq_dtype(data1, dst);
megdnn_assert_eq_shape(data1, data2);
deduce_layout_fwd(data1, data2, dst_expected);
megdnn_assert_eq_shape(dst_expected, dst);
}
void CorrelationForward::deduce_layout(
const TensorLayout& data1, const TensorLayout& data2, TensorLayout& dst) {
deduce_layout_fwd(data1, data2, dst);
}
void CorrelationForward::check_exec(
const TensorLayout& data1, const TensorLayout& data2, const TensorLayout& dst,
size_t workspace_in_bytes) {
check_layout_fwd(data1, data2, dst);
auto required_workspace_in_bytes = get_workspace_in_bytes(data1, data2, dst);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}
void CorrelationBackwardData1::check_exec(
const TensorLayout& diff, const TensorLayout& data1, const TensorLayout& data2,
const TensorLayout& grad1, size_t workspace_in_bytes) {
check_layout_fwd(grad1, data2, diff);
megdnn_assert_eq_shape(data1, data2);
auto required_workspace_in_bytes =
get_workspace_in_bytes(diff, data1, data2, grad1);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}
void CorrelationBackwardData2::check_exec(
const TensorLayout& diff, const TensorLayout& data1, const TensorLayout& data2,
const TensorLayout& grad2, size_t workspace_in_bytes) {
check_layout_fwd(data1, grad2, diff);
megdnn_assert_eq_shape(data1, data2);
auto required_workspace_in_bytes =
get_workspace_in_bytes(diff, data1, data2, grad2);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}
void CorrelationBackwardData2::deduce_layout(
const TensorLayout& diff, const TensorLayout& data1, const TensorLayout& data2,
TensorLayout& grad) {
megdnn_assert_eq_shape(data1, data2);
check_layout_fwd(data1, data2, diff);
grad = data2;
}
void CorrelationBackwardData1::deduce_layout(
const TensorLayout& diff, const TensorLayout& data1, const TensorLayout& data2,
TensorLayout& grad) {
megdnn_assert_eq_shape(data1, data2);
check_layout_fwd(data1, data2, diff);
grad = data1;
}
}