#include "megbrain_build_config.h"
#if MGB_JIT && MGB_JIT_MLIR
#include "megbrain/jit/mlir/ir/utils.h"
#include "./types.h"
#include "megbrain/common.h"
#include "megbrain/exception.h"
#include "megdnn/basic_types.h"
#include "megdnn/oprs/general.h"
#include <mlir/Dialect/Affine/IR/AffineOps.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/IR/Types.h>
#include <mlir/Support/LLVM.h>
using namespace mgb;
using namespace jit;
mlir::Value jit::insert_alloc_and_dealloc(
mlir::MemRefType type, mlir::Location loc, mlir::PatternRewriter& rewriter) {
auto alloc = rewriter.create<mlir::AllocOp>(loc, type);
auto* parent_block = alloc.getOperation()->getBlock();
alloc.getOperation()->moveBefore(&parent_block->front());
auto dealloc = rewriter.create<mlir::DeallocOp>(loc, alloc);
dealloc.getOperation()->moveBefore(&parent_block->back());
return alloc;
}
mlir::Type jit::deduce_elemwise_res_type(mlir::ValueRange operands) {
megdnn::TensorShapeArray srcs;
megdnn::TensorShape dst;
megdnn::DType dst_type;
for (auto operand : operands) {
if (operand.getType().isa<mlir::FloatType>()) {
continue;
}
auto type = operand.getType().dyn_cast_or_null<mlir::MemRefType>();
mgb_assert(type, "currently only support MemRefType");
srcs.push_back(mlir_type_to_layout(type));
}
megdnn::Elemwise::deduce_shape(srcs, dst);
mlir::Builder builder(operands[0].getContext());
return layout_to_mlir_type(
{dst, mlir_type_to_megdnn_dtype(operands[0].getType())}, builder);
}
megdnn::TensorLayout jit::mlir_type_to_layout(mlir::Type type) {
megdnn::TensorLayout ret;
if (type.isa<mlir::MemRefType>()) {
auto real_type = type.dyn_cast_or_null<mlir::MemRefType>();
mgb_assert(real_type);
ret.ndim = real_type.getRank();
for (size_t i = 0; i < ret.ndim; i++) {
ret.shape[i] = real_type.getDimSize(i);
}
ret.dtype = mlir_type_to_megdnn_dtype(real_type.getElementType());
}
return ret;
}
mlir::MemRefType jit::layout_to_mlir_type(
const megdnn::TensorLayout& layout, mlir::Builder& builder) {
std::vector<int64_t> shape;
for (size_t i = 0; i < layout.ndim; i++) {
shape.push_back(layout[i]);
}
mlir::Type type = megdnn_dtype_to_mlir_type(layout.dtype, builder.getContext());
return mlir::MemRefType::get(shape, signless(type));
}
#endif