#pragma once
#include "lite/tensor.h"
#include "misc.h"
#include "type_info.h"
#include <unordered_map>
namespace lite {
class Tensor::TensorImplBase : public DynTypeObj {
public:
virtual ~TensorImplBase() = default;
virtual LiteDeviceType get_device_type() const = 0;
virtual int get_device_id() const = 0;
virtual LiteBackend get_backend_type() const = 0;
virtual Layout get_layout() const = 0;
virtual bool is_pinned_host() const = 0;
virtual void* get_memory_ptr() const = 0;
virtual void* get_memory_ptr(const std::vector<size_t>& idx) const = 0;
virtual void set_layout(const Layout& layout) = 0;
virtual void reset(void* prepared_data) = 0;
virtual void reset(void* prepared_data, const Layout& layout) = 0;
virtual void reshape(const Layout& layout) = 0;
virtual std::shared_ptr<Tensor> slice(
const std::vector<size_t>& start, const std::vector<size_t>& end,
const std::vector<size_t>& step = {}) = 0;
virtual void fill_zero() = 0;
virtual void copy_from(const TensorImplBase* src_impl) = 0;
virtual void share_memory_with(const TensorImplBase* src_impl) = 0;
virtual bool is_continue_memory() const = 0;
};
class TensorHelper {
public:
static inline std::shared_ptr<Tensor::TensorImplBase> implement(
const std::shared_ptr<Tensor> tensor) {
LITE_ASSERT(tensor);
return tensor->m_tensor_impl;
}
static inline std::shared_ptr<Tensor::TensorImplBase> implement(
const Tensor* tensor) {
LITE_ASSERT(tensor);
return tensor->m_tensor_impl;
}
static inline void implement(
const std::shared_ptr<Tensor> tensor,
std::shared_ptr<Tensor::TensorImplBase> impl) {
LITE_ASSERT(tensor);
tensor->m_tensor_impl = impl;
}
};
}