#pragma once
#include "megdnn/oprs/nn.h"
#include "src/common/utils.h"
#include <cstring>
namespace megdnn {
namespace naive {
namespace convolution3d {
struct GroupCounter {
const size_t grp_size;
size_t cur_grp = 0, cur_off = 0;
explicit GroupCounter(size_t grp_size) : grp_size{grp_size} {}
void next() {
if ((++cur_off) == grp_size) {
cur_off = 0;
++cur_grp;
}
}
};
struct StrategyFwd {
template <typename st, typename ft, typename dt>
static void on(st& s, ft& f, dt& d) {
d += s * f;
}
template <typename dt>
static void init_dval(dt& d) {
d = 0;
}
};
struct StrategyBwdData {
template <typename st, typename ft, typename dt>
static void on(st& s, ft& f, dt& d) {
s += f * d;
}
template <typename dt>
static void init_dval(dt&) {}
};
struct StrategyBwdFlt {
template <typename st, typename ft, typename dt>
static void on(st& s, ft& f, dt& d) {
f += s * d;
}
template <typename dt>
static void init_dval(dt&) {}
};
template <typename stype, typename ftype, typename dtype, class Strategy>
void compute3d(
_megdnn_tensor_in src, ftype* __restrict fptr, _megdnn_tensor_out dst,
const Convolution3D::CanonizedFilterMeta& filter_meta) {
size_t spatial_start, channel_pos;
using Format = param::Convolution3D::Format;
if (filter_meta.format == Format::NCDHW) {
spatial_start = 2;
channel_pos = 1;
} else {
megdnn_assert(filter_meta.format == Format::NDHWC, "invalid conv format");
spatial_start = 1;
channel_pos = 4;
}
auto N = src.layout.shape[0], ID = src.layout.shape[spatial_start],
IH = src.layout.shape[spatial_start + 1],
IW = src.layout.shape[spatial_start + 2];
auto FD = filter_meta.spatial[0], FH = filter_meta.spatial[1],
FW = filter_meta.spatial[2];
auto OC = dst.layout.shape[channel_pos], OD = dst.layout.shape[spatial_start],
OH = dst.layout.shape[spatial_start + 1],
OW = dst.layout.shape[spatial_start + 2];
size_t FS_G, FS_OC, FS_IC, FS_SPATIAL;
if (filter_meta.format == Format::NCDHW) {
FS_SPATIAL = 1;
FS_IC = FD * FH * FW;
FS_OC = FS_IC * filter_meta.icpg;
FS_G = FS_OC * filter_meta.ocpg;
} else {
megdnn_assert(filter_meta.format == Format::NDHWC, "invalid conv format");
FS_IC = 1;
FS_SPATIAL = filter_meta.icpg;
FS_OC = FS_SPATIAL * FD * FH * FW;
FS_G = FS_OC * filter_meta.ocpg;
}
int pd = filter_meta.padding[0], ph = filter_meta.padding[1],
pw = filter_meta.padding[2];
size_t sd = filter_meta.stride[0], sh = filter_meta.stride[1],
sw = filter_meta.stride[2];
int dd = filter_meta.dilation[0], dh = filter_meta.dilation[1],
dw = filter_meta.dilation[2];
stype* __restrict sptr = src.ptr<stype>();
dtype* __restrict dptr = dst.ptr<dtype>();
int d_offset = -pd, h_offset = -ph, w_offset = -pw;
if (filter_meta.should_flip) {
d_offset += filter_meta.dilated_spatial[0] - 1;
h_offset += filter_meta.dilated_spatial[1] - 1;
w_offset += filter_meta.dilated_spatial[2] - 1;
dd = -dd;
dh = -dh;
dw = -dw;
}
auto get_linear_addr = [&filter_meta](
size_t n, size_t c, size_t d, size_t h, size_t w,
const TensorLayout& layout) -> size_t {
if (filter_meta.format == Format::NCDHW) {
return n * layout.stride[0] + c * layout.stride[1] + d * layout.stride[2] +
h * layout.stride[3] + w * layout.stride[4];
} else {
megdnn_assert(filter_meta.format == Format::NDHWC, "invalid conv format");
return n * layout.stride[0] + d * layout.stride[1] + h * layout.stride[2] +
w * layout.stride[3] + c * layout.stride[4];
}
};
for (size_t n = 0; n < N; ++n) {
GroupCounter gc_out{filter_meta.ocpg};
for (size_t oc = 0; oc < OC; ++oc, gc_out.next())
for (size_t od = 0; od < OD; ++od)
for (size_t oh = 0; oh < OH; ++oh)
for (size_t ow = 0; ow < OW; ++ow) {
dtype& dval =
dptr[get_linear_addr(n, oc, od, oh, ow, dst.layout)];
Strategy::init_dval(dval);
for (size_t fd = 0; fd < FD; ++fd)
for (size_t fh = 0; fh < FH; ++fh)
for (size_t fw = 0; fw < FW; ++fw) {
size_t id = sd * od + fd * dd + d_offset,
ih = sh * oh + fh * dh + h_offset,
iw = sw * ow + fw * dw + w_offset;
if (id < ID && ih < IH && iw < IW) {
size_t ic0 = gc_out.cur_grp * filter_meta.icpg,
ic1 = ic0 + filter_meta.icpg;
for (size_t ic = ic0; ic < ic1; ++ic) {
stype& sval = sptr[get_linear_addr(
n, ic, id, ih, iw, src.layout)];
ftype& fval =
fptr[gc_out.cur_grp * FS_G +
gc_out.cur_off * FS_OC +
(ic - ic0) * FS_IC +
(fd * FH * FW + fh * FW + fw) *
FS_SPATIAL];
Strategy::on(sval, fval, dval);
}
}
}
}
}
}
//! forward with only filter ptr
template <typename stype, typename ftype, typename dtype>
void forward(
_megdnn_tensor_in src, const ftype* fptr, _megdnn_tensor_out dst,
const Convolution3D::CanonizedFilterMeta& filter_meta) {
megdnn_assert(filter_meta.spatial_ndim == 3);
compute3d<stype, ftype, dtype, StrategyFwd>(
src, const_cast<ftype*>(fptr), dst, filter_meta);
}
template <typename stype, typename ftype, typename dtype>
void forward(
_megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
const Convolution3D::CanonizedFilterMeta& filter_meta) {
return forward<stype, ftype, dtype>(src, filter.ptr<ftype>(), dst, filter_meta);
}
template <typename ftype, typename dtype, typename gtype>
void backward_data(
_megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
const Convolution3D::CanonizedFilterMeta& filter_meta) {
memset(grad.raw_ptr(), 0, grad.layout.span().dist_byte());
megdnn_assert(filter_meta.spatial_ndim == 3);
compute3d<gtype, ftype, dtype, StrategyBwdData>(
grad, filter.ptr<ftype>(), diff, filter_meta);
}
template <typename stype, typename dtype, typename gtype>
void backward_filter(
_megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
const Convolution3D::CanonizedFilterMeta& filter_meta) {
memset(grad.raw_ptr(), 0, grad.layout.span().dist_byte());
megdnn_assert(filter_meta.spatial_ndim == 3);
compute3d<stype, gtype, dtype, StrategyBwdFlt>(
src, grad.ptr<gtype>(), diff, filter_meta);
}
} } }