#pragma once
#include "megbrain/test/helper.h"
#if MGB_ENABLE_TENSOR_RT
namespace mgb {
namespace tensorrt {
class TrtReplaceChecker {
public:
using ExpFunc = thin_function<SymbolVar(const SymbolVarArray&)>;
TrtReplaceChecker(size_t nr_input, ExpFunc exp_func, CompNode cn)
: m_nr_input{nr_input},
m_comp_node{cn},
m_graph{ComputingGraph::make()},
m_exp_func{std::move(exp_func)},
m_epsilon{1e-5} {}
TrtReplaceChecker& set_dtype(size_t idx, DType dtype) {
m_idx2dtype[idx] = dtype;
return *this;
}
TrtReplaceChecker& set_rng_gen(size_t idx, HostTensorGeneratorBase* rng_gen) {
m_idx2rng_gen[idx] = rng_gen;
return *this;
}
TrtReplaceChecker& set_const_var(size_t idx) {
m_mark_inp_const.insert(idx);
return *this;
}
TrtReplaceChecker& set_epsilon(float epsilon) {
m_epsilon = epsilon;
return *this;
}
TrtReplaceChecker& run(const TensorShapeArray& input_shapes);
private:
const size_t m_nr_input;
const CompNode m_comp_node;
HostTensorGenerator<> m_input_gen;
SmallVector<std::shared_ptr<HostTensorND>> m_inputs_val;
std::tuple<HostTensorND, HostTensorND> m_output_val;
ThinHashMap<size_t, DType> m_idx2dtype;
ThinHashMap<size_t, HostTensorGeneratorBase*> m_idx2rng_gen;
ThinHashSet<size_t> m_mark_inp_const;
std::shared_ptr<ComputingGraph> m_graph;
std::unique_ptr<cg::AsyncExecutable> m_func;
ExpFunc m_exp_func;
SymbolVar m_truth_y, m_trt_y;
float m_epsilon;
void ensure_init_graph();
};
} }
#endif