#include "./memory_manager.h"
#include "src/common/utils.h"
#include "test/common/utils.h"
namespace {
using namespace megdnn;
using namespace test;
std::unique_ptr<MemoryManager> create_memory_manager_from_handle(Handle* handle) {
return make_unique<HandleMemoryManager>(handle);
}
}
megdnn::test::MemoryManagerHolder megdnn::test::MemoryManagerHolder::m_instance;
megdnn::test::HandleMemoryManager::HandleMemoryManager(Handle* handle)
: MemoryManager(), m_handle(handle) {}
void* megdnn::test::HandleMemoryManager::malloc(size_t size) {
auto comp_handle = m_handle->megcore_computing_handle();
megcoreDeviceHandle_t dev_handle;
megcore_check(megcoreGetDeviceHandle(comp_handle, &dev_handle));
void* ptr;
megcore_check(megcoreMalloc(dev_handle, &ptr, size));
return ptr;
}
void megdnn::test::HandleMemoryManager::free(void* ptr) {
auto comp_handle = m_handle->megcore_computing_handle();
megcoreDeviceHandle_t dev_handle;
megcore_check(megcoreGetDeviceHandle(comp_handle, &dev_handle));
megcore_check(megcoreFree(dev_handle, ptr));
}
megdnn::test::MemoryManager* megdnn::test::MemoryManagerHolder::get(Handle* handle) {
std::lock_guard<std::mutex> lock(m_map_mutex);
auto i = m_map.find(handle);
if (i != m_map.end()) {
return i->second.get();
} else {
auto mm = create_memory_manager_from_handle(handle);
auto res = mm.get();
m_map.emplace(std::make_pair(handle, std::move(mm)));
return res;
}
}
void MemoryManagerHolder::update(
Handle* handle, std::unique_ptr<MemoryManager> memory_manager) {
std::lock_guard<std::mutex> lock(m_map_mutex);
m_map[handle] = std::move(memory_manager);
}
void MemoryManagerHolder::clear() {
std::lock_guard<std::mutex> lock(m_map_mutex);
m_map.clear();
}