#ifndef wasm_mixed_arena_h
#define wasm_mixed_arena_h
#include <atomic>
#include <cassert>
#include <memory>
#include <mutex>
#include <thread>
#include <type_traits>
#include <vector>
#include <support/alloc.h>
struct MixedArena {
static const size_t CHUNK_SIZE = 32768;
static const size_t MAX_ALIGN = 16;
std::vector<void*> chunks;
size_t index = 0;
std::thread::id threadId;
std::atomic<MixedArena*> next;
MixedArena() {
threadId = std::this_thread::get_id();
next.store(nullptr);
}
void* allocSpace(size_t size, size_t align) {
auto myId = std::this_thread::get_id();
if (myId != threadId) {
MixedArena* curr = this;
MixedArena* allocated = nullptr;
while (myId != curr->threadId) {
auto seen = curr->next.load();
if (seen) {
curr = seen;
continue;
}
if (!allocated) {
allocated = new MixedArena(); }
if (curr->next.compare_exchange_strong(seen, allocated)) {
allocated = nullptr;
break;
}
curr = seen;
}
if (allocated) {
delete allocated;
}
return curr->allocSpace(size, align);
}
index = (index + align - 1) & (-align);
if (index + size > CHUNK_SIZE || chunks.size() == 0) {
auto numChunks = (size + CHUNK_SIZE - 1) / CHUNK_SIZE;
assert(size <= numChunks * CHUNK_SIZE);
auto* allocation =
wasm::aligned_malloc(MAX_ALIGN, numChunks * CHUNK_SIZE);
if (!allocation) {
abort();
}
chunks.push_back(allocation);
index = 0;
}
uint8_t* ret = static_cast<uint8_t*>(chunks.back());
ret += index;
index += size; return static_cast<void*>(ret);
}
template<class T> T* alloc() {
static_assert(alignof(T) <= MAX_ALIGN,
"maximum alignment not large enough");
auto* ret = static_cast<T*>(allocSpace(sizeof(T), alignof(T)));
new (ret) T(*this); return ret;
}
void clear() {
for (auto* chunk : chunks) {
wasm::aligned_free(chunk);
}
chunks.clear();
}
~MixedArena() {
clear();
if (next.load()) {
delete next.load();
}
}
};
template<typename SubType, typename T> class ArenaVectorBase {
protected:
T* data = nullptr;
size_t usedElements = 0, allocatedElements = 0;
void reallocate(size_t size) {
T* old = data;
static_cast<SubType*>(this)->allocate(size);
for (size_t i = 0; i < usedElements; i++) {
data[i] = old[i];
}
}
public:
struct Iterator;
T& operator[](size_t index) const {
assert(index < usedElements);
return data[index];
}
size_t size() const { return usedElements; }
bool empty() const { return size() == 0; }
void resize(size_t size) {
if (size > allocatedElements) {
reallocate(size);
}
for (size_t i = usedElements; i < size; i++) {
new (data + i) T();
}
usedElements = size;
}
T& back() const {
assert(usedElements > 0);
return data[usedElements - 1];
}
T& pop_back() {
assert(usedElements > 0);
usedElements--;
return data[usedElements];
}
void push_back(T item) {
if (usedElements == allocatedElements) {
reallocate((allocatedElements + 1) * 2); }
data[usedElements] = item;
usedElements++;
}
T& front() const {
assert(usedElements > 0);
return data[0];
}
void erase(Iterator start_it, Iterator end_it) {
assert(start_it.parent == end_it.parent && start_it.parent == this);
assert(start_it.index <= end_it.index && end_it.index <= usedElements);
size_t size = end_it.index - start_it.index;
for (size_t cur = start_it.index; cur + size < usedElements; ++cur) {
data[cur] = data[cur + size];
}
usedElements -= size;
}
void erase(Iterator it) { erase(it, it + 1); }
void clear() { usedElements = 0; }
void reserve(size_t size) {
if (size > allocatedElements) {
reallocate(size);
}
}
template<typename ListType> void set(const ListType& list) {
size_t size = list.size();
if (allocatedElements < size) {
static_cast<SubType*>(this)->allocate(size);
}
size_t i = 0;
for (auto elem : list) {
data[i++] = elem;
}
usedElements = size;
}
void operator=(SubType& other) { set(other); }
void swap(SubType& other) {
data = other.data;
usedElements = other.usedElements;
allocatedElements = other.allocatedElements;
other.data = nullptr;
other.usedElements = other.allocatedElements = 0;
}
struct Iterator {
using iterator_category = std::random_access_iterator_tag;
using value_type = T;
using difference_type = std::ptrdiff_t;
using pointer = T*;
using reference = T&;
const SubType* parent;
size_t index;
Iterator() : parent(nullptr), index(0) {}
Iterator(const SubType* parent, size_t index)
: parent(parent), index(index) {}
bool operator==(const Iterator& other) const {
return index == other.index && parent == other.parent;
}
bool operator!=(const Iterator& other) const { return !(*this == other); }
bool operator<(const Iterator& other) const {
assert(parent == other.parent);
return index < other.index;
}
bool operator>(const Iterator& other) const { return other < *this; }
bool operator<=(const Iterator& other) const { return !(other < *this); }
bool operator>=(const Iterator& other) const { return !(*this < other); }
Iterator& operator++() {
index++;
return *this;
}
Iterator& operator--() {
index--;
return *this;
}
Iterator operator++(int) {
Iterator it = *this;
++*this;
return it;
}
Iterator operator--(int) {
Iterator it = *this;
--*this;
return it;
}
Iterator& operator+=(std::ptrdiff_t off) {
index += off;
return *this;
}
Iterator& operator-=(std::ptrdiff_t off) { return *this += -off; }
Iterator operator+(std::ptrdiff_t off) const {
return Iterator(*this) += off;
}
Iterator operator-(std::ptrdiff_t off) const { return *this + -off; }
std::ptrdiff_t operator-(const Iterator& other) const {
assert(parent == other.parent);
return index - other.index;
}
friend Iterator operator+(std::ptrdiff_t off, const Iterator& it) {
return it + off;
}
T& operator*() const { return (*parent)[index]; }
T& operator[](std::ptrdiff_t off) const { return (*parent)[index + off]; }
T* operator->() const { return &(*parent)[index]; }
};
Iterator begin() const {
return Iterator(static_cast<const SubType*>(this), 0);
}
Iterator end() const {
return Iterator(static_cast<const SubType*>(this), usedElements);
}
void allocate(size_t size) {
abort(); }
void insertAt(size_t index, T item) {
assert(index <= size()); resize(size() + 1);
for (auto i = size() - 1; i > index; --i) {
data[i] = data[i - 1];
}
data[index] = item;
}
T removeAt(size_t index) {
assert(index < size());
auto item = data[index];
for (auto i = index; i < size() - 1; ++i) {
data[i] = data[i + 1];
}
resize(size() - 1);
return item;
}
};
template<typename T>
class ArenaVector : public ArenaVectorBase<ArenaVector<T>, T> {
private:
MixedArena& allocator;
public:
ArenaVector(MixedArena& allocator) : allocator(allocator) {}
ArenaVector(ArenaVector<T>&& other) : allocator(other.allocator) {
swap(other);
}
ArenaVector<T>& operator=(ArenaVector<T>&& other) {
if (this != &other) {
this->clear();
this->swap(other);
}
return *this;
}
void allocate(size_t size) {
this->allocatedElements = size;
this->data = static_cast<T*>(
allocator.allocSpace(sizeof(T) * this->allocatedElements, alignof(T)));
}
};
#endif