#include "megbrain_build_config.h"
#if MGB_JIT && MGB_JIT_MLIR
#include "./common.h"
#include "./each_mode.h"
#include "megbrain/common.h"
#include "megbrain/jit/mlir/ir/dialect.h"
#include "megbrain/jit/mlir/ir/passes.h"
#include "megbrain/jit/mlir/ir/utils.h"
#include <mlir/Dialect/GPU/GPUDialect.h>
#include <mlir/Dialect/SCF/SCF.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/EDSC/Builders.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/Pass/Pass.h>
#include <mlir/Transforms/DialectConversion.h>
using namespace mgb;
using namespace jit;
namespace {
using Rewriter = ConversionPatternRewriter;
using Layout = megdnn::TensorLayout;
struct GpuLoweringHelper {
GpuLoweringHelper(scf::ForOp* for_op, Value index, const Layout& dest)
: m_for_op(for_op), m_index(index), m_dest(dest) {}
void set_insertion_point(OpBuilder& builder) const {
builder.setInsertionPoint(&(m_for_op->getLoopBody().front().back()));
}
std::vector<Value> map_indices(
OpBuilder& builder, Location loc, Value value) const {
auto type = value.getType().dyn_cast_or_null<MemRefType>();
if (!type) {
return {m_index};
}
std::vector<Value> indices(m_dest.ndim);
ValueBuilderHelper helper(builder, loc);
Value dim_index = m_index;
for (int i = m_dest.ndim - 1; i >= 0; i--) {
indices[i] = helper.modI(dim_index, helper.const_i32(m_dest[i]));
dim_index = helper.divI(dim_index, helper.const_i32(m_dest[i]));
}
Layout src_layout = mlir_type_to_layout(type);
src_layout.init_contiguous_stride();
for (int i = 0; i < type.getRank(); ++i) {
if (src_layout[i] == 1) {
indices[i] = helper.const_i32(0);
}
}
return indices;
}
private:
scf::ForOp* m_for_op;
Value m_index;
Layout m_dest;
};
struct AssignOpLowering : public ConversionPattern, public GpuLoweringHelper {
AssignOpLowering(
MLIRContext* ctx, scf::ForOp* for_op, mlir::Value index, const Layout& dest)
: ConversionPattern(dialect::AssignOp::getOperationName(), 2, ctx),
GpuLoweringHelper(for_op, index, dest) {}
LogicalResult matchAndRewrite(
Operation* op, ArrayRef<Value> operands, Rewriter& rewriter) const final {
auto loc = op->getLoc();
set_insertion_point(rewriter);
auto index = map_indices(rewriter, loc, operands[1]);
auto input = get_operand<LoadOp>(rewriter, loc, operands[0], index);
rewriter.create<StoreOp>(loc, input, operands[1], index);
rewriter.eraseOp(op);
return success();
}
};
struct ConstantScalarOpLowering : public OpRewritePattern<dialect::ConstantScalarOp>,
public GpuLoweringHelper {
ConstantScalarOpLowering(
MLIRContext* ctx, scf::ForOp* for_op, Value index, const Layout& dest)
: OpRewritePattern<dialect::ConstantScalarOp>(ctx),
GpuLoweringHelper(for_op, index, dest) {}
LogicalResult matchAndRewrite(
dialect::ConstantScalarOp op, PatternRewriter& rewriter) const final {
set_insertion_point(rewriter);
rewriter.replaceOpWithNewOp<mlir::ConstantOp>(op, op.value());
return success();
}
};
struct DimshuffleLowering : public ConversionPattern, public GpuLoweringHelper {
DimshuffleLowering(
MLIRContext* ctx, scf::ForOp* for_op, Value index, const Layout& dest)
: ConversionPattern(dialect::Dimshuffle::getOperationName(), 1, ctx),
GpuLoweringHelper(for_op, index, dest) {}
static std::vector<mlir::Value> get_index_from_pattern(
const std::vector<int32_t>& pattern,
const std::vector<mlir::Value>& index) {
size_t ndim = *std::max_element(pattern.begin(), pattern.end()) + 1;
std::vector<mlir::Value> res(ndim);
for (size_t i = 0; i < pattern.size(); i++) {
int32_t j = pattern[i];
if (j >= 0) {
res[j] = index[i];
}
}
return res;
}
LogicalResult matchAndRewrite(
Operation* op, ArrayRef<Value> operands, Rewriter& rewriter) const final {
auto loc = op->getLoc();
set_insertion_point(rewriter);
auto pattern = llvm::dyn_cast<dialect::Dimshuffle>(op).pattern();
auto index = map_indices(rewriter, loc, operands[0]);
auto shuffled_index = get_index_from_pattern(pattern, index);
rewriter.replaceOp(
op, get_operand<LoadOp>(rewriter, loc, operands[0], shuffled_index));
return success();
}
};
struct ElemwiseLowering : public ConversionPattern, public GpuLoweringHelper {
ElemwiseLowering(
MLIRContext* ctx, scf::ForOp* for_op, Value index, const Layout& dest)
: ConversionPattern(dialect::Elemwise::getOperationName(), 1, ctx),
GpuLoweringHelper(for_op, index, dest) {}
LogicalResult matchAndRewrite(
Operation* op, ArrayRef<Value> operands, Rewriter& rewriter) const final {
auto loc = op->getLoc();
set_insertion_point(rewriter);
auto inputs =
llvm::to_vector<4>(llvm::map_range(operands, [&](mlir::Value val) {
auto index = map_indices(rewriter, loc, val);
return get_operand<LoadOp>(rewriter, loc, val, index);
}));
rewriter.replaceOp(op, lower_elemwise_to_std(op, rewriter, loc, inputs));
return success();
}
};
struct ReturnOpLowering : public ConversionPattern {
ReturnOpLowering(MLIRContext* ctx, scf::ForOp*, Value, const Layout&)
: ConversionPattern(dialect::ReturnOp::getOperationName(), 1, ctx) {}
LogicalResult matchAndRewrite(
Operation* op, ArrayRef<Value>, Rewriter& rewriter) const final {
rewriter.setInsertionPointToEnd(op->getBlock());
rewriter.replaceOpWithNewOp<gpu::ReturnOp>(op);
return success();
}
};
struct TypeCvtLowering : public ConversionPattern, public GpuLoweringHelper {
TypeCvtLowering(
MLIRContext* ctx, scf::ForOp* for_op, Value index, const Layout& dest)
: ConversionPattern(dialect::TypeCvt::getOperationName(), 1, ctx),
GpuLoweringHelper(for_op, index, dest) {}
LogicalResult matchAndRewrite(
Operation* op, ArrayRef<Value> operands, Rewriter& rewriter) const final {
auto loc = op->getLoc();
set_insertion_point(rewriter);
auto index = map_indices(rewriter, loc, operands[0]);
auto input = get_operand<LoadOp>(rewriter, loc, operands[0], index);
rewriter.replaceOp(op, lower_typecvt_to_std(op, rewriter, loc, input));
return success();
}
};
class MgbToGpuLoweringPass
: public PassWrapper<MgbToGpuLoweringPass, OperationPass<ModuleOp>> {
public:
void getDependentDialects(DialectRegistry& registry) const override;
void runOnOperation() final;
private:
Value get_idx(OpBuilder& builder, Location loc);
Layout get_dest_layout(FuncOp func_op);
};
void MgbToGpuLoweringPass::getDependentDialects(DialectRegistry& registry) const {
registry.insert<gpu::GPUDialect, scf::SCFDialect, StandardOpsDialect>();
}
void MgbToGpuLoweringPass::runOnOperation() {
ModuleOp module_op = getOperation();
FuncOp func_op;
module_op.walk([&](FuncOp fop) {
func_op = fop;
return WalkResult::interrupt();
});
mgb_assert(func_op, "FuncOp not found in the body of ModuleOp");
Location loc = func_op.getLoc();
OpBuilder builder(&(func_op.getBody().front().back()));
auto it = func_op.getArguments().end();
Value nr_threads = *(--it);
Value nr_elements = *(--it);
Value idx = get_idx(builder, loc);
auto for_op = builder.create<scf::ForOp>(loc, idx, nr_elements, nr_threads);
Layout dest = get_dest_layout(func_op);
Value for_idx = for_op.getInductionVar();
OwningRewritePatternList patterns;
patterns
.insert<AssignOpLowering, ConstantScalarOpLowering, DimshuffleLowering,
ElemwiseLowering, ReturnOpLowering, TypeCvtLowering>(
&getContext(), &for_op, for_idx, dest);
ConversionTarget target(getContext());
target.addLegalDialect<gpu::GPUDialect, scf::SCFDialect, StandardOpsDialect>();
target.addIllegalDialect<MgbDialect>();
if (failed(applyPartialConversion(func_op, target, std::move(patterns)))) {
signalPassFailure();
}
std::string kernel_name = func_op.getName().str() + "_kernel";
builder.setInsertionPoint(func_op);
gpu::GPUModuleOp gpu_module_op = builder.create<gpu::GPUModuleOp>(loc, kernel_name);
builder.setInsertionPointToStart(&gpu_module_op.body().front());
gpu::GPUFuncOp gpu_func_op =
builder.create<gpu::GPUFuncOp>(loc, kernel_name, func_op.getType());
gpu_func_op.setAttr(
gpu::GPUDialect::getKernelFuncAttrName(), builder.getUnitAttr());
gpu_func_op.body().takeBody(func_op.getBody());
SymbolTable(module_op).erase(func_op);
}
Value MgbToGpuLoweringPass::get_idx(OpBuilder& builder, Location loc) {
IndexType idx_type = builder.getIndexType();
StringAttr x = builder.getStringAttr("x");
Value block_dim = builder.create<gpu::BlockDimOp>(loc, idx_type, x);
Value block_idx = builder.create<gpu::BlockIdOp>(loc, idx_type, x);
Value thread_idx = builder.create<gpu::ThreadIdOp>(loc, idx_type, x);
Value prod = builder.create<MulIOp>(loc, block_dim, block_idx);
return builder.create<AddIOp>(loc, prod, thread_idx);
}
Layout MgbToGpuLoweringPass::get_dest_layout(FuncOp func_op) {
Layout dest_layout;
bool found = false;
func_op.walk([&](dialect::AssignOp assign_op) {
dest_layout = mlir_type_to_layout(assign_op.lhs().getType());
found = true;
return WalkResult::interrupt();
});
mgb_assert(found, "AssignOp not found in the body of FuncOp");
return dest_layout;
}
}
std::unique_ptr<mlir::Pass> mgb::jit::create_lower_to_gpu_pass() {
return std::make_unique<MgbToGpuLoweringPass>();
}
#endif