#include "./compiler_cuda.h"
#if MGB_JIT_HALIDE && MGB_CUDA
#include "../nvrtc/compiler_cuda.h"
#include "./ast_hl.h"
#include "megbrain/common.h"
#include "megbrain/comp_node_env.h"
#include "megbrain/jit/utils.h"
#include "megbrain/utils/timer.h"
#include <HalideRuntimeCuda.h>
using namespace mgb;
using namespace jit;
using namespace Halide;
struct HalideCudaTargetTrait::UserData : public HalideExecutable::TargetTraitUserData {
DeviceProp dev_prop; Halide::Pipeline pipeline;
std::mutex mtx;
};
HalideCudaTargetTrait::FeatureSet HalideCudaTargetTrait::features(
CompNode comp_node) const {
FeatureSet set;
set.set(Target::CUDA);
auto&& prop = CompNodeEnv::from_comp_node(comp_node).cuda_env().device_prop;
auto in = [ver = prop.major * 10 + prop.minor](int low, int high) {
return ver >= low && ver < high;
};
if (in(30, 32)) {
set.set(Target::CUDACapability30);
} else if (in(32, 35)) {
set.set(Target::CUDACapability32);
} else if (in(35, 40)) {
set.set(Target::CUDACapability35);
} else if (in(50, 61)) {
set.set(Target::CUDACapability50);
} else if (in(61, 70)) {
set.set(Target::CUDACapability61);
} else {
mgb_log_warn(
"cuda capability(%d.%d) not support for Halide, using compute "
"capability 6.1",
prop.major, prop.minor);
set.set(Target::CUDACapability61);
}
return set;
}
HalideCudaTargetTrait::FunctionHandle HalideCudaTargetTrait::compile_and_load(
CompNode comp_node, Halide::Target target, const HalideExecutable& hl_exec) {
auto&& dev_prop = get_dev_prop(comp_node);
auto func_name = next_kernel_name();
auto&& helper = ExecutableHelper::get();
auto make_ud = [&]() -> std::unique_ptr<HalideExecutable::TargetTraitUserData> {
auto ret = std::make_unique<UserData>();
ret->dev_prop = dev_prop;
ret->pipeline = gen_halide_pipeline_schedule(hl_exec.halide_output(), dev_prop);
return ret;
};
auto ud = static_cast<UserData*>(user_data(hl_exec, make_ud));
mgb_throw_if(
dev_prop.max_threads_per_block != ud->dev_prop.max_threads_per_block,
InternalError,
"halide on multiple devices with different "
"max_threads_per_block is currently not supported");
auto&& pipeline = ud->pipeline;
auto halide_inputs = hl_exec.halide_inputs();
RealTimer timer;
{
MGB_LOCK_GUARD(ud->mtx);
pipeline.compile_to_object(
helper.realpath(func_name + ".o"), halide_inputs, func_name, target);
if (ExecutableHelper::keep_interm()) {
pipeline.compile_to_lowered_stmt(
helper.realpath(func_name + ".stmt"), halide_inputs, Text, target);
}
}
auto time_compile = timer.get_msecs_reset();
FunctionHandle ret;
ret.init_uctx_map();
auto obj_name = func_name + ".o";
ret.dl_handle = helper.link_and_load(
{HalideCudaCompiler::cuda_runtime_lib(), obj_name}, func_name + ".so");
helper.remove_interm(obj_name);
helper.resolve_func(
ret.get_device_interface, ret.dl_handle, "halide_cuda_device_interface");
helper.resolve_func(ret.execute, ret.dl_handle, func_name + "_argv");
helper.resolve_func(
ret.device_release, ret.dl_handle, "halide_cuda_device_release");
auto time_link = timer.get_msecs_reset();
mgb_log("Halide CUDA JIT: compile %s for %s: time_compile=%.3fms "
"time_link=%.3fms",
func_name.c_str(), comp_node.to_string().c_str(), time_compile, time_link);
return ret;
}
void* HalideCudaTargetTrait::get_user_context(CompNode comp_node) {
return &(get_dev_prop(comp_node).ctx);
}
HalideCudaTargetTrait::DeviceProp& HalideCudaTargetTrait::get_dev_prop(
CompNode comp_node) {
MGB_LOCK_GUARD(m_mtx);
auto&& ret = m_cn2prop[comp_node];
if (ret.max_threads_per_block == -1) {
auto&& env = CompNodeEnv::from_comp_node(comp_node).cuda_env();
comp_node.activate();
MGB_CUDA_CU_CHECK(cuCtxGetCurrent(&(ret.ctx.ctx)));
ret.ctx.strm = env.stream;
ret.max_threads_per_block = env.device_prop.maxThreadsPerBlock;
}
return ret;
}
Halide::Pipeline HalideCudaTargetTrait::gen_halide_pipeline_schedule(
const ast_hl::AstNodePtr& dst_output, const DeviceProp& device_prop) {
#if 1
using namespace ast_hl;
std::unordered_set<AstNodePtr> visited;
std::queue<AstNodePtr> q;
for (auto inp : dst_output->m_inputs) {
q.push(inp);
}
std::unordered_set<ReduceOp*> reduce_set;
while (!q.empty()) {
auto top = q.front();
if (visited.count(top)) {
q.pop();
continue;
}
for (auto inp : top->m_inputs) {
q.push(inp);
}
if (!top->same_type<InputDevValueOp>() && !top->same_type<ReduceOp>() &&
!top->same_type<InputHostValueShapeOp>() &&
!top->same_type<BroadcastOp>()) {
top->m_func.compute_inline();
}
if (auto reduce_opr = try_cast_as_op<ReduceOp>(top.get())) {
reduce_set.insert(reduce_opr);
}
visited.insert(top);
q.pop();
}
std::vector<Func> outputs;
auto process_reduce = [&](Func f, Var tx) {
for (auto&& reduce_opr : reduce_set) {
if (reduce_opr->m_comp.defined()) {
reduce_opr->m_comp.compute_at(f, tx);
}
reduce_opr->m_func.compute_at(f, tx);
}
};
auto schedule_elemwise_like = [&process_reduce, &outputs,
&device_prop](const AstNodePtr& output) {
auto& f = output->m_func;
auto vars = f.args();
auto&& layout = output->m_layout;
size_t total_nr_elems = layout.total_nr_elems();
mgb_assert(vars.size() == layout.ndim);
for (int i = layout.ndim - 1; i >= 0; i--) {
f.bound(vars[layout.ndim - 1 - i], 0, static_cast<int>(layout[i]));
}
Var fused = vars[0];
for (size_t i = 1; i < vars.size(); i++) {
output->m_func.fuse(fused, vars[i], fused);
}
const int max_blocks = 65536;
const int max_threads_num = device_prop.max_threads_per_block;
bool need_block_split =
total_nr_elems > static_cast<size_t>(max_blocks * max_threads_num);
const int bt = max_blocks * max_threads_num;
if (need_block_split) {
Var xo, xi;
Var bx, tx;
f.split(fused, xo, xi, bt, TailStrategy::GuardWithIf);
f.split(xi, bx, tx, Expr{max_threads_num}, TailStrategy::GuardWithIf);
f.reorder(xo, tx, bx);
f.unroll(xo);
f.gpu_threads(tx);
f.gpu_blocks(bx);
process_reduce(f, tx);
} else {
Var bx, tx;
f.split(fused, bx, tx, max_threads_num, TailStrategy::GuardWithIf);
f.gpu_threads(tx);
f.gpu_blocks(bx);
process_reduce(f, tx);
}
outputs.push_back(f);
};
auto schedule_reduce = [&process_reduce, &outputs,
&device_prop](const AstNodePtr& output) {
auto& f = output->m_func;
auto& c = try_cast_as_op<ReduceOp>(output.get())->m_comp;
auto vars = f.args();
std::vector<Expr> exprs;
Func real_out;
for (auto var : vars) {
exprs.emplace_back(var);
}
real_out(vars) = f(exprs);
auto layout = output->m_layout;
size_t total_nr_elems = layout.total_nr_elems();
for (int i = layout.ndim - 1; i >= 0; i--) {
real_out.bound(vars[layout.ndim - 1 - i], 0, static_cast<int>(layout[i]));
}
Var fused = vars[0];
for (size_t i = 1; i < vars.size(); i++) {
real_out.fuse(fused, vars[i], fused);
}
const int max_blocks = 65536;
const int max_threads_num = device_prop.max_threads_per_block;
bool need_block_split =
total_nr_elems > static_cast<size_t>(max_blocks * max_threads_num);
const int bt = max_blocks * max_threads_num;
if (need_block_split) {
Var xo, xi;
Var bx, tx;
real_out.split(fused, xo, xi, bt, TailStrategy::GuardWithIf);
real_out.split(
xi, bx, tx, Expr{max_threads_num}, TailStrategy::GuardWithIf);
real_out.reorder(xo, tx, bx);
real_out.unroll(xo);
real_out.gpu_threads(tx);
real_out.gpu_blocks(bx);
f.compute_at(real_out, tx);
if (c.defined())
c.compute_at(real_out, tx);
process_reduce(real_out, tx);
} else {
Var bx, tx;
real_out.split(fused, bx, tx, max_threads_num, TailStrategy::GuardWithIf);
real_out.gpu_threads(tx);
real_out.gpu_blocks(bx);
f.compute_at(real_out, tx);
if (c.defined())
c.compute_at(real_out, tx);
process_reduce(real_out, tx);
}
outputs.push_back(real_out);
};
if (dst_output->same_type<ReduceOp>()) {
schedule_reduce(dst_output);
} else {
schedule_elemwise_like(dst_output);
}
return Pipeline(outputs);
#else#endif
}
std::unique_ptr<Executable> HalideCudaCompiler::do_compile(
const InternalGraph& graph, const JITExecutor::Args& args) {
return std::make_unique<HalideExecutable>(m_trait, graph, args);
}
const std::string& HalideCudaCompiler::cuda_runtime_lib() {
static const char* const source = R"(
#include <cuda.h>
#include <cstdio>
#include <cstdlib>
namespace {
struct HalideUserContext {
CUcontext ctx;
CUstream strm;
};
HalideUserContext* check_user_context(void* user_context) {
if (!user_context) {
fprintf(stderr, "user_context not provided\n");
abort();
}
return static_cast<HalideUserContext*>(user_context);
}
} // anonymous namespace
extern "C" int halide_cuda_acquire_context(void* user_context, CUcontext* ctx,
bool create) {
if (!user_context && !create) {
// called from halide_cuda_cleanup()
return 1;
}
*ctx = check_user_context(user_context)->ctx;
return 0;
}
extern "C" int halide_cuda_release_context(void* user_context) {
return 0;
}
extern "C" int halide_cuda_get_stream(void* user_context, CUcontext ctx,
CUstream* stream) {
*stream = check_user_context(user_context)->strm;
return 0;
}
)";
static std::string name = ExecutableHelper::get().compile_cpp_source_secondary(
source, "halide_cuda_runtime_override");
return name;
}
#endif