#pragma once
#include "lite_build_config.h"
#if LITE_BUILD_WITH_MGE
#include "lite/tensor.h"
#include "tensor_impl_base.h"
#include "megbrain/tensor.h"
#include <unordered_map>
namespace lite {
class TensorImplDft final : public Tensor::TensorImplBase {
LITE_DYN_TYPE_OBJ_FINAL_DECL;
public:
TensorImplDft();
TensorImplDft(LiteDeviceType device, bool is_pinned_host = false);
TensorImplDft(
LiteDeviceType device, const Layout& layout, bool is_pinned_host = false);
TensorImplDft(
int device_id, LiteDeviceType device, const Layout& layout = {},
bool is_pinned_host = false);
TensorImplDft(
int device_id, int stream_id, LiteDeviceType device,
bool is_pinned_host = false);
virtual ~TensorImplDft() = default;
LiteDeviceType get_device_type() const override;
int get_device_id() const override;
LiteBackend get_backend_type() const override { return LiteBackend::LITE_DEFAULT; }
Layout get_layout() const override;
bool is_pinned_host() const override;
void* get_memory_ptr() const override;
void* get_memory_ptr(const std::vector<size_t>& idx) const override;
void set_layout(const Layout& layout) override;
void reset(void* prepared_data) override;
void reset(void* prepared_data, const Layout& layout) override;
std::shared_ptr<Tensor> slice(
const std::vector<size_t>& start, const std::vector<size_t>& end,
const std::vector<size_t>& step = {}) override;
void fill_zero() override;
void reshape(const Layout& layout) override;
void copy_from(const TensorImplBase* src_impl) override;
void share_memory_with(const TensorImplBase* src_impl) override;
bool is_continue_memory() const override;
std::shared_ptr<mgb::HostTensorND> host_tensor() const { return m_host_tensor; }
std::shared_ptr<mgb::DeviceTensorND> dev_tensor() const { return m_dev_tensor; }
void copy_from_mge_tensor(const mgb::DeviceTensorND& dv);
void set_reset_callback(const std::function<void(TensorImplDft*)>& cb);
void set_get_memory_callback(const std::function<void(TensorImplDft*)>& cb);
void device_share_host_memory();
public:
friend class NetworkImplDft;
private:
bool is_host() const { return m_host_tensor != nullptr; };
void copy_from_continue(const TensorImplBase* src_impl);
void copy_from_fixlayout(const TensorImplBase* src_impl);
void set_mge_tensor_compnode(const mgb::CompNode& comp_node);
private:
bool m_record_reset = false;
std::function<void(TensorImplDft*)> m_get_memory_callback;
std::function<void(TensorImplDft*)> m_reset_callback;
std::shared_ptr<mgb::HostTensorND> m_host_tensor;
std::shared_ptr<mgb::DeviceTensorND> m_dev_tensor;
};
}
#endif