#pragma once
#include "lite_build_config.h"
#if LITE_BUILD_WITH_MGE
#include "common.h"
#include "megbrain/dtype.h"
#include "network_impl.h"
#include "megbrain/graph/cg.h"
namespace lite {
class UserStaticMemAlloc final : public mgb::cg::DeviceMemoryAllocator {
std::shared_ptr<Allocator> m_allocator = nullptr;
public:
UserStaticMemAlloc(std::shared_ptr<Allocator> allocator) : m_allocator(allocator) {}
void alloc_static(
LComputingGraph*, LDeviceTensorStorage& dest, size_t size) override {
if (size < dest.size()) {
return;
}
auto cn = dest.comp_node_allow_invalid();
LITE_ASSERT(cn.valid(), "The compnode is invalid when alloc memory.");
LiteDeviceType device_type = get_device_from_locator(cn.locator_logical());
int device_id = cn.locator_logical().device;
auto ptr_alloc = static_cast<mgb::dt_byte*>(m_allocator->allocate(
device_type, device_id, size, cn.get_mem_addr_alignment()));
auto storage = std::shared_ptr<mgb::dt_byte>(
ptr_alloc,
[allocator = m_allocator, device_type, device_id](void* ptr) {
allocator->free(device_type, device_id, ptr);
});
dest.reset(cn, size, storage);
}
void alloc_dynamic(
mgb::VarNode*, mgb::DeviceTensorStorage& dest, size_t size) override {
alloc_static(nullptr, dest, size);
}
void defrag_prealloc_contig(
mgb::ComputingGraph*, mgb::CompNode comp_node, size_t size) override {
LiteDeviceType device_type =
get_device_from_locator(comp_node.locator_logical());
int device_id = comp_node.locator_logical().device;
auto ptr_tmp = m_allocator->allocate(
device_type, device_id, size, comp_node.get_mem_addr_alignment());
m_allocator->free(device_type, device_id, ptr_tmp);
}
};
} #endif