#pragma once
#include "./halide_header.h"
#if MGB_JIT_HALIDE
#include "./ast_hl.h"
#include "megbrain/jit/compiler.h"
#include <atomic>
namespace mgb {
namespace jit {
class HalideExecutable final : public Executable {
public:
struct FunctionHandle {
struct UctxMap {
CompNode::UnorderedMap<void*> cn2uctx;
std::mutex mtx;
};
UctxMap* uctx_map = nullptr;
void* dl_handle = nullptr;
halide_device_interface_t* (*get_device_interface)() = nullptr;
int (*execute)(void** argv) = nullptr;
int (*device_release)(void* user_context) = nullptr;
void swap(FunctionHandle& rhs) {
using T = std::aligned_storage_t<
sizeof(FunctionHandle), alignof(FunctionHandle)>;
T tmp;
tmp = *reinterpret_cast<T*>(this);
*reinterpret_cast<T*>(this) = reinterpret_cast<T&>(rhs);
reinterpret_cast<T&>(rhs) = tmp;
}
void init_uctx_map() {
mgb_assert(!uctx_map);
uctx_map = new UctxMap;
}
FunctionHandle() = default;
FunctionHandle(FunctionHandle&& rhs) { swap(rhs); }
FunctionHandle(const FunctionHandle&) = delete;
FunctionHandle& operator=(FunctionHandle&& rhs) {
swap(rhs);
return *this;
}
FunctionHandle& operator=(const FunctionHandle&) = delete;
~FunctionHandle();
};
struct TargetTraitUserData {
virtual ~TargetTraitUserData() = default;
};
class TargetTrait {
public:
using FunctionHandle = HalideExecutable::FunctionHandle;
using FeatureSet = std::bitset<Halide::Target::FeatureEnd>;
virtual ~TargetTrait() = default;
virtual FeatureSet features(CompNode comp_node) const = 0;
virtual void* get_user_context(CompNode comp_node) = 0;
virtual FunctionHandle compile_and_load(
CompNode comp_node, Halide::Target halide_target,
const HalideExecutable& hl_exec) = 0;
protected:
TargetTraitUserData* user_data(
const HalideExecutable& hl_exec,
thin_function<std::unique_ptr<TargetTraitUserData>()> maker);
};
HalideExecutable(
std::shared_ptr<TargetTrait> trait, const InternalGraph& graph,
const JITExecutor::Args& args);
~HalideExecutable();
void execute(JITExecutor* fusion_opr) override;
std::vector<Halide::Argument> halide_inputs() const;
const ast_hl::AstNodePtr& halide_output() const { return m_halide_output; }
static halide_type_t dtype_mgb2halide(DType dtype);
private:
std::shared_ptr<TargetTrait> const m_target_trait;
ast_hl::AstNodePtr m_halide_output;
SmallVector<std::pair<size_t, ast_hl::AstNodePtr>> m_value_inputs;
std::mutex m_mtx;
std::unordered_map<TargetTrait::FeatureSet, std::pair<std::mutex, FunctionHandle>>
m_feature_set2func;
CompNode::UnorderedMap<std::atomic<FunctionHandle*>> m_cn2func;
mutable std::unique_ptr<TargetTraitUserData> m_target_trait_user_data;
mutable std::mutex m_target_trait_user_data_mtx;
void invoke(
void* user_context, const FunctionHandle& handle,
const VarNodeArray& inputs, VarNode* output);
static ast_hl::AstNodePtr mgb_var_to_halide_buffer(VarNode* var);
FunctionHandle compile_and_load(CompNode comp_node) const;
};
} }
#endif