#include "megbrain_build_config.h"
#if MGB_JIT && MGB_JIT_MLIR
#include "./common.h"
#include "./each_mode.h"
#include "./numerical.h"
#include "./types.h"
#include "megbrain/common.h"
#include "megbrain/exception.h"
#include "megbrain/jit/mlir/ir/dialect.h"
#include <llvm/Support/raw_ostream.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
namespace mgb {
namespace jit {
using Mode = megdnn::param::Elemwise::Mode;
template <Mode mode>
mlir::Value lower_mode(
mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands);
#define cb(mode, fun) \
template <> \
mlir::Value lower_mode<Mode::mode>( \
mlir::OpBuilder & builder, mlir::Location loc, ValueRange operands) { \
ValueBuilderHelper helper(builder, loc); \
return helper.fun(operands); \
}
cb(ABS, abs);
cb(CEIL, ceil);
cb(COS, cos);
cb(EXP, exp);
cb(FLOOR, floor);
cb(LOG, log);
cb(NEGATE, neg);
cb(SIN, sin);
cb(TANH, tanh);
cb(ADD, add);
cb(MAX, max);
cb(MIN, min);
cb(MOD, mod);
cb(MUL, mul);
cb(SUB, sub);
cb(TRUE_DIV, div);
#undef cb
template <>
mlir::Value lower_mode<Mode::ACOS>(
mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
auto x = operands[0];
auto one_minus_x_2 = helper.sub(helper.const_f32(1.f), helper.mul(x, x));
auto asin = atan2_approx(helper, x, helper.sqrt(one_minus_x_2));
auto pi_over_2 = helper.const_f32(1.57079637f);
return helper.sub(pi_over_2, asin);
}
template <>
mlir::Value lower_mode<Mode::ASIN>(
mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
auto x = operands[0];
auto one_minus_x_2 = helper.sub(helper.const_f32(1.f), helper.mul(x, x));
return atan2_approx(helper, x, helper.sqrt(one_minus_x_2));
}
template <>
mlir::Value lower_mode<Mode::ERFCINV>(
mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
auto minus_sqrt2 = helper.const_f32(-1.4142135623f);
auto x = helper.mul(helper.const_f32(0.5f), operands[0]);
return helper.div(ndtri_approx(helper, x), minus_sqrt2);
}
template <>
mlir::Value lower_mode<Mode::ERFC>(
mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.sub(helper.const_f32(1.f), erf_approx(helper, operands[0]));
}
template <>
mlir::Value lower_mode<Mode::ERFINV>(
mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
auto sqrt2 = helper.const_f32(1.4142135623f);
auto x = helper.mul(
helper.const_f32(0.5f), helper.add(operands[0], helper.const_f32(1.f)));
return helper.div(ndtri_approx(helper, x), sqrt2);
}
template <>
mlir::Value lower_mode<Mode::ERF>(
mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return erf_approx(helper, operands[0]);
}
template <>
mlir::Value lower_mode<Mode::EXPM1>(
mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.sub(helper.exp(operands[0]), helper.const_f32(1.f));
}
template <>
mlir::Value lower_mode<Mode::FAST_TANH>(
mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
auto square = helper.mul(operands[0], operands[0]);
return helper.div(
helper.mul(operands[0], helper.add(helper.const_f32(27.f), square)),
helper.add(
helper.const_f32(27.f), helper.mul(helper.const_f32(9.f), square)));
}
template <>
mlir::Value lower_mode<Mode::H_SWISH>(
mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
auto const_3 = helper.const_f32(3.f);
auto const_0 = helper.const_f32(0.f);
auto const_6 = helper.const_f32(6.f);
auto tmp = helper.add(operands[0], const_3);
return helper.div(
helper.mul(operands[0], helper.min(helper.max(tmp, const_0), const_6)),
const_6);
}
template <>
mlir::Value lower_mode<Mode::LOG1P>(
mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.log(helper.add(operands[0], helper.const_f32(1.f)));
}
template <>
mlir::Value lower_mode<Mode::RELU>(
mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.max(operands[0], helper.const_f32(0.f));
}
template <>
mlir::Value lower_mode<Mode::ROUND>(
mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.select(
helper.gt(operands[0], helper.const_f32(0.f)),
helper.floor(helper.add(operands[0], helper.const_f32(0.5f))),
helper.ceil(helper.sub(operands[0], helper.const_f32(0.5f))));
}
template <>
mlir::Value lower_mode<Mode::SIGMOID>(
mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.div(
helper.const_f32(1.f),
helper.add(helper.exp(helper.neg(operands[0])), helper.const_f32(1.f)));
}
template <>
mlir::Value lower_mode<Mode::ABS_GRAD>(
mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.select(
helper.gt(operands[0], helper.const_f32(0.f)), operands[1],
helper.neg(operands[1]));
}
template <>
mlir::Value lower_mode<Mode::ATAN2>(
mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return atan2_approx(helper, operands[0], operands[1]);
}
template <>
mlir::Value lower_mode<Mode::EQ>(
mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.select(
helper.eq(operands[0], operands[1]), helper.const_f32(1.f),
helper.const_f32(0.f));
}
template <>
mlir::Value lower_mode<Mode::FAST_TANH_GRAD>(
mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
auto x_pow2 = helper.mul(operands[0], operands[0]);
auto deno = helper.add(helper.const_f32(3.f), x_pow2);
return helper.mul(
helper.div(
helper.add(
helper.add(
helper.div(
helper.mul(helper.const_f32(-48.f), x_pow2),
deno),
helper.const_f32(27.f)),
x_pow2),
helper.mul(deno, helper.const_f32(9.f))),
operands[1]);
}
template <>
mlir::Value lower_mode<Mode::FLOOR_DIV>(
mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.floor(helper.div(operands[0], operands[1]));
}
template <>
mlir::Value lower_mode<Mode::FUSE_ADD_H_SWISH>(
mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
auto sum = helper.add(operands[0], operands[1]);
auto const_3 = helper.const_f32(3.f);
auto const_0 = helper.const_f32(0.f);
auto const_6 = helper.const_f32(6.f);
auto tmp = helper.add(sum, const_3);
return helper.div(
helper.mul(sum, helper.min(helper.max(tmp, const_0), const_6)), const_6);
}
template <>
mlir::Value lower_mode<Mode::FUSE_ADD_RELU>(
mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
auto sum = helper.add(operands[0], operands[1]);
return helper.max(sum, helper.const_f32(0.f));
}
template <>
mlir::Value lower_mode<Mode::FUSE_ADD_SIGMOID>(
mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.div(
helper.const_f32(1.f),
helper.add(
helper.exp(helper.neg(helper.add(operands[0], operands[1]))),
helper.const_f32(1.f)));
}
template <>
mlir::Value lower_mode<Mode::FUSE_ADD_TANH>(
mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.tanh(helper.add(operands[0], operands[1]));
}
template <>
mlir::Value lower_mode<Mode::H_SWISH_GRAD>(
mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.select(
helper.lt(operands[0], helper.const_f32(-3.f)), helper.const_f32(0.f),
helper.select(
helper.gt(operands[0], helper.const_f32(3.f)), operands[1],
helper.mul(
helper.div(
helper.add(
helper.mul(
helper.const_f32(2.f), operands[0]),
helper.const_f32(3.f)),
helper.const_f32(6.f)),
operands[1])));
}
template <>
mlir::Value lower_mode<Mode::LEQ>(
mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.select(
helper.le(operands[0], operands[1]), helper.const_f32(1.f),
helper.const_f32(0.f));
}
template <>
mlir::Value lower_mode<Mode::LOG_SUM_EXP>(
mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.log(helper.add(helper.exp(operands[0]), helper.exp(operands[1])));
}
template <>
mlir::Value lower_mode<Mode::LT>(
mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.select(
helper.lt(operands[0], operands[1]), helper.const_f32(1.f),
helper.const_f32(0.f));
}
template <>
mlir::Value lower_mode<Mode::POW>(
mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.exp(helper.mul(operands[1], helper.log(operands[0])));
}
template <>
mlir::Value lower_mode<Mode::SIGMOID_GRAD>(
mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.mul(
helper.mul(operands[0], helper.sub(helper.const_f32(1.f), operands[0])),
operands[1]);
}
template <>
mlir::Value lower_mode<Mode::SWITCH_GT0>(
mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.select(
helper.gt(operands[0], helper.const_f32(0.f)), operands[1],
helper.const_f32(0.f));
}
template <>
mlir::Value lower_mode<Mode::TANH_GRAD>(
mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.mul(
helper.sub(helper.const_f32(1.0f), helper.mul(operands[0], operands[0])),
operands[1]);
}
template <>
mlir::Value lower_mode<Mode::COND_LEQ_MOV>(
mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.select(
helper.le(operands[0], operands[1]), operands[2], helper.const_f32(0.f));
}
template <>
mlir::Value lower_mode<Mode::FUSE_MUL_ADD3>(
mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.add(helper.mul(operands[0], operands[1]), operands[2]);
}
mlir::Value lower_elemwise_to_std(
mlir::Operation* op, mlir::OpBuilder& builder, mlir::Location loc,
ValueRange operands) {
auto mode = llvm::dyn_cast<dialect::Elemwise>(op).mode();
switch (mode) {
#define cb(_, _mode) \
case Mode::_mode: \
return lower_mode<Mode::_mode>(builder, loc, operands);
MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY(cb);
MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb);
MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb);
default:
return nullptr;
}
#undef cb
}
mlir::Value lower_typecvt_to_std(
mlir::Operation* op, mlir::OpBuilder& builder, mlir::Location loc,
mlir::Value input) {
auto&& typecvt = llvm::dyn_cast<dialect::TypeCvt>(op);
mlir::Type idtype = typecvt.idtype();
mlir::Type odtype =
megdnn_dtype_to_mlir_type(typecvt.dtype(), builder.getContext());
mlir::Type itype = input.getType();
mlir::Type otype = signless(odtype);
mgb_assert(signless(idtype) == itype);
if (mlir::FPExtOp::areCastCompatible(itype, otype)) {
return builder.create<mlir::FPExtOp>(loc, otype, input);
} else if (mlir::FPTruncOp::areCastCompatible(itype, otype)) {
return builder.create<mlir::FPTruncOp>(loc, otype, input);
} else if (
mlir::FPToSIOp::areCastCompatible(itype, otype) and
odtype.isSignedInteger()) {
return builder.create<mlir::FPToSIOp>(loc, otype, input);
} else if (
mlir::FPToUIOp::areCastCompatible(itype, otype) and
odtype.isUnsignedInteger()) {
return builder.create<mlir::FPToUIOp>(loc, otype, input);
} else if (
mlir::SIToFPOp::areCastCompatible(itype, otype) and
idtype.isSignedInteger()) {
return builder.create<mlir::SIToFPOp>(loc, otype, input);
} else if (
mlir::UIToFPOp::areCastCompatible(itype, otype) and
idtype.isUnsignedInteger()) {
return builder.create<mlir::UIToFPOp>(loc, otype, input);
} else {
std::string tmp;
llvm::raw_string_ostream os(tmp);
os << "cannot convert from " << idtype << " to " << odtype;
mgb_throw_raw(InternalError{tmp});
}
return nullptr;
}
} }
#endif