#pragma once
#include "megbrain_build_config.h"
#if MGB_JIT && MGB_JIT_MLIR
#if MGB_CUDA
#include "megbrain/jit/compiler.h"
#include <mlir/IR/Module.h>
#include <cuda.h>
namespace mgb {
namespace jit {
class MLIRCUDAExecutable final : public Executable {
public:
MLIRCUDAExecutable(mlir::OwningModuleRef& module, const std::string& kernel_name);
~MLIRCUDAExecutable();
void execute(JITExecutor* fusion_opr) override final;
const static std::string sm_blob_annotation;
private:
struct FuncCache {
struct Func {
int block_size{-1};
CUmodule module{nullptr};
CUfunction func{nullptr};
};
std::mutex mtx;
std::string kernel_data;
CompNode::UnorderedMap<Func> cn2func;
void exec(const JITExecutor* fusion_opr, const MLIRCUDAExecutable* cuda_exe);
};
std::string m_kernel_name;
std::string m_kernel_data;
ThinHashMap<std::pair<uint32_t, uint32_t>, FuncCache> m_func_cache;
};
} }
#endif #endif