#include "atn/PredictionContextMergeCache.h"
#include "misc/MurmurHash.h"
using namespace antlr4::atn;
using namespace antlr4::misc;
PredictionContextMergeCache::PredictionContextMergeCache(
const PredictionContextMergeCacheOptions &options) : _options(options) {}
Ref<const PredictionContext> PredictionContextMergeCache::put(
const Ref<const PredictionContext> &key1,
const Ref<const PredictionContext> &key2,
Ref<const PredictionContext> value) {
assert(key1);
assert(key2);
if (getOptions().getMaxSize() == 0) {
return value;
}
auto [existing, inserted] = _entries.try_emplace(std::make_pair(key1.get(), key2.get()));
if (inserted) {
try {
existing->second.reset(new Entry());
} catch (...) {
_entries.erase(existing);
throw;
}
existing->second->key = std::make_pair(key1, key2);
existing->second->value = std::move(value);
pushToFront(existing->second.get());
} else {
if (existing->second->value != value) {
existing->second->value = std::move(value);
}
moveToFront(existing->second.get());
}
compact(existing->second.get());
return existing->second->value;
}
Ref<const PredictionContext> PredictionContextMergeCache::get(
const Ref<const PredictionContext> &key1,
const Ref<const PredictionContext> &key2) const {
assert(key1);
assert(key2);
if (getOptions().getMaxSize() == 0) {
return nullptr;
}
auto iterator = _entries.find(std::make_pair(key1.get(), key2.get()));
if (iterator == _entries.end()) {
return nullptr;
}
moveToFront(iterator->second.get());
return iterator->second->value;
}
void PredictionContextMergeCache::clear() {
Container().swap(_entries);
_head = _tail = nullptr;
_size = 0;
}
void PredictionContextMergeCache::moveToFront(Entry *entry) const {
if (entry->prev == nullptr) {
assert(entry == _head);
return;
}
entry->prev->next = entry->next;
if (entry->next != nullptr) {
entry->next->prev = entry->prev;
} else {
assert(entry == _tail);
_tail = entry->prev;
}
entry->prev = nullptr;
entry->next = _head;
_head->prev = entry;
_head = entry;
assert(entry->prev == nullptr);
}
void PredictionContextMergeCache::pushToFront(Entry *entry) {
++_size;
entry->prev = nullptr;
entry->next = _head;
if (_head != nullptr) {
_head->prev = entry;
_head = entry;
} else {
assert(entry->next == nullptr);
_head = entry;
_tail = entry;
}
assert(entry->prev == nullptr);
}
void PredictionContextMergeCache::remove(Entry *entry) {
if (entry->prev != nullptr) {
entry->prev->next = entry->next;
} else {
assert(entry == _head);
_head = entry->next;
}
if (entry->next != nullptr) {
entry->next->prev = entry->prev;
} else {
assert(entry == _tail);
_tail = entry->prev;
}
--_size;
_entries.erase(std::make_pair(entry->key.first.get(), entry->key.second.get()));
}
void PredictionContextMergeCache::compact(const Entry *preserve) {
Entry *entry = _tail;
while (entry != nullptr && _size > getOptions().getMaxSize()) {
Entry *next = entry->prev;
if (entry != preserve) {
remove(entry);
}
entry = next;
}
}
size_t PredictionContextMergeCache::PredictionContextHasher::operator()(
const PredictionContextPair &value) const {
size_t hash = MurmurHash::initialize();
hash = MurmurHash::update(hash, value.first->hashCode());
hash = MurmurHash::update(hash, value.second->hashCode());
return MurmurHash::finish(hash, 2);
}
bool PredictionContextMergeCache::PredictionContextComparer::operator()(
const PredictionContextPair &lhs, const PredictionContextPair &rhs) const {
return *lhs.first == *rhs.first && *lhs.second == *rhs.second;
}