megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/**
 * \file src/tensorrt/test/helper.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#pragma once

#include "megbrain/test/helper.h"

#if MGB_ENABLE_TENSOR_RT

namespace mgb {
namespace tensorrt {
/*!
 * \brief helper class for testing fusions on specific funcs
 *
 * The tensorrt opr would be created based on automatic opr replace pass
 */
class TrtReplaceChecker {
public:
    using ExpFunc = thin_function<SymbolVar(const SymbolVarArray&)>;

    TrtReplaceChecker(size_t nr_input, ExpFunc exp_func, CompNode cn)
            : m_nr_input{nr_input},
              m_comp_node{cn},
              m_graph{ComputingGraph::make()},
              m_exp_func{std::move(exp_func)},
              m_epsilon{1e-5} {}

    //! set input data type, which is float32 by default
    TrtReplaceChecker& set_dtype(size_t idx, DType dtype) {
        m_idx2dtype[idx] = dtype;
        return *this;
    }

    //! set input rng generator, which is default generator of float32
    TrtReplaceChecker& set_rng_gen(size_t idx, HostTensorGeneratorBase* rng_gen) {
        m_idx2rng_gen[idx] = rng_gen;
        return *this;
    }

    //! set input is a const var node
    TrtReplaceChecker& set_const_var(size_t idx) {
        m_mark_inp_const.insert(idx);
        return *this;
    }

    //! set epsilon
    TrtReplaceChecker& set_epsilon(float epsilon) {
        m_epsilon = epsilon;
        return *this;
    }

    /*!
     * \brief run and check correctness
     *
     * The graph would be built (and m_exp_func is invoked) on first call.
     */
    TrtReplaceChecker& run(const TensorShapeArray& input_shapes);

private:
    const size_t m_nr_input;
    const CompNode m_comp_node;
    HostTensorGenerator<> m_input_gen;
    SmallVector<std::shared_ptr<HostTensorND>> m_inputs_val;
    //! first item is output; following are input grads
    std::tuple<HostTensorND, HostTensorND> m_output_val;
    ThinHashMap<size_t, DType> m_idx2dtype;
    ThinHashMap<size_t, HostTensorGeneratorBase*> m_idx2rng_gen;
    ThinHashSet<size_t> m_mark_inp_const;
    std::shared_ptr<ComputingGraph> m_graph;
    std::unique_ptr<cg::AsyncExecutable> m_func;

    ExpFunc m_exp_func;
    SymbolVar m_truth_y, m_trt_y;
    float m_epsilon;

    //! init m_graph and related fields; m_inputs_val must have been initialized
    void ensure_init_graph();
};

}  // namespace tensorrt
}  // namespace mgb

#endif

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}