#pragma once
#include "src/common/cv/aligned_allocator.h"
#include "src/common/utils.h"
#include "./helper.h"
#include "megdnn/opr_param_defs.h"
#include <cstdint>
#include <memory>
#include <mutex>
namespace megdnn {
namespace megcv {
using InterpolationMode = megdnn::param::WarpPerspective::InterpolationMode;
using BorderMode = megdnn::param::WarpPerspective::BorderMode;
template <int INTER_BITS_ = 5, int INTER_MAX_ = 7, int INTER_REMAP_COEF_BITS_ = 15>
class InterpolationTable {
public:
using IMode = InterpolationMode;
static constexpr int INTER_BITS = INTER_BITS_;
static constexpr int INTER_MAX = INTER_MAX_;
static constexpr int INTER_REMAP_COEF_BITS = INTER_REMAP_COEF_BITS_;
static constexpr int INTER_TAB_SIZE = (1 << INTER_BITS);
static constexpr int INTER_TAB_SIZE2 = INTER_TAB_SIZE * INTER_TAB_SIZE;
static constexpr int INTER_REMAP_COEF_SCALE = 1 << INTER_REMAP_COEF_BITS;
static const void* get_table(InterpolationMode imode, bool fixpt);
#if MEGDNN_X86
static const int16_t* get_linear_ic4_table();
#endif
private:
template <int ksize>
struct Table {
float ftab[INTER_TAB_SIZE2 * ksize * ksize];
int16_t itab[INTER_TAB_SIZE2 * ksize * ksize];
#if MEGDNN_X86
alignas(128) int16_t bilineartab_ic4_buf[INTER_TAB_SIZE2 * 2 * 8];
static void* operator new(std::size_t sz) {
return ah::aligned_allocator<Table, 128>().allocate(sz / sizeof(Table));
}
void operator delete(void* ptr) noexcept {
ah::aligned_allocator<Table, 128>().deallocate(
reinterpret_cast<Table*>(ptr), 0);
}
#endif
};
struct TableHolderBase {
DNN_MUTEX mtx;
virtual bool get(float**, int16_t**) = 0;
protected:
~TableHolderBase() = default;
};
template <int ksize>
struct TableHolder final : public TableHolderBase {
std::unique_ptr<Table<ksize>> table;
bool get(float** ftab, int16_t** itab) override {
bool ret = true;
if (!table) {
ret = false;
table.reset(new Table<ksize>);
}
*ftab = table->ftab;
*itab = table->itab;
return ret;
}
};
static void init_inter_tab_1d(InterpolationMode imode, float* tab, int tabsz);
static inline void interpolate_linear(float x, float* coeffs);
static inline void interpolate_cubic(float x, float* coeffs);
static inline void interpolate_lanczos4(float x, float* coeffs);
static TableHolder<2> sm_tab_linear;
static TableHolder<4> sm_tab_cubic;
static TableHolder<8> sm_tab_lanczos4;
};
} }