#include <algorithm>
#include <memory>
#include <set>
#include <type_traits>
#include <unordered_set>
#include "db/memtable.h"
#include "memory/arena.h"
#include "memtable/stl_wrappers.h"
#include "port/port.h"
#include "rocksdb/memtablerep.h"
#include "rocksdb/utilities/options_type.h"
#include "util/mutexlock.h"
namespace ROCKSDB_NAMESPACE {
namespace {
class VectorRep : public MemTableRep {
public:
VectorRep(const KeyComparator& compare, Allocator* allocator, size_t count);
void Insert(KeyHandle handle) override;
void InsertConcurrently(KeyHandle handle) override;
bool Contains(const char* key) const override;
void MarkReadOnly() override;
size_t ApproximateMemoryUsage() override;
void Get(const LookupKey& k, void* callback_args,
bool (*callback_func)(void* arg, const char* entry)) override;
void BatchPostProcess() override;
~VectorRep() override = default;
class Iterator : public MemTableRep::Iterator {
class VectorRep* vrep_;
std::shared_ptr<std::vector<const char*>> bucket_;
std::vector<const char*>::const_iterator mutable cit_;
const KeyComparator& compare_;
std::string tmp_; bool mutable sorted_;
void DoSort() const;
public:
explicit Iterator(class VectorRep* vrep,
std::shared_ptr<std::vector<const char*>> bucket,
const KeyComparator& compare);
~Iterator() override = default;
bool Valid() const override;
const char* key() const override;
void Next() override;
void Prev() override;
void Seek(const Slice& user_key, const char* memtable_key) override;
Status SeekAndValidate(const Slice& internal_key, const char* memtable_key,
bool allow_data_in_errors,
bool detect_key_out_of_order,
const std::function<Status(const char*, bool)>&
key_validation_callback) override;
void SeekForPrev(const Slice& user_key, const char* memtable_key) override;
void SeekToFirst() override;
void SeekToLast() override;
};
MemTableRep::Iterator* GetIterator(Arena* arena) override;
private:
friend class Iterator;
ALIGN_AS(CACHE_LINE_SIZE) RelaxedAtomic<size_t> bucket_size_;
using Bucket = std::vector<const char*>;
std::shared_ptr<Bucket> bucket_;
mutable port::RWMutex rwlock_;
bool immutable_;
bool sorted_;
const KeyComparator& compare_;
using TlBucket = std::vector<const char*>;
ThreadLocalPtr tl_writes_;
static void DeleteTlBucket(void* ptr) {
auto* v = static_cast<TlBucket*>(ptr);
delete v;
}
};
void VectorRep::Insert(KeyHandle handle) {
auto* key = static_cast<char*>(handle);
{
WriteLock l(&rwlock_);
assert(!immutable_);
bucket_->push_back(key);
}
bucket_size_.FetchAddRelaxed(1);
}
void VectorRep::InsertConcurrently(KeyHandle handle) {
auto* v = static_cast<TlBucket*>(tl_writes_.Get());
if (!v) {
v = new TlBucket();
tl_writes_.Reset(v);
}
v->push_back(static_cast<char*>(handle));
}
bool VectorRep::Contains(const char* key) const {
ReadLock l(&rwlock_);
return std::find(bucket_->begin(), bucket_->end(), key) != bucket_->end();
}
void VectorRep::MarkReadOnly() {
WriteLock l(&rwlock_);
immutable_ = true;
}
size_t VectorRep::ApproximateMemoryUsage() {
return bucket_size_.LoadRelaxed() *
sizeof(std::remove_reference<decltype(*bucket_)>::type::value_type);
}
void VectorRep::BatchPostProcess() {
auto* v = static_cast<TlBucket*>(tl_writes_.Get());
if (v) {
{
WriteLock l(&rwlock_);
assert(!immutable_);
for (auto& key : *v) {
bucket_->push_back(key);
}
}
bucket_size_.FetchAddRelaxed(v->size());
delete v;
tl_writes_.Reset(nullptr);
}
}
VectorRep::VectorRep(const KeyComparator& compare, Allocator* allocator,
size_t count)
: MemTableRep(allocator),
bucket_size_(0),
bucket_(new Bucket()),
immutable_(false),
sorted_(false),
compare_(compare),
tl_writes_(DeleteTlBucket) {
bucket_.get()->reserve(count);
}
VectorRep::Iterator::Iterator(class VectorRep* vrep,
std::shared_ptr<std::vector<const char*>> bucket,
const KeyComparator& compare)
: vrep_(vrep),
bucket_(bucket),
cit_(bucket_->end()),
compare_(compare),
sorted_(false) {}
void VectorRep::Iterator::DoSort() const {
if (!sorted_ && vrep_ != nullptr) {
WriteLock l(&vrep_->rwlock_);
if (!vrep_->sorted_) {
std::sort(bucket_->begin(), bucket_->end(),
stl_wrappers::Compare(compare_));
cit_ = bucket_->begin();
vrep_->sorted_ = true;
}
sorted_ = true;
}
if (!sorted_) {
std::sort(bucket_->begin(), bucket_->end(),
stl_wrappers::Compare(compare_));
cit_ = bucket_->begin();
sorted_ = true;
}
assert(sorted_);
assert(vrep_ == nullptr || vrep_->sorted_);
}
bool VectorRep::Iterator::Valid() const {
DoSort();
return cit_ != bucket_->end();
}
const char* VectorRep::Iterator::key() const {
assert(sorted_);
return *cit_;
}
void VectorRep::Iterator::Next() {
assert(sorted_);
if (cit_ == bucket_->end()) {
return;
}
++cit_;
}
void VectorRep::Iterator::Prev() {
assert(sorted_);
if (cit_ == bucket_->begin()) {
cit_ = bucket_->end();
} else {
--cit_;
}
}
void VectorRep::Iterator::Seek(const Slice& user_key,
const char* memtable_key) {
DoSort();
const char* encoded_key =
(memtable_key != nullptr) ? memtable_key : EncodeKey(&tmp_, user_key);
cit_ = std::equal_range(bucket_->begin(), bucket_->end(), encoded_key,
[this](const char* a, const char* b) {
return compare_(a, b) < 0;
})
.first;
}
Status VectorRep::Iterator::SeekAndValidate(
const Slice& , const char* ,
bool , bool ,
const std::function<Status(const char*, bool)>&
) {
if (vrep_) {
WriteLock l(&vrep_->rwlock_);
if (bucket_->begin() == bucket_->end()) {
return Status::OK();
} else {
return Status::NotSupported("SeekAndValidate() not implemented");
}
} else {
return Status::NotSupported("SeekAndValidate() not implemented");
}
}
void VectorRep::Iterator::SeekForPrev(const Slice& ,
const char* ) {
assert(false);
}
void VectorRep::Iterator::SeekToFirst() {
DoSort();
cit_ = bucket_->begin();
}
void VectorRep::Iterator::SeekToLast() {
DoSort();
cit_ = bucket_->end();
if (bucket_->size() != 0) {
--cit_;
}
}
void VectorRep::Get(const LookupKey& k, void* callback_args,
bool (*callback_func)(void* arg, const char* entry)) {
rwlock_.ReadLock();
VectorRep* vector_rep;
std::shared_ptr<Bucket> bucket;
if (immutable_) {
vector_rep = this;
} else {
vector_rep = nullptr;
bucket.reset(new Bucket(*bucket_)); }
VectorRep::Iterator iter(vector_rep, immutable_ ? bucket_ : bucket, compare_);
rwlock_.ReadUnlock();
for (iter.Seek(k.user_key(), k.memtable_key().data());
iter.Valid() && callback_func(callback_args, iter.key()); iter.Next()) {
}
}
MemTableRep::Iterator* VectorRep::GetIterator(Arena* arena) {
char* mem = nullptr;
if (arena != nullptr) {
mem = arena->AllocateAligned(sizeof(Iterator));
}
ReadLock l(&rwlock_);
if (immutable_) {
if (arena == nullptr) {
return new Iterator(this, bucket_, compare_);
} else {
return new (mem) Iterator(this, bucket_, compare_);
}
} else {
std::shared_ptr<Bucket> tmp;
tmp.reset(new Bucket(*bucket_)); if (arena == nullptr) {
return new Iterator(nullptr, tmp, compare_);
} else {
return new (mem) Iterator(nullptr, tmp, compare_);
}
}
}
}
static std::unordered_map<std::string, OptionTypeInfo> vector_rep_table_info = {
{"count",
{0, OptionType::kSizeT, OptionVerificationType::kNormal,
OptionTypeFlags::kNone}},
};
VectorRepFactory::VectorRepFactory(size_t count) : count_(count) {
RegisterOptions("VectorRepFactoryOptions", &count_, &vector_rep_table_info);
}
MemTableRep* VectorRepFactory::CreateMemTableRep(
const MemTableRep::KeyComparator& compare, Allocator* allocator,
const SliceTransform*, Logger* ) {
return new VectorRep(compare, allocator, count_);
}
}