#pragma once
#include "megbrain_build_config.h"
#if MGB_JIT && MGB_CUDA
#include <cuda.h>
#include <nvrtc.h>
#include "megbrain/jit/compiler.h"
#define MGB_NVRTC_CHECK(expr) \
do { \
nvrtcResult __nvrtc_result = (expr); \
if (!mgb_likely(__nvrtc_result == NVRTC_SUCCESS)) { \
::mgb::jit::_on_nvrtc_error( \
#expr, __nvrtc_result, __FILE__, __func__, __LINE__); \
} \
} while (0)
namespace mgb {
namespace jit {
[[noreturn]] void _on_nvrtc_error(
const char* expr, nvrtcResult nvrtc_res, const char* file, const char* func,
int line);
class CudaExecutable final : public Executable {
public:
CudaExecutable(std::string source, std::string name);
~CudaExecutable();
void execute(JITExecutor* fusion_opr) override final;
private:
struct FuncCache {
struct Func {
int block_size{-1};
CUmodule module{nullptr};
CUfunction func{nullptr};
};
std::mutex mtx;
std::string ptx;
CompNode::UnorderedMap<Func> cn2func;
void compile(
const std::string& cache_category, int major, int minor,
const CudaExecutable* cuda_exe);
void exec(const JITExecutor* fusion_opr, const CudaExecutable* cuda_exe);
};
const std::string m_source;
const std::string m_name;
std::mutex m_mtx;
ThinHashMap<std::pair<uint32_t, uint32_t>, FuncCache> m_func_cache;
};
class CudaCompiler final : public Compiler {
std::unique_ptr<Executable> do_compile(
const InternalGraph& graph, const JITExecutor::Args& args) override;
public:
static constexpr size_t MAX_CUDA_NR_INPUT = 38;
Property property() const override {
using F = Property::Flag;
return Property{
F::NEED_INPUT_COLLAPSE | F::BIND_NDIM, JITFeatureBits::NONE, 64};
}
size_t get_nr_workspace_outputs(JITExecutor* opr) const override;
void init_workspace_size_infer(JITExecutor* opr) override;
};
} }
#endif