#pragma once
#include "megbrain/test/helper.h"
namespace mgb {
namespace jit {
enum class Backend { NONE, HALIDE, NVRTC, MLIR };
void set_backend(Backend backend);
std::vector<cg::OperatorNodeBase*> get_rev_topo_order(
SymbolVar nd, ThinHashSet<VarNode*> endpoints_set = {});
class FusionChecker {
public:
using ExpFunc = thin_function<SymbolVar(const SymbolVarArray&)>;
FusionChecker(size_t nr_input, ExpFunc exp_func, CompNode cn)
: m_nr_input{nr_input},
m_comp_node{cn},
m_graph{ComputingGraph::make()},
m_exp_func{std::move(exp_func)} {}
FusionChecker& set_dtype(size_t idx, DType dtype) {
m_idx2dtype[idx] = dtype;
return *this;
}
FusionChecker& disable_inp_grad();
FusionChecker& enable_direct_build() {
m_direct_build = true;
return *this;
}
FusionChecker& disable_opr_type_check() {
m_check_opr_type = false;
return *this;
}
FusionChecker& set_jit_level(uint8_t jit_level) {
m_jit_level = jit_level;
return *this;
}
FusionChecker& run(const TensorShapeArray& input_shapes);
private:
bool m_check_opr_type = true;
bool m_direct_build = false;
const size_t m_nr_input;
uint8_t m_jit_level = 2;
const CompNode m_comp_node;
HostTensorGenerator<> m_input_gen;
SmallVector<std::shared_ptr<HostTensorND>> m_inputs_val;
SmallVector<std::tuple<size_t, HostTensorND, HostTensorND>> m_outputs_val;
ThinHashSet<size_t> m_disable_inp_grad;
ThinHashMap<size_t, DType> m_idx2dtype;
std::shared_ptr<ComputingGraph> m_graph;
std::unique_ptr<cg::AsyncExecutable> m_func;
ExpFunc m_exp_func;
SymbolVar m_truth_y, m_jit_y;
void ensure_init_graph();
};
} }