#include <random>
#include "db/memtable.h"
#include "memory/arena.h"
#include "memtable/inlineskiplist.h"
#include "rocksdb/memtablerep.h"
#include "rocksdb/utilities/options_type.h"
#include "util/string_util.h"
namespace ROCKSDB_NAMESPACE {
namespace {
class SkipListRep : public MemTableRep {
InlineSkipList<const MemTableRep::KeyComparator&> skip_list_;
const MemTableRep::KeyComparator& cmp_;
const SliceTransform* transform_;
const size_t lookahead_;
friend class LookaheadIterator;
public:
explicit SkipListRep(const MemTableRep::KeyComparator& compare,
Allocator* allocator, const SliceTransform* transform,
const size_t lookahead)
: MemTableRep(allocator),
skip_list_(compare, allocator),
cmp_(compare),
transform_(transform),
lookahead_(lookahead) {}
KeyHandle Allocate(const size_t len, char** buf) override {
*buf = skip_list_.AllocateKey(len);
return static_cast<KeyHandle>(*buf);
}
void Insert(KeyHandle handle) override {
skip_list_.Insert(static_cast<char*>(handle));
}
bool InsertKey(KeyHandle handle) override {
return skip_list_.Insert(static_cast<char*>(handle));
}
void InsertWithHint(KeyHandle handle, void** hint) override {
skip_list_.InsertWithHint(static_cast<char*>(handle), hint);
}
bool InsertKeyWithHint(KeyHandle handle, void** hint) override {
return skip_list_.InsertWithHint(static_cast<char*>(handle), hint);
}
void InsertWithHintConcurrently(KeyHandle handle, void** hint) override {
skip_list_.InsertWithHintConcurrently(static_cast<char*>(handle), hint);
}
bool InsertKeyWithHintConcurrently(KeyHandle handle, void** hint) override {
return skip_list_.InsertWithHintConcurrently(static_cast<char*>(handle),
hint);
}
void InsertConcurrently(KeyHandle handle) override {
skip_list_.InsertConcurrently(static_cast<char*>(handle));
}
bool InsertKeyConcurrently(KeyHandle handle) override {
return skip_list_.InsertConcurrently(static_cast<char*>(handle));
}
bool Contains(const char* key) const override {
return skip_list_.Contains(key);
}
size_t ApproximateMemoryUsage() override {
return 0;
}
void Get(const LookupKey& k, void* callback_args,
bool (*callback_func)(void* arg, const char* entry)) override {
SkipListRep::Iterator iter(&skip_list_);
Slice dummy_slice;
for (iter.Seek(dummy_slice, k.memtable_key().data());
iter.Valid() && callback_func(callback_args, iter.key());
iter.Next()) {
}
}
Status GetAndValidate(const LookupKey& k, void* callback_args,
bool (*callback_func)(void* arg, const char* entry),
bool allow_data_in_errors, bool detect_key_out_of_order,
const std::function<Status(const char*, bool)>&
key_validation_callback) override {
SkipListRep::Iterator iter(&skip_list_);
Slice dummy_slice;
Status status = iter.SeekAndValidate(
dummy_slice, k.memtable_key().data(), allow_data_in_errors,
detect_key_out_of_order, key_validation_callback);
for (; iter.Valid() && status.ok() &&
callback_func(callback_args, iter.key());
status = iter.NextAndValidate(allow_data_in_errors)) {
}
return status;
}
uint64_t ApproximateNumEntries(const Slice& start_ikey,
const Slice& end_ikey) override {
return skip_list_.ApproximateNumEntries(start_ikey, end_ikey);
}
void UniqueRandomSample(const uint64_t num_entries,
const uint64_t target_sample_size,
std::unordered_set<const char*>* entries) override {
entries->clear();
assert(target_sample_size > 0);
assert(num_entries > 0);
SkipListRep::Iterator iter(&skip_list_);
if (target_sample_size >
static_cast<uint64_t>(std::sqrt(1.0 * num_entries))) {
Random* rnd = Random::GetTLSInstance();
iter.SeekToFirst();
uint64_t counter = 0, num_samples_left = target_sample_size;
for (; iter.Valid() && (num_samples_left > 0); iter.Next(), counter++) {
if (rnd->Next() % (num_entries - counter) < num_samples_left) {
entries->insert(iter.key());
num_samples_left--;
}
}
} else {
for (uint64_t i = 0; i < target_sample_size; i++) {
for (uint64_t j = 0; j < 5; j++) {
iter.RandomSeek();
if ((entries->insert(iter.key())).second) {
break;
}
}
}
}
}
~SkipListRep() override = default;
class Iterator : public MemTableRep::Iterator {
InlineSkipList<const MemTableRep::KeyComparator&>::Iterator iter_;
public:
explicit Iterator(
const InlineSkipList<const MemTableRep::KeyComparator&>* list)
: iter_(list) {}
~Iterator() override = default;
bool Valid() const override { return iter_.Valid(); }
const char* key() const override {
assert(Valid());
return iter_.key();
}
void Next() override {
assert(Valid());
iter_.Next();
}
void Prev() override {
assert(Valid());
iter_.Prev();
}
void Seek(const Slice& user_key, const char* memtable_key) override {
if (memtable_key != nullptr) {
iter_.Seek(memtable_key);
} else {
iter_.Seek(EncodeKey(&tmp_, user_key));
}
}
void SeekForPrev(const Slice& user_key, const char* memtable_key) override {
if (memtable_key != nullptr) {
iter_.SeekForPrev(memtable_key);
} else {
iter_.SeekForPrev(EncodeKey(&tmp_, user_key));
}
}
void RandomSeek() override { iter_.RandomSeek(); }
void SeekToFirst() override { iter_.SeekToFirst(); }
void SeekToLast() override { iter_.SeekToLast(); }
Status NextAndValidate(bool allow_data_in_errors) override {
assert(Valid());
return iter_.NextAndValidate(allow_data_in_errors);
}
Status SeekAndValidate(const Slice& user_key, const char* memtable_key,
bool allow_data_in_errors,
bool detect_key_out_of_order,
const std::function<Status(const char*, bool)>&
key_validation_callback) override {
if (memtable_key != nullptr) {
return iter_.SeekAndValidate(memtable_key, allow_data_in_errors,
detect_key_out_of_order,
key_validation_callback);
} else {
return iter_.SeekAndValidate(
EncodeKey(&tmp_, user_key), allow_data_in_errors,
detect_key_out_of_order, key_validation_callback);
}
}
Status PrevAndValidate(bool allow_data_in_error) override {
assert(Valid());
return iter_.PrevAndValidate(allow_data_in_error);
}
protected:
std::string tmp_; };
class LookaheadIterator : public MemTableRep::Iterator {
public:
explicit LookaheadIterator(const SkipListRep& rep)
: rep_(rep), iter_(&rep_.skip_list_), prev_(iter_) {}
~LookaheadIterator() override = default;
bool Valid() const override { return iter_.Valid(); }
const char* key() const override {
assert(Valid());
return iter_.key();
}
void Next() override {
assert(Valid());
bool advance_prev = true;
if (prev_.Valid()) {
auto k1 = rep_.UserKey(prev_.key());
auto k2 = rep_.UserKey(iter_.key());
if (k1.compare(k2) == 0) {
advance_prev = false;
} else if (rep_.transform_) {
auto t1 = rep_.transform_->Transform(k1);
auto t2 = rep_.transform_->Transform(k2);
advance_prev = t1.compare(t2) == 0;
}
}
if (advance_prev) {
prev_ = iter_;
}
iter_.Next();
}
void Prev() override {
assert(Valid());
iter_.Prev();
prev_ = iter_;
}
void Seek(const Slice& internal_key, const char* memtable_key) override {
const char* encoded_key = (memtable_key != nullptr)
? memtable_key
: EncodeKey(&tmp_, internal_key);
if (prev_.Valid() && rep_.cmp_(encoded_key, prev_.key()) >= 0) {
iter_ = prev_;
size_t cur = 0;
while (cur++ <= rep_.lookahead_ && iter_.Valid()) {
if (rep_.cmp_(encoded_key, iter_.key()) <= 0) {
return;
}
Next();
}
}
iter_.Seek(encoded_key);
prev_ = iter_;
}
void SeekForPrev(const Slice& internal_key,
const char* memtable_key) override {
const char* encoded_key = (memtable_key != nullptr)
? memtable_key
: EncodeKey(&tmp_, internal_key);
iter_.SeekForPrev(encoded_key);
prev_ = iter_;
}
void SeekToFirst() override {
iter_.SeekToFirst();
prev_ = iter_;
}
void SeekToLast() override {
iter_.SeekToLast();
prev_ = iter_;
}
protected:
std::string tmp_;
private:
const SkipListRep& rep_;
InlineSkipList<const MemTableRep::KeyComparator&>::Iterator iter_;
InlineSkipList<const MemTableRep::KeyComparator&>::Iterator prev_;
};
MemTableRep::Iterator* GetIterator(Arena* arena = nullptr) override {
if (lookahead_ > 0) {
void* mem =
arena ? arena->AllocateAligned(sizeof(SkipListRep::LookaheadIterator))
: operator new(sizeof(SkipListRep::LookaheadIterator));
return new (mem) SkipListRep::LookaheadIterator(*this);
} else {
void* mem = arena ? arena->AllocateAligned(sizeof(SkipListRep::Iterator))
: operator new(sizeof(SkipListRep::Iterator));
return new (mem) SkipListRep::Iterator(&skip_list_);
}
}
};
}
static std::unordered_map<std::string, OptionTypeInfo> skiplist_factory_info = {
{"lookahead",
{0, OptionType::kSizeT, OptionVerificationType::kNormal,
OptionTypeFlags::kDontSerialize }},
};
SkipListFactory::SkipListFactory(size_t lookahead) : lookahead_(lookahead) {
RegisterOptions("SkipListFactoryOptions", &lookahead_,
&skiplist_factory_info);
}
std::string SkipListFactory::GetId() const {
std::string id = Name();
if (lookahead_ > 0) {
id.append(":").append(std::to_string(lookahead_));
}
return id;
}
MemTableRep* SkipListFactory::CreateMemTableRep(
const MemTableRep::KeyComparator& compare, Allocator* allocator,
const SliceTransform* transform, Logger* ) {
return new SkipListRep(compare, allocator, transform, lookahead_);
}
}