#include "lite/network.h"
#include "function_base.h"
#include "network_impl_base.h"
#include "parse_info/parse_info_base.h"
#include "parse_model/model_parser.h"
#include "type_info.h"
#if LITE_BUILD_WITH_MGE
#include "mge/function_dft.h"
#include "mge/network_impl.h"
#endif
#include <fstream>
#include <memory>
using namespace lite;
Network::Network(const Config& config, const NetworkIO& network_io) {
LITE_ERROR_HANDLER_BEGIN
m_config = config;
m_network_io = network_io;
if (config.backend == LiteBackend::LITE_DEFAULT) {
m_impl = call_func<
NetworkImplDft, std::unique_ptr<lite::Network::NetworkImplBase>>(
"create_network");
}
m_impl->set_config(config);
m_impl->set_io(network_io);
LITE_ERROR_HANDLER_END
}
Network::Network(const NetworkIO& network_io, const Config& config) {
LITE_ERROR_HANDLER_BEGIN
m_config = config;
m_network_io = network_io;
if (config.backend == LiteBackend::LITE_DEFAULT) {
m_impl = call_func<
NetworkImplDft, std::unique_ptr<lite::Network::NetworkImplBase>>(
"create_network");
}
m_impl->set_config(config);
m_impl->set_io(network_io);
LITE_ERROR_HANDLER_END
}
void Network::load_model(void* model_mem, size_t size) {
LITE_ERROR_HANDLER_BEGIN
LITE_CHECK_NON_NULL_POINTER(m_impl);
std::shared_ptr<void> model{model_mem, [](void*) {}};
prase_model(model, size);
LITE_ERROR_HANDLER_END
}
void Network::load_model(std::string model_path) {
LITE_ERROR_HANDLER_BEGIN
LITE_CHECK_NON_NULL_POINTER(m_impl);
FILE* fin = fopen(model_path.c_str(), "rb");
LITE_ASSERT(fin, "failed to open %s: %s", model_path.c_str(), strerror(errno));
fseek(fin, 0, SEEK_END);
size_t size = ftell(fin);
fseek(fin, 0, SEEK_SET);
void* ptr = malloc(size);
std::shared_ptr<void> buf{ptr, ::free};
auto nr = fread(buf.get(), 1, size, fin);
LITE_ASSERT(nr == size);
fclose(fin);
prase_model(buf, size);
LITE_ERROR_HANDLER_END
}
void Network::prase_model(std::shared_ptr<void> model_data, size_t size) {
std::unordered_map<std::string, LiteAny> separate_config_map;
ModelParser model_parser(model_data, size);
if (model_parser.parse_model_info(
m_config, m_network_io, separate_config_map, m_extra_info)) {
if (m_config.backend == LiteBackend::LITE_DEFAULT &&
m_impl->get_backend_type() != LiteBackend::LITE_DEFAULT) {
m_impl.reset(try_call_func<NetworkImplDft, lite::Network::NetworkImplBase*>(
"parse_model"));
}
m_impl->set_config(m_config);
m_impl->set_io(m_network_io);
}
size_t model_length;
auto&& model_shared_ptr = model_parser.parse_model(model_length, m_config);
m_impl->load_model(model_shared_ptr, model_length, separate_config_map);
m_loaded = true;
update_from_implement();
}
Network::~Network() = default;
void Network::update_from_implement() {
m_config.device_type = m_impl->get_device_type();
}
void Network::compute_only_configured_output() {
LITE_ERROR_HANDLER_BEGIN
LITE_ASSERT(
!m_loaded,
"compute_only_configured_output should be used before model "
"loaded.");
LITE_CHECK_NON_NULL_POINTER(m_impl);
return m_impl->compute_only_configured_output();
LITE_ERROR_HANDLER_END
}
std::shared_ptr<Tensor> Network::get_io_tensor(
std::string name, LiteTensorPhase phase) {
LITE_ERROR_HANDLER_BEGIN
LITE_ASSERT(m_loaded, "get_io_tensor should be used after model loaded.");
LITE_CHECK_NON_NULL_POINTER(m_impl);
return m_impl->get_io_tensor(name, phase);
LITE_ERROR_HANDLER_END
}
std::shared_ptr<Tensor> Network::get_input_tensor(size_t index) {
LITE_ERROR_HANDLER_BEGIN
LITE_ASSERT(m_loaded, "get_input_tensor should be used after model loaded.");
LITE_CHECK_NON_NULL_POINTER(m_impl);
return m_impl->get_input_tensor(index);
LITE_ERROR_HANDLER_END
}
std::shared_ptr<Tensor> Network::get_output_tensor(size_t index) {
LITE_ERROR_HANDLER_BEGIN
LITE_ASSERT(m_loaded, "get_output_tensor should be used after model loaded.");
LITE_CHECK_NON_NULL_POINTER(m_impl);
return m_impl->get_output_tensor(index);
LITE_ERROR_HANDLER_END
}
Network& Network::set_async_callback(const AsyncCallback& callback) {
LITE_ERROR_HANDLER_BEGIN
LITE_ASSERT(
!m_config.options.force_output_use_user_specified_memory,
"Async mode can't run with force_output_use_user_specified_memory which "
"output data is written to use specific memory.");
LITE_CHECK_NON_NULL_POINTER(m_impl);
m_impl->set_async_callback(std::move(callback));
return *this;
LITE_ERROR_HANDLER_END
}
Network& Network::set_start_callback(const StartCallback& callback) {
LITE_ERROR_HANDLER_BEGIN
LITE_CHECK_NON_NULL_POINTER(m_impl);
m_impl->set_start_callback(std::move(callback));
return *this;
LITE_ERROR_HANDLER_END
}
Network& Network::set_finish_callback(const FinishCallback& callback) {
LITE_ERROR_HANDLER_BEGIN
LITE_CHECK_NON_NULL_POINTER(m_impl);
m_impl->set_finish_callback(std::move(callback));
return *this;
LITE_ERROR_HANDLER_END
}
Network& Network::set_device_id(int device_id) {
LITE_ERROR_HANDLER_BEGIN
LITE_ASSERT(!m_loaded, "set_device_id should be used before model loaded.");
LITE_CHECK_NON_NULL_POINTER(m_impl);
m_impl->set_device_id(device_id);
return *this;
LITE_ERROR_HANDLER_END
}
Network& Network::set_stream_id(int stream_id) {
LITE_ERROR_HANDLER_BEGIN
LITE_ASSERT(!m_loaded, "set_stream_id should be used before model loaded.");
LITE_CHECK_NON_NULL_POINTER(m_impl);
m_impl->set_stream_id(stream_id);
return *this;
LITE_ERROR_HANDLER_END
}
void Network::forward() {
LITE_ERROR_HANDLER_BEGIN
LITE_ASSERT(m_loaded, "forward should be used after model loaded.");
LITE_CHECK_NON_NULL_POINTER(m_impl.get());
m_impl->forward();
LITE_ERROR_HANDLER_END
}
void Network::wait() {
LITE_ERROR_HANDLER_BEGIN
LITE_ASSERT(m_loaded, "wait should be used after model loaded.");
LITE_CHECK_NON_NULL_POINTER(m_impl);
m_impl->wait();
LITE_ERROR_HANDLER_END
}
std::string Network::get_input_name(size_t index) const {
LITE_ERROR_HANDLER_BEGIN
LITE_ASSERT(m_loaded, "get_input_name should be used after model loaded.");
LITE_CHECK_NON_NULL_POINTER(m_impl);
return m_impl->get_input_name(index);
LITE_ERROR_HANDLER_END
}
std::string Network::get_output_name(size_t index) const {
LITE_ERROR_HANDLER_BEGIN
LITE_ASSERT(m_loaded, "get_output_name should be used after model loaded.");
LITE_CHECK_NON_NULL_POINTER(m_impl);
return m_impl->get_output_name(index);
LITE_ERROR_HANDLER_END
}
std::vector<std::string> Network::get_all_input_name() const {
LITE_ERROR_HANDLER_BEGIN
LITE_ASSERT(m_loaded, "get_all_input_name should be used after model loaded.");
LITE_CHECK_NON_NULL_POINTER(m_impl);
auto all_input_name = m_impl->get_all_input_name();
std::vector<std::string> all_names;
for (auto& name : all_input_name) {
all_names.push_back(name);
}
return all_names;
LITE_ERROR_HANDLER_END
}
std::vector<std::string> Network::get_all_output_name() const {
LITE_ERROR_HANDLER_BEGIN
LITE_ASSERT(m_loaded, "get_all_output_name should be used after model loaded.");
LITE_CHECK_NON_NULL_POINTER(m_impl);
auto all_output_name = m_impl->get_all_output_name();
std::vector<std::string> all_names;
for (auto& name : all_output_name) {
all_names.push_back(name);
}
return all_names;
LITE_ERROR_HANDLER_END
}
int Network::get_device_id() const {
LITE_ERROR_HANDLER_BEGIN
LITE_CHECK_NON_NULL_POINTER(m_impl);
return m_impl->get_device_id();
LITE_ERROR_HANDLER_END
}
int Network::get_stream_id() const {
LITE_ERROR_HANDLER_BEGIN
LITE_CHECK_NON_NULL_POINTER(m_impl);
return m_impl->get_stream_id();
LITE_ERROR_HANDLER_END
}
void Network::enable_profile_performance(std::string profile_file_path) {
LITE_ERROR_HANDLER_BEGIN
m_impl->enable_profile_performance(profile_file_path);
LITE_ERROR_HANDLER_END
}
const std::string& Network::get_model_extra_info() {
LITE_ERROR_HANDLER_BEGIN
return m_extra_info;
LITE_ERROR_HANDLER_END
}
LiteDeviceType Network::get_device_type() const {
LITE_ERROR_HANDLER_BEGIN
return m_impl->get_device_type();
LITE_ERROR_HANDLER_END
}
void Network::get_static_memory_alloc_info(const std::string& log_dir) const {
LITE_ERROR_HANDLER_BEGIN
#ifndef __IN_TEE_ENV__
#if MGB_ENABLE_JSON
LITE_ASSERT(m_loaded, "get_all_output_name should be used after model loaded.");
m_impl->get_static_memory_alloc_info(log_dir);
return;
#endif
#endif
LITE_MARK_USED_VAR(log_dir);
LITE_THROW("Doesn't support get_static_memory_alloc_info().Please check macro.");
LITE_ERROR_HANDLER_END
}
void Runtime::set_cpu_threads_number(
std::shared_ptr<Network> network, size_t nr_threads) {
LITE_ERROR_HANDLER_BEGIN
auto network_impl = NetworkHelper::implement(network);
if (network_impl->get_backend_type() == LiteBackend::LITE_DEFAULT) {
LITE_ASSERT(
!NetworkHelper::loaded(network),
"set_cpu_threads_number should be used before model loaded.");
call_func<NetworkImplDft, void>(
"set_cpu_threads_number", network_impl, nr_threads);
return;
}
LITE_THROW("set_cpu_threads_number is not aviliable in the backend.");
LITE_ERROR_HANDLER_END
}
void Runtime::use_tensorrt(std::shared_ptr<Network> network) {
LITE_ERROR_HANDLER_BEGIN
auto network_impl = NetworkHelper::implement(network);
if (network_impl->get_backend_type() == LiteBackend::LITE_DEFAULT) {
LITE_ASSERT(
!NetworkHelper::loaded(network),
"use_tensorrt should be used before model loaded.");
call_func<NetworkImplDft, void>("use_tensorrt", network_impl);
return;
}
LITE_THROW("use_tensorrt is not aviliable in the backend.");
LITE_ERROR_HANDLER_END
}
size_t Runtime::get_cpu_threads_number(const std::shared_ptr<Network> network) {
LITE_ERROR_HANDLER_BEGIN
auto network_impl = NetworkHelper::implement(network);
if (network_impl->get_backend_type() == LiteBackend::LITE_DEFAULT) {
return call_func<NetworkImplDft, size_t>(
"get_cpu_threads_number", network_impl);
}
LITE_THROW("get_cpu_threads_number is not aviliable in the backend.");
LITE_ERROR_HANDLER_END
}
void Runtime::set_runtime_thread_affinity(
std::shared_ptr<Network> network,
const ThreadAffinityCallback& thread_affinity_callback) {
LITE_ERROR_HANDLER_BEGIN
auto network_impl = NetworkHelper::implement(network);
if (network_impl->get_backend_type() == LiteBackend::LITE_DEFAULT) {
LITE_ASSERT(
NetworkHelper::loaded(network),
"set_runtime_thread_affinity should be used after model "
"loaded.");
call_func<NetworkImplDft, void>(
"set_runtime_thread_affinity", network_impl, thread_affinity_callback);
return;
}
LITE_THROW("set_runtime_thread_affinity is not aviliable in the backend.");
LITE_ERROR_HANDLER_END
}
void Runtime::set_cpu_inplace_mode(std::shared_ptr<Network> network) {
LITE_ERROR_HANDLER_BEGIN
auto network_impl = NetworkHelper::implement(network);
if (network_impl->get_backend_type() == LiteBackend::LITE_DEFAULT) {
LITE_ASSERT(
!NetworkHelper::loaded(network),
"set_cpu_inplace_mode should be used before model loaded.");
call_func<NetworkImplDft, void>("set_cpu_inplace_mode", network_impl);
return;
}
LITE_THROW("set_cpu_inplace_mode is not aviliable in the backend.");
LITE_ERROR_HANDLER_END
}
bool Runtime::is_cpu_inplace_mode(const std::shared_ptr<Network> network) {
LITE_ERROR_HANDLER_BEGIN
auto network_impl = NetworkHelper::implement(network);
if (network_impl->get_backend_type() == LiteBackend::LITE_DEFAULT) {
return call_func<NetworkImplDft, bool>("is_cpu_inplace_mode", network_impl);
}
LITE_THROW("is_cpu_inplace_mode is not aviliable in the backend.");
LITE_ERROR_HANDLER_END
}
void Runtime::set_network_algo_policy(
std::shared_ptr<Network> network, LiteAlgoSelectStrategy strategy,
uint32_t shared_batch_size, bool binary_equal_between_batch) {
LITE_ERROR_HANDLER_BEGIN
auto network_impl = NetworkHelper::implement(network);
if (network_impl->get_backend_type() == LiteBackend::LITE_DEFAULT) {
call_func<NetworkImplDft, void>(
"set_network_algo_policy", network_impl, strategy, shared_batch_size,
binary_equal_between_batch);
return;
}
LITE_THROW("set_network_algo_policy is not aviliable in the backend.");
LITE_ERROR_HANDLER_END
}
void Runtime::set_network_algo_workspace_limit(
std::shared_ptr<Network> network, size_t workspace_limit) {
LITE_ERROR_HANDLER_BEGIN
auto network_impl = NetworkHelper::implement(network);
if (network_impl->get_backend_type() == LiteBackend::LITE_DEFAULT) {
LITE_ASSERT(
NetworkHelper::loaded(network),
"set_network_algo_policy should be used after model "
"loaded.");
call_func<NetworkImplDft, void>(
"set_network_algo_workspace_limit", network_impl, workspace_limit);
return;
}
LITE_THROW(
"set_network_algo_workspace_limit is not aviliable in the "
"backend.");
LITE_ERROR_HANDLER_END
}
void Runtime::set_memory_allocator(
std::shared_ptr<Network> network, std::shared_ptr<Allocator> user_allocator) {
LITE_ERROR_HANDLER_BEGIN
auto network_impl = NetworkHelper::implement(network);
if (network_impl->get_backend_type() == LiteBackend::LITE_DEFAULT) {
LITE_ASSERT(
!NetworkHelper::loaded(network),
"set_memory_allocator should be used before model loaded.");
call_func<NetworkImplDft, void>(
"set_memory_allocator", network_impl, user_allocator);
return;
}
LITE_THROW("set_memory_allocator is not aviliable in the backend.");
LITE_ERROR_HANDLER_END
}
void Runtime::share_runtime_memory_with(
std::shared_ptr<Network> dst_network, std::shared_ptr<Network> src_network) {
LITE_ERROR_HANDLER_BEGIN
auto network_impl_dst = NetworkHelper::implement(dst_network);
if (network_impl_dst->get_backend_type() == LiteBackend::LITE_DEFAULT) {
LITE_ASSERT(
!NetworkHelper::loaded(dst_network),
"share_runtime_memory_with should be used before model "
"loaded.");
call_func<NetworkImplDft, void>(
"share_runtime_memory_with", network_impl_dst,
NetworkHelper::implement(src_network));
return;
}
LITE_THROW("share_runtime_memory_with is not aviliable in the backend.");
LITE_ERROR_HANDLER_END
}
void Runtime::enable_io_txt_dump(
std::shared_ptr<Network> network, std::string io_txt_out_file) {
LITE_ERROR_HANDLER_BEGIN
auto network_impl = NetworkHelper::implement(network);
if (network_impl->get_backend_type() == LiteBackend::LITE_DEFAULT) {
call_func<NetworkImplDft, void>(
"enable_io_txt_dump", network_impl, io_txt_out_file);
return;
}
LITE_THROW("enable_io_txt_dump is not aviliable in the backend.");
LITE_ERROR_HANDLER_END
}
void Runtime::enable_io_bin_dump(
std::shared_ptr<Network> network, std::string io_bin_out_dir) {
LITE_ERROR_HANDLER_BEGIN
auto network_impl = NetworkHelper::implement(network);
if (network_impl->get_backend_type() == LiteBackend::LITE_DEFAULT) {
call_func<NetworkImplDft, void>(
"enable_io_bin_dump", network_impl, io_bin_out_dir);
return;
}
LITE_THROW("enable_io_bin_dump is not aviliable in the backend.");
LITE_ERROR_HANDLER_END
}
void Runtime::shared_weight_with_network(
std::shared_ptr<Network> dst_network,
const std::shared_ptr<Network> src_network) {
LITE_ERROR_HANDLER_BEGIN
auto network_impl_dst = NetworkHelper::implement(dst_network);
if (network_impl_dst->get_backend_type() == LiteBackend::LITE_DEFAULT) {
LITE_ASSERT(
NetworkHelper::loaded(src_network),
"shared_weight_with_network should be used after the src "
"network "
"loaded.");
auto src_implment = NetworkHelper::implement(src_network);
call_func<NetworkImplDft, void>(
"shared_weight_with", network_impl_dst, src_implment);
NetworkHelper::loaded(dst_network, true);
return;
}
LITE_THROW("shared_weight_with_network is not aviliable in the backend.");
LITE_ERROR_HANDLER_END
}
void Runtime::enable_global_layout_transform(std::shared_ptr<Network> network) {
LITE_ERROR_HANDLER_BEGIN
auto network_impl = NetworkHelper::implement(network);
if (network_impl->get_backend_type() == LiteBackend::LITE_DEFAULT) {
LITE_ASSERT(
!NetworkHelper::loaded(network),
"enable_global_layout_transform should be used before model loaded.");
call_func<NetworkImplDft, void>("enable_global_layout_transform", network_impl);
return;
}
LITE_THROW("enable_global_layout_transform is not aviliable in the backend.");
LITE_ERROR_HANDLER_END
}
void Runtime::dump_layout_transform_model(
std::shared_ptr<Network> network, std::string optimized_model_path) {
LITE_ERROR_HANDLER_BEGIN
auto network_impl = NetworkHelper::implement(network);
if (network_impl->get_backend_type() == LiteBackend::LITE_DEFAULT) {
LITE_ASSERT(
NetworkHelper::loaded(network),
"dump_layout_transform_model should be used after model loaded.");
call_func<NetworkImplDft, void>(
"dump_layout_transform_model", network_impl, optimized_model_path);
return;
}
LITE_THROW("dump_layout_transform_model is not aviliable in the backend.");
LITE_ERROR_HANDLER_END
}
NetworkIO Runtime::get_model_io_info(
const std::string& model_path, const Config& config) {
LITE_ERROR_HANDLER_BEGIN
if (config.backend == LiteBackend::LITE_DEFAULT) {
return call_func<NetworkImplDft, NetworkIO>(
"get_model_io_info", model_path, config);
}
LITE_THROW("get_model_io_info is not aviliable in the backend.");
LITE_ERROR_HANDLER_END
}
NetworkIO Runtime::get_model_io_info(
const void* model_mem, size_t size, const Config& config) {
LITE_ERROR_HANDLER_BEGIN
if (config.backend == LiteBackend::LITE_DEFAULT) {
return call_func<NetworkImplDft, NetworkIO>(
"get_model_io_info", model_mem, size, config);
}
LITE_THROW("get_model_io_info is not aviliable in the backend.");
LITE_ERROR_HANDLER_END
}