#pragma once
#include "megdnn/oprs.h"
namespace megdnn {
namespace naive {
class LocalForwardImpl : public LocalForward {
public:
using LocalForward::LocalForward;
void exec(
_megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout&, const TensorLayout&, const TensorLayout&) override {
return 0;
}
struct FloatNoncontigBatchKernParam {
RefPtr src;
RefPtr filter;
RefPtr dst;
size_t n, ic, ih, iw, oc, oh, ow, fh, fw;
uint32_t ph, pw, sh, sw;
ptrdiff_t inp_bs, out_bs; void* workspace;
};
typedef void (*float_noncontig_batch_kern)(const FloatNoncontigBatchKernParam&);
FloatNoncontigBatchKernParam make_float_kern_param(
_megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
_megdnn_workspace workspace) const;
virtual float_noncontig_batch_kern dispatch_float_noncontig_batch(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst);
protected:
void exec_use_float_noncontig_batch(
_megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
_megdnn_workspace workspace);
private:
template <bool is_xcorr, typename dtype>
static void naive_kern(const FloatNoncontigBatchKernParam& param);
};
class LocalBackwardDataImpl : public LocalBackwardData {
public:
using LocalBackwardData::LocalBackwardData;
void exec(
_megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout&, const TensorLayout&, const TensorLayout&) override {
return 0;
}
};
class LocalBackwardFilterImpl : public LocalBackwardFilter {
public:
using LocalBackwardFilter::LocalBackwardFilter;
void exec(
_megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout&, const TensorLayout&, const TensorLayout&) override {
return 0;
}
};
} }
#define UNPACK_LOCAL_FLOAT_NONCONTIG_BATCH_KERN_PARAM(_p, _dtype) \
const _dtype* src = static_cast<const _dtype*>(_p.src.get_ptr()); \
const _dtype* filter = static_cast<const _dtype*>(_p.filter.get_ptr()); \
_dtype* dst = static_cast<_dtype*>(_p.dst.get_ptr()); \
_dtype* workspace = static_cast<_dtype*>(_p.workspace); \
const int N = _p.n, IC = _p.ic, IH = _p.ih, IW = _p.iw, OC = _p.oc, OH = _p.oh, \
OW = _p.ow, FH = _p.fh, FW = _p.fw; \
const uint32_t PH = _p.ph, PW = _p.pw, SH = _p.sh, SW = _p.sw; \
const ptrdiff_t INP_BS = _p.inp_bs, OUT_BS = _p.out_bs