#pragma once
#if LITE_BUILD_WITH_MGE
#include "function_base.h"
#include "network_impl.h"
#include "network_impl_base.h"
#include "tensor_impl.h"
namespace lite {
#define THROW_FUNC_ERROR(func_name) \
auto msg_info = func_name + " is not aviliable in Dft backend."; \
LITE_THROW(msg_info.c_str())
template <>
inline std::shared_ptr<Tensor::TensorImplBase> call_func<
TensorImplDft, std::shared_ptr<Tensor::TensorImplBase>>(std::string func_name) {
if (func_name == "create_tensor") {
return std::make_shared<TensorImplDft>();
}
THROW_FUNC_ERROR(func_name);
}
template <>
inline std::shared_ptr<Tensor::TensorImplBase> call_func<
TensorImplDft, std::shared_ptr<Tensor::TensorImplBase>>(
std::string func_name, LiteDeviceType device_type, bool is_pinned_host) {
if (func_name == "create_tensor") {
return std::make_shared<TensorImplDft>(device_type, is_pinned_host);
}
THROW_FUNC_ERROR(func_name);
}
template <>
inline std::shared_ptr<Tensor::TensorImplBase> call_func<
TensorImplDft, std::shared_ptr<Tensor::TensorImplBase>>(
std::string func_name, int device_id, LiteDeviceType device_type,
const Layout layout, bool is_pinned_host) {
if (func_name == "create_tensor") {
return std::make_shared<TensorImplDft>(
device_id, device_type, layout, is_pinned_host);
}
THROW_FUNC_ERROR(func_name);
}
template <>
inline std::shared_ptr<Tensor::TensorImplBase> call_func<
TensorImplDft, std::shared_ptr<Tensor::TensorImplBase>>(
std::string func_name, LiteDeviceType device_type, const Layout layout,
bool is_pinned_host) {
if (func_name == "create_tensor") {
return std::make_shared<TensorImplDft>(device_type, layout, is_pinned_host);
}
THROW_FUNC_ERROR(func_name);
}
template <>
inline std::shared_ptr<Tensor::TensorImplBase> call_func<
TensorImplDft, std::shared_ptr<Tensor::TensorImplBase>>(
std::string func_name, int device_id, int stream_id, LiteDeviceType device_type,
bool is_pinned_host) {
if (func_name == "create_tensor") {
return std::make_shared<TensorImplDft>(
device_id, stream_id, device_type, is_pinned_host);
}
THROW_FUNC_ERROR(func_name);
}
template <>
inline std::unique_ptr<Network::NetworkImplBase> call_func<
NetworkImplDft, std::unique_ptr<Network::NetworkImplBase>>(
std::string func_name) {
if (func_name == "create_network") {
return std::make_unique<NetworkImplDft>();
}
THROW_FUNC_ERROR(func_name);
}
template <>
inline Network::NetworkImplBase* try_call_func<
NetworkImplDft, Network::NetworkImplBase*>(std::string func_name) {
if (func_name == "parse_model") {
return new NetworkImplDft();
}
THROW_FUNC_ERROR(func_name);
}
#define CALL_FUNC(func_name, ...) \
network_impl->cast_final_safe<NetworkImplDft>().func_name(__VA_ARGS__)
template <>
inline void call_func<NetworkImplDft, void>(
std::string func_name, Network::NetworkImplBase* network_impl, size_t num) {
if (func_name == "set_cpu_threads_number") {
CALL_FUNC(set_cpu_threads_number, num);
} else if (func_name == "set_network_algo_workspace_limit") {
CALL_FUNC(set_network_algo_workspace_limit, num);
} else {
THROW_FUNC_ERROR(func_name);
}
}
template <>
inline void call_func<NetworkImplDft, void>(
std::string func_name, Network::NetworkImplBase* network_impl) {
if (func_name == "use_tensorrt") {
CALL_FUNC(use_tensorrt);
} else if (func_name == "set_cpu_inplace_mode") {
CALL_FUNC(set_cpu_inplace_mode);
} else if (func_name == "enable_global_layout_transform") {
CALL_FUNC(enable_global_layout_transform);
} else {
THROW_FUNC_ERROR(func_name);
}
}
template <>
inline size_t call_func<NetworkImplDft, size_t>(
std::string func_name, Network::NetworkImplBase* network_impl) {
if (func_name == "get_cpu_threads_number") {
return CALL_FUNC(get_cpu_threads_number);
}
THROW_FUNC_ERROR(func_name);
}
template <>
inline bool call_func<NetworkImplDft, bool>(
std::string func_name, Network::NetworkImplBase* network_impl) {
if (func_name == "is_cpu_inplace_mode") {
return CALL_FUNC(is_cpu_inplace_mode);
}
THROW_FUNC_ERROR(func_name);
}
template <>
inline void call_func<NetworkImplDft, void>(
std::string func_name, Network::NetworkImplBase* network_impl,
ThreadAffinityCallback thread_affinity_callback) {
if (func_name == "set_runtime_thread_affinity") {
return CALL_FUNC(
set_runtime_thread_affinity, std::move(thread_affinity_callback));
}
THROW_FUNC_ERROR(func_name);
}
template <>
inline void call_func<NetworkImplDft, void>(
std::string func_name, Network::NetworkImplBase* network_impl,
LiteAlgoSelectStrategy strategy, uint32_t shared_batch_size,
bool binary_equal_between_batch) {
if (func_name == "set_network_algo_policy") {
return CALL_FUNC(
set_network_algo_policy, strategy, shared_batch_size,
binary_equal_between_batch);
}
THROW_FUNC_ERROR(func_name);
}
template <>
inline void call_func<NetworkImplDft, void>(
std::string func_name, Network::NetworkImplBase* network_impl,
std::shared_ptr<Allocator> user_allocator) {
if (func_name == "set_memory_allocator") {
return CALL_FUNC(set_memory_allocator, user_allocator);
}
THROW_FUNC_ERROR(func_name);
}
template <>
inline void call_func<NetworkImplDft, void>(
std::string func_name, Network::NetworkImplBase* network_impl,
std::string file_name) {
if (func_name == "enable_io_txt_dump") {
return CALL_FUNC(enable_io_txt_dump, file_name);
} else if (func_name == "enable_io_bin_dump") {
return CALL_FUNC(enable_io_bin_dump, file_name);
} else if (func_name == "dump_layout_transform_model") {
return CALL_FUNC(dump_layout_transform_model, file_name);
}
THROW_FUNC_ERROR(func_name);
}
template <>
inline void call_func<NetworkImplDft, void>(
std::string func_name, Network::NetworkImplBase* network_impl,
Network::NetworkImplBase* src_network_impl) {
if (func_name == "share_runtime_memory_with") {
CALL_FUNC(share_runtime_memory_with, src_network_impl);
} else if (func_name == "shared_weight_with") {
CALL_FUNC(shared_weight_with, src_network_impl);
} else {
THROW_FUNC_ERROR(func_name);
}
}
template <>
inline NetworkIO call_func<NetworkImplDft, NetworkIO>(
std::string func_name, std::string model_path, Config config) {
if (func_name == "get_model_io_info") {
return get_model_io_info_dft(model_path, config);
} else {
THROW_FUNC_ERROR(func_name);
}
}
template <>
inline NetworkIO call_func<NetworkImplDft, NetworkIO>(
std::string func_name, const void* model_mem, size_t size, Config config) {
if (func_name == "get_model_io_info") {
return get_model_io_info_dft(model_mem, size, config);
} else {
THROW_FUNC_ERROR(func_name);
}
}
#undef THROW_FUNC_ERROR
} #endif