/**
* \file dnn/src/common/cond_take/predicate.cuh
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megdnn/arch.h"
#include "src/common/opr_param_defs_enumv.cuh"
#if MEGDNN_CC_HOST
#include "megdnn/opr_param_defs.h"
#endif
#ifndef __device__
#define __device__
#define __host__
#define def_device
#endif
#include <cmath>
namespace megdnn {
namespace cond_take {
typedef param_enumv::CondTake::Mode PEnum;
struct KParam {
float val, eps;
#if MEGDNN_CC_HOST
KParam(const param::CondTake& p) : val(p.val), eps(p.eps) {}
#endif
};
template <uint32_t mode, typename ctype>
struct Pred;
#define do_inst_eq_f(_ct) \
template <> \
struct Pred<PEnum::EQ, _ct> { \
typedef _ct ctype; \
ctype val, eps; \
Pred(const KParam& p) : val(p.val), eps(p.eps) {} \
__device__ __host__ bool operator()(ctype x) const { \
return fabsf(val - x) < eps; \
} \
};
#define do_inst_eq_i(_ct) \
template <> \
struct Pred<PEnum::EQ, _ct> { \
typedef _ct ctype; \
ctype val; \
Pred(const KParam& p) : val(p.val) {} \
__device__ __host__ bool operator()(ctype x) const { return val == x; } \
};
#define inst_eq_f(_dt) do_inst_eq_f(DTypeTrait<_dt>::ctype)
#define inst_eq_i(_dt) do_inst_eq_i(DTypeTrait<_dt>::ctype)
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(inst_eq_f)
MEGDNN_FOREACH_COMPUTING_DTYPE_INT(inst_eq_i)
inst_eq_i(::megdnn::dtype::Bool)
#undef inst_eq_f
#undef inst_eq_i
template <typename ctype_>
struct Pred<PEnum::NEQ, ctype_> {
typedef ctype_ ctype;
Pred<PEnum::EQ, ctype> eq;
Pred(const KParam& p) : eq(p) {}
__device__ __host__ bool operator()(ctype x) const { return !this->eq(x); }
};
#define DEF_OP(_name, _op) \
template <typename ctype_> \
struct Pred<PEnum::_name, ctype_> { \
typedef ctype_ ctype; \
ctype val; \
Pred(const KParam& p) : val(p.val) {} \
__device__ __host__ bool operator()(ctype x) const { return x _op val; } \
}
DEF_OP(LT, <);
DEF_OP(LEQ, <=);
DEF_OP(GT, >);
DEF_OP(GEQ, >=);
#undef DEF_OP
#define MEGDNN_FOREACH_COND_TAKE_MODE(cb) cb(EQ) cb(NEQ) cb(LT) cb(LEQ) cb(GT) cb(GEQ)
} // namespace cond_take
} // namespace megdnn
#ifdef def_device
#undef __device__
#undef __host__
#endif
// vim: ft=cpp syntax=cpp.doxygen