#pragma once
#include <vector>
#include "megdnn/dtype.h"
#include "megdnn/oprs.h"
namespace megdnn {
namespace winograd {
using NonlineMode = ::megdnn::ConvBias::Param::NonlineMode;
using BiasMode = ConvBiasForward::BiasMode;
template <
typename ctype, typename dst_type, typename input_filter_compute_type,
typename output_compute_type,
param::ConvBias::Format layout = param::ConvBias::Format::NCHW,
param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT>
class StrategyHelper {
public:
static void filter(
const ctype* filter, input_filter_compute_type* filter_transform_buf,
input_filter_compute_type* transform_mid_buf, size_t OC, size_t IC,
size_t oc_start, size_t oc_end, size_t m, size_t r,
const std::vector<float>& interp_points, DType dtype, float rescale = 1.0f);
static void input(
const ctype* input, input_filter_compute_type* input_transform_buf,
input_filter_compute_type* transform_mid_buf, int ih_start, int iw_start,
size_t IH, size_t IW, size_t IC, size_t ic, size_t unit_idx,
size_t nr_units_in_tile, size_t m, size_t r,
const std::vector<float>& interp_points, DType dtype, float rescale = 1.0f);
static void output(
const output_compute_type* output_transform_buf,
const output_compute_type* bias, dst_type* output,
output_compute_type* transform_mid_buf, BiasMode bmode,
NonlineMode nonline_mode, size_t oh_start, size_t ow_start, size_t OH,
size_t OW, size_t OC, size_t oc_start, size_t oc_index, size_t unit_idx,
size_t nr_units_in_tile, size_t m, size_t r,
const std::vector<float>& interp_points, DType dtype,
float input_filter_scale = 1.0f, float input_filter_rescale = 1.0f, float rescale = 1.0f);
};
} }