#include "./blob_manager_impl.h"
#include <set>
#include "megbrain/utils/arith_helper.h"
namespace mgb {
namespace imperative {
BlobManagerImpl::BlobData::BlobData(Blob* in_blob) {
blob = in_blob;
DeviceTensorStorage d_storage;
d_storage.reset(blob->m_comp_node, blob->m_size, blob->m_storage);
h_storage = HostTensorStorage(blob->m_comp_node);
h_storage.ensure_size(blob->m_size);
h_storage.copy_from(const_cast<DeviceTensorStorage&>(d_storage), blob->m_size);
}
void BlobManagerImpl::register_blob(Blob* blob) {
MGB_LOCK_GUARD(m_mtx);
mgb_assert(m_comp2blobs_map[blob->m_comp_node].insert(blob));
}
void BlobManagerImpl::unregister_blob(Blob* blob) {
MGB_LOCK_GUARD(m_mtx);
mgb_assert(1 == m_comp2blobs_map[blob->m_comp_node].erase(blob));
}
void BlobManagerImpl::alloc_with_defrag(Blob* blob, size_t size) {
if (custom_allocator) {
blob->m_storage = custom_allocator(blob->m_comp_node, size);
return;
}
MGB_TRY { alloc_direct(blob, size); }
MGB_CATCH(MemAllocError&, {
mgb_log_warn("memory allocation failed for blob; try defragmenting");
defrag(blob->m_comp_node);
alloc_direct(blob, size);
});
}
void BlobManagerImpl::alloc_direct(Blob* blob, size_t size) {
DeviceTensorStorage storage(blob->m_comp_node);
mgb_assert(blob->m_comp_node.valid());
storage.ensure_size(size);
blob->m_storage = storage.raw_storage();
}
DeviceTensorND BlobManagerImpl::alloc_workspace_with_defrag(
CompNode cn, TensorLayout& layout) {
DeviceTensorND dev_tensor;
if (custom_allocator) {
DeviceTensorStorage storage(cn);
size_t sz = layout.dtype.size(layout.total_nr_elems());
storage.reset(cn, sz, custom_allocator(cn, sz));
dev_tensor.reset(storage, layout);
return dev_tensor;
}
MGB_TRY { dev_tensor = alloc_workspace(cn, layout); }
MGB_CATCH(MemAllocError&, {
mgb_log_warn("memory allocation failed for workspace; try defragmenting");
defrag(cn);
dev_tensor = alloc_workspace(cn, layout);
});
return dev_tensor;
};
DeviceTensorND BlobManagerImpl::alloc_workspace(CompNode cn, TensorLayout layout) {
DeviceTensorStorage storage(cn);
storage.ensure_size(layout.dtype.size(layout.total_nr_elems()));
DeviceTensorND dev_tensor;
dev_tensor.reset(storage, layout);
return dev_tensor;
}
void BlobManagerImpl::set_allocator(allocator_t allocator) {
custom_allocator = allocator;
}
void BlobManagerImpl::defrag(const CompNode& cn) {
BlobSetWithMux* blobs_set_ptr;
{
MGB_LOCK_GUARD(m_mtx);
blobs_set_ptr = &m_comp2blobs_map[cn];
}
MGB_LOCK_GUARD(blobs_set_ptr->mtx);
std::vector<BlobData> blob_data_arrary;
std::set<Blob::RawStorage> storage_set;
auto alignment = cn.get_mem_addr_alignment();
size_t tot_sz = 0;
for (auto i : blobs_set_ptr->blobs_set) {
if (!i->m_storage)
continue;
if (i->m_storage.use_count() > 1)
continue;
mgb_assert(storage_set.insert(i->m_storage).second);
tot_sz += get_aligned_power2(i->m_size, alignment);
BlobData blob_data(i);
blob_data_arrary.push_back(blob_data);
i->m_storage.reset();
}
storage_set.clear();
if (!blob_data_arrary.size())
return;
CompNode::sync_all();
CompNode::try_coalesce_all_free_memory();
MGB_TRY { cn.free_device(cn.alloc_device(tot_sz)); }
MGB_CATCH(MemAllocError&, {})
std::sort(
blob_data_arrary.begin(), blob_data_arrary.end(),
[](auto& lhs, auto& rhs) { return lhs.blob->id() < rhs.blob->id(); });
for (auto i : blob_data_arrary) {
DeviceTensorStorage d_storage = DeviceTensorStorage(cn);
d_storage.ensure_size(i.blob->m_size);
d_storage.copy_from(i.h_storage, i.blob->m_size);
i.blob->m_storage = d_storage.raw_storage();
}
cn.sync();
}
struct BlobManagerStub : BlobManager {
void alloc_direct(Blob* blob, size_t size) {
mgb_assert(0, "prohibited after global variable destruction");
};
void alloc_with_defrag(Blob* blob, size_t size) {
mgb_assert(0, "prohibited after global variable destruction");
};
DeviceTensorND alloc_workspace_with_defrag(CompNode cn, TensorLayout& layout) {
mgb_assert(0, "prohibited after global variable destruction");
};
void register_blob(Blob* blob) {
mgb_assert(0, "prohibited after global variable destruction");
};
void unregister_blob(Blob* blob){};
void defrag(const CompNode& cn) {
mgb_assert(0, "prohibited after global variable destruction");
};
virtual void set_allocator(allocator_t allocator) {
mgb_assert(0, "prohibited after global variable destruction");
};
};
BlobManager* BlobManager::inst() {
static std::aligned_union_t<0, BlobManagerImpl, BlobManagerStub> storage;
struct Keeper {
Keeper() { new (&storage) BlobManagerImpl(); }
~Keeper() {
reinterpret_cast<BlobManager*>(&storage)->~BlobManager();
new (&storage) BlobManagerStub();
}
};
static Keeper _;
return reinterpret_cast<BlobManager*>(&storage);
}
} }