#pragma once
#include "megbrain_build_config.h"
#if MGB_JIT && MGB_JIT_MLIR
#include "megbrain/tensor.h"
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/IR/OperationSupport.h>
#include <mlir/IR/Value.h>
namespace mgb {
namespace jit {
class ValueBuilderHelper {
public:
ValueBuilderHelper(mlir::OpBuilder& b, mlir::Location location)
: m_builder{b}, m_location{location} {};
#define cb(name) \
mlir::Value name(mlir::ValueRange operands) { return name(operands[0]); } \
mlir::Value name(mlir::Value lhs)
cb(abs);
cb(ceil);
cb(cos);
cb(exp);
cb(exp2);
cb(floor);
cb(log);
cb(log10);
cb(log2);
cb(neg);
cb(rsqrt);
cb(sin);
cb(sqrt);
cb(tanh);
#undef cb
#define cb(name) \
mlir::Value name(mlir::ValueRange operands) { \
return name(operands[0], operands[1]); \
} \
mlir::Value name(mlir::Value lhs, mlir::Value rhs)
cb(add);
cb(bit_and);
cb(bit_or);
cb(div);
cb(divI);
cb(eq);
cb(ge);
cb(gt);
cb(le);
cb(lt);
cb(max);
cb(min);
cb(mod);
cb(modI);
cb(mul);
cb(sub);
#undef cb
mlir::Value const_f32(float val);
mlir::Value const_i32(int32_t val);
mlir::Value select(mlir::Value cond, mlir::Value true_val, mlir::Value false_val);
private:
mlir::OpBuilder& m_builder;
mlir::Location m_location;
};
template <typename Op>
mlir::Value get_operand(
mlir::OpBuilder& builder, const mlir::Location& loc, const mlir::Value& val,
const mlir::ValueRange& index) {
if (val.getType().isa<mlir::MemRefType>()) {
return builder.create<Op>(loc, val, index);
} else {
return val;
}
}
mlir::AffineMap get_affinemap(
mlir::OpBuilder& builder, const mlir::Value& val, const TensorLayout& layout);
mlir::Value get_affine_load_op(
mlir::OpBuilder& builder, const mlir::Location& loc, const mlir::Value& val,
const mlir::ValueRange& index, const TensorLayout& dst);
} }
#endif