megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation

/**
 * \file src/mge/function_dft.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#pragma once
#if LITE_BUILD_WITH_MGE
#include "function_base.h"
#include "network_impl.h"
#include "network_impl_base.h"
#include "tensor_impl.h"
namespace lite {

#define THROW_FUNC_ERROR(func_name)                                   \
    auto msg_info = func_name + "  is not aviliable in Dft backend."; \
    LITE_THROW(msg_info.c_str())

// the functions used for dft's tensor.cpp are as followed:

template <>
inline std::shared_ptr<Tensor::TensorImplBase> call_func<
        TensorImplDft, std::shared_ptr<Tensor::TensorImplBase>>(std::string func_name) {
    if (func_name == "create_tensor") {
        return std::make_shared<TensorImplDft>();
    }
    THROW_FUNC_ERROR(func_name);
}

template <>
inline std::shared_ptr<Tensor::TensorImplBase> call_func<
        TensorImplDft, std::shared_ptr<Tensor::TensorImplBase>>(
        std::string func_name, LiteDeviceType device_type, bool is_pinned_host) {
    if (func_name == "create_tensor") {
        return std::make_shared<TensorImplDft>(device_type, is_pinned_host);
    }
    THROW_FUNC_ERROR(func_name);
}

template <>
inline std::shared_ptr<Tensor::TensorImplBase> call_func<
        TensorImplDft, std::shared_ptr<Tensor::TensorImplBase>>(
        std::string func_name, int device_id, LiteDeviceType device_type,
        const Layout layout, bool is_pinned_host) {
    if (func_name == "create_tensor") {
        return std::make_shared<TensorImplDft>(
                device_id, device_type, layout, is_pinned_host);
    }
    THROW_FUNC_ERROR(func_name);
}

template <>
inline std::shared_ptr<Tensor::TensorImplBase> call_func<
        TensorImplDft, std::shared_ptr<Tensor::TensorImplBase>>(
        std::string func_name, LiteDeviceType device_type, const Layout layout,
        bool is_pinned_host) {
    if (func_name == "create_tensor") {
        return std::make_shared<TensorImplDft>(device_type, layout, is_pinned_host);
    }
    THROW_FUNC_ERROR(func_name);
}

template <>
inline std::shared_ptr<Tensor::TensorImplBase> call_func<
        TensorImplDft, std::shared_ptr<Tensor::TensorImplBase>>(
        std::string func_name, int device_id, int stream_id, LiteDeviceType device_type,
        bool is_pinned_host) {
    if (func_name == "create_tensor") {
        return std::make_shared<TensorImplDft>(
                device_id, stream_id, device_type, is_pinned_host);
    }
    THROW_FUNC_ERROR(func_name);
}

// the functions used for dft's network.cpp are as followed:

template <>
inline std::unique_ptr<Network::NetworkImplBase> call_func<
        NetworkImplDft, std::unique_ptr<Network::NetworkImplBase>>(
        std::string func_name) {
    if (func_name == "create_network") {
        return std::make_unique<NetworkImplDft>();
    }
    THROW_FUNC_ERROR(func_name);
}

template <>
inline Network::NetworkImplBase* try_call_func<
        NetworkImplDft, Network::NetworkImplBase*>(std::string func_name) {
    if (func_name == "parse_model") {
        return new NetworkImplDft();
    }
    THROW_FUNC_ERROR(func_name);
}

#define CALL_FUNC(func_name, ...) \
    network_impl->cast_final_safe<NetworkImplDft>().func_name(__VA_ARGS__)

template <>
inline void call_func<NetworkImplDft, void>(
        std::string func_name, Network::NetworkImplBase* network_impl, size_t num) {
    if (func_name == "set_cpu_threads_number") {
        CALL_FUNC(set_cpu_threads_number, num);
    } else if (func_name == "set_network_algo_workspace_limit") {
        CALL_FUNC(set_network_algo_workspace_limit, num);
    } else {
        THROW_FUNC_ERROR(func_name);
    }
}

template <>
inline void call_func<NetworkImplDft, void>(
        std::string func_name, Network::NetworkImplBase* network_impl) {
    if (func_name == "use_tensorrt") {
        CALL_FUNC(use_tensorrt);
    } else if (func_name == "set_cpu_inplace_mode") {
        CALL_FUNC(set_cpu_inplace_mode);
    } else if (func_name == "enable_global_layout_transform") {
        CALL_FUNC(enable_global_layout_transform);
    } else {
        THROW_FUNC_ERROR(func_name);
    }
}

template <>
inline size_t call_func<NetworkImplDft, size_t>(
        std::string func_name, Network::NetworkImplBase* network_impl) {
    if (func_name == "get_cpu_threads_number") {
        return CALL_FUNC(get_cpu_threads_number);
    }
    THROW_FUNC_ERROR(func_name);
}

template <>
inline bool call_func<NetworkImplDft, bool>(
        std::string func_name, Network::NetworkImplBase* network_impl) {
    if (func_name == "is_cpu_inplace_mode") {
        return CALL_FUNC(is_cpu_inplace_mode);
    }
    THROW_FUNC_ERROR(func_name);
}

template <>
inline void call_func<NetworkImplDft, void>(
        std::string func_name, Network::NetworkImplBase* network_impl,
        ThreadAffinityCallback thread_affinity_callback) {
    if (func_name == "set_runtime_thread_affinity") {
        return CALL_FUNC(
                set_runtime_thread_affinity, std::move(thread_affinity_callback));
    }
    THROW_FUNC_ERROR(func_name);
}

template <>
inline void call_func<NetworkImplDft, void>(
        std::string func_name, Network::NetworkImplBase* network_impl,
        LiteAlgoSelectStrategy strategy, uint32_t shared_batch_size,
        bool binary_equal_between_batch) {
    if (func_name == "set_network_algo_policy") {
        return CALL_FUNC(
                set_network_algo_policy, strategy, shared_batch_size,
                binary_equal_between_batch);
    }
    THROW_FUNC_ERROR(func_name);
}

template <>
inline void call_func<NetworkImplDft, void>(
        std::string func_name, Network::NetworkImplBase* network_impl,
        std::shared_ptr<Allocator> user_allocator) {
    if (func_name == "set_memory_allocator") {
        return CALL_FUNC(set_memory_allocator, user_allocator);
    }
    THROW_FUNC_ERROR(func_name);
}

template <>
inline void call_func<NetworkImplDft, void>(
        std::string func_name, Network::NetworkImplBase* network_impl,
        std::string file_name) {
    if (func_name == "enable_io_txt_dump") {
        return CALL_FUNC(enable_io_txt_dump, file_name);
    } else if (func_name == "enable_io_bin_dump") {
        return CALL_FUNC(enable_io_bin_dump, file_name);
    } else if (func_name == "dump_layout_transform_model") {
        return CALL_FUNC(dump_layout_transform_model, file_name);
    }
    THROW_FUNC_ERROR(func_name);
}

template <>
inline void call_func<NetworkImplDft, void>(
        std::string func_name, Network::NetworkImplBase* network_impl,
        Network::NetworkImplBase* src_network_impl) {
    if (func_name == "share_runtime_memory_with") {
        CALL_FUNC(share_runtime_memory_with, src_network_impl);
    } else if (func_name == "shared_weight_with") {
        CALL_FUNC(shared_weight_with, src_network_impl);
    } else {
        THROW_FUNC_ERROR(func_name);
    }
}

template <>
inline NetworkIO call_func<NetworkImplDft, NetworkIO>(
        std::string func_name, std::string model_path, Config config) {
    if (func_name == "get_model_io_info") {
        return get_model_io_info_dft(model_path, config);
    } else {
        THROW_FUNC_ERROR(func_name);
    }
}

template <>
inline NetworkIO call_func<NetworkImplDft, NetworkIO>(
        std::string func_name, const void* model_mem, size_t size, Config config) {
    if (func_name == "get_model_io_info") {
        return get_model_io_info_dft(model_mem, size, config);
    } else {
        THROW_FUNC_ERROR(func_name);
    }
}
#undef THROW_FUNC_ERROR

}  // namespace lite
#endif

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}