#include "util/thread_local.h"
#include <atomic>
#include <string>
#include <thread>
#include "port/port.h"
#include "rocksdb/env.h"
#include "test_util/sync_point.h"
#include "test_util/testharness.h"
#include "test_util/testutil.h"
#include "util/autovector.h"
namespace ROCKSDB_NAMESPACE {
class ThreadLocalTest : public testing::Test {
public:
ThreadLocalTest() : env_(Env::Default()) {}
Env* env_;
};
namespace {
struct Params {
Params(port::Mutex* m, port::CondVar* c, int* u, int n,
UnrefHandler handler = nullptr)
: mu(m),
cv(c),
unref(u),
total(n),
started(0),
completed(0),
doWrite(false),
tls1(handler),
tls2(nullptr) {}
port::Mutex* mu;
port::CondVar* cv;
int* unref;
int total;
int started;
int completed;
bool doWrite;
ThreadLocalPtr tls1;
ThreadLocalPtr* tls2;
};
class IDChecker : public ThreadLocalPtr {
public:
static uint32_t PeekId() { return TEST_PeekId(); }
};
}
#ifndef __clang_analyzer__
TEST_F(ThreadLocalTest, UniqueIdTest) {
port::Mutex mu;
port::CondVar cv(&mu);
uint32_t base_id = IDChecker::PeekId();
{
Params p1(&mu, &cv, nullptr, 1u);
ASSERT_EQ(IDChecker::PeekId(), base_id + 1u);
Params p2(&mu, &cv, nullptr, 1u);
ASSERT_EQ(IDChecker::PeekId(), base_id + 2u);
Params p3(&mu, &cv, nullptr, 1u);
ASSERT_EQ(IDChecker::PeekId(), base_id + 3u);
Params p4(&mu, &cv, nullptr, 1u);
ASSERT_EQ(IDChecker::PeekId(), base_id + 4u);
}
ASSERT_EQ(IDChecker::PeekId(), base_id + 0u);
Params p1(&mu, &cv, nullptr, 1u);
ASSERT_EQ(IDChecker::PeekId(), base_id + 1u);
Params* p2 = new Params(&mu, &cv, nullptr, 1u);
ASSERT_EQ(IDChecker::PeekId(), base_id + 2u);
Params p3(&mu, &cv, nullptr, 1u);
ASSERT_EQ(IDChecker::PeekId(), base_id + 3u);
delete p2;
ASSERT_EQ(IDChecker::PeekId(), base_id + 1u);
Params p4(&mu, &cv, nullptr, 1u);
ASSERT_EQ(IDChecker::PeekId(), base_id + 3u);
Params p5(&mu, &cv, nullptr, 1u);
ASSERT_EQ(IDChecker::PeekId(), base_id + 4u);
}
#endif
TEST_F(ThreadLocalTest, SequentialReadWriteTest) {
uint32_t base_id = IDChecker::PeekId();
port::Mutex mu;
port::CondVar cv(&mu);
Params p(&mu, &cv, nullptr, 1);
ThreadLocalPtr tls2;
p.tls2 = &tls2;
ASSERT_GT(IDChecker::PeekId(), base_id);
base_id = IDChecker::PeekId();
auto func = [](Params* ptr) {
Params& params = *ptr;
ASSERT_TRUE(params.tls1.Get() == nullptr);
params.tls1.Reset(reinterpret_cast<int*>(1));
ASSERT_TRUE(params.tls1.Get() == reinterpret_cast<int*>(1));
params.tls1.Reset(reinterpret_cast<int*>(2));
ASSERT_TRUE(params.tls1.Get() == reinterpret_cast<int*>(2));
ASSERT_TRUE(params.tls2->Get() == nullptr);
params.tls2->Reset(reinterpret_cast<int*>(1));
ASSERT_TRUE(params.tls2->Get() == reinterpret_cast<int*>(1));
params.tls2->Reset(reinterpret_cast<int*>(2));
ASSERT_TRUE(params.tls2->Get() == reinterpret_cast<int*>(2));
params.mu->Lock();
++(params.completed);
params.cv->SignalAll();
params.mu->Unlock();
};
for (int iter = 0; iter < 1024; ++iter) {
ASSERT_EQ(IDChecker::PeekId(), base_id);
env_->StartThreadTyped(func, &p);
mu.Lock();
while (p.completed != iter + 1) {
cv.Wait();
}
mu.Unlock();
ASSERT_EQ(IDChecker::PeekId(), base_id);
}
}
TEST_F(ThreadLocalTest, ConcurrentReadWriteTest) {
uint32_t base_id = IDChecker::PeekId();
ThreadLocalPtr tls2;
port::Mutex mu1;
port::CondVar cv1(&mu1);
Params p1(&mu1, &cv1, nullptr, 16);
p1.tls2 = &tls2;
port::Mutex mu2;
port::CondVar cv2(&mu2);
Params p2(&mu2, &cv2, nullptr, 16);
p2.doWrite = true;
p2.tls2 = &tls2;
auto func = [](void* ptr) {
auto& p = *static_cast<Params*>(ptr);
p.mu->Lock();
size_t own = ++(p.started);
p.cv->SignalAll();
while (p.started != p.total) {
p.cv->Wait();
}
p.mu->Unlock();
if (p.doWrite) {
own += 8192;
}
ASSERT_TRUE(p.tls1.Get() == nullptr);
ASSERT_TRUE(p.tls2->Get() == nullptr);
auto* env = Env::Default();
auto start = env->NowMicros();
p.tls1.Reset(reinterpret_cast<size_t*>(own));
p.tls2->Reset(reinterpret_cast<size_t*>(own + 1));
while (env->NowMicros() - start < 1000 * 1000) {
for (int iter = 0; iter < 100000; ++iter) {
ASSERT_TRUE(p.tls1.Get() == reinterpret_cast<size_t*>(own));
ASSERT_TRUE(p.tls2->Get() == reinterpret_cast<size_t*>(own + 1));
if (p.doWrite) {
p.tls1.Reset(reinterpret_cast<size_t*>(own));
p.tls2->Reset(reinterpret_cast<size_t*>(own + 1));
}
}
}
p.mu->Lock();
++(p.completed);
p.cv->SignalAll();
p.mu->Unlock();
};
for (int th = 0; th < p1.total; ++th) {
env_->StartThreadTyped(func, &p1);
}
for (int th = 0; th < p2.total; ++th) {
env_->StartThreadTyped(func, &p2);
}
mu1.Lock();
while (p1.completed != p1.total) {
cv1.Wait();
}
mu1.Unlock();
mu2.Lock();
while (p2.completed != p2.total) {
cv2.Wait();
}
mu2.Unlock();
ASSERT_EQ(IDChecker::PeekId(), base_id + 3u);
}
TEST_F(ThreadLocalTest, Unref) {
auto unref = [](void* ptr) {
auto& p = *static_cast<Params*>(ptr);
p.mu->Lock();
++(*p.unref);
p.mu->Unlock();
};
auto func0 = [](Params* ptr) {
auto& p = *ptr;
p.mu->Lock();
++(p.started);
p.cv->SignalAll();
while (p.started != p.total) {
p.cv->Wait();
}
p.mu->Unlock();
};
for (int th = 1; th <= 128; th += th) {
port::Mutex mu;
port::CondVar cv(&mu);
int unref_count = 0;
Params p(&mu, &cv, &unref_count, th, unref);
for (int i = 0; i < p.total; ++i) {
env_->StartThreadTyped(func0, &p);
}
env_->WaitForJoin();
ASSERT_EQ(unref_count, 0);
}
auto func1 = [](Params* ptr) {
auto& p = *ptr;
p.mu->Lock();
++(p.started);
p.cv->SignalAll();
while (p.started != p.total) {
p.cv->Wait();
}
p.mu->Unlock();
ASSERT_TRUE(p.tls1.Get() == nullptr);
ASSERT_TRUE(p.tls2->Get() == nullptr);
p.tls1.Reset(ptr);
p.tls2->Reset(ptr);
p.tls1.Reset(ptr);
p.tls2->Reset(ptr);
};
for (int th = 1; th <= 128; th += th) {
port::Mutex mu;
port::CondVar cv(&mu);
int unref_count = 0;
ThreadLocalPtr tls2(unref);
Params p(&mu, &cv, &unref_count, th, unref);
p.tls2 = &tls2;
for (int i = 0; i < p.total; ++i) {
env_->StartThreadTyped(func1, &p);
}
env_->WaitForJoin();
ASSERT_EQ(unref_count, 2 * p.total);
}
auto func2 = [](Params* ptr) {
auto& p = *ptr;
p.mu->Lock();
++(p.started);
p.cv->SignalAll();
while (p.started != p.total) {
p.cv->Wait();
}
p.mu->Unlock();
ASSERT_TRUE(p.tls1.Get() == nullptr);
ASSERT_TRUE(p.tls2->Get() == nullptr);
p.tls1.Reset(ptr);
p.tls2->Reset(ptr);
p.tls1.Reset(ptr);
p.tls2->Reset(ptr);
p.mu->Lock();
++(p.completed);
p.cv->SignalAll();
while (p.completed != 0) {
p.cv->Wait();
}
p.mu->Unlock();
};
for (int th = 1; th <= 128; th += th) {
port::Mutex mu;
port::CondVar cv(&mu);
int unref_count = 0;
Params p(&mu, &cv, &unref_count, th, unref);
p.tls2 = new ThreadLocalPtr(unref);
for (int i = 0; i < p.total; ++i) {
env_->StartThreadTyped(func2, &p);
}
mu.Lock();
while (p.completed != p.total) {
cv.Wait();
}
mu.Unlock();
delete p.tls2;
p.tls2 = nullptr;
ASSERT_EQ(unref_count, p.total);
mu.Lock();
p.completed = 0;
cv.SignalAll();
mu.Unlock();
env_->WaitForJoin();
ASSERT_EQ(unref_count, 2 * p.total);
}
}
TEST_F(ThreadLocalTest, Swap) {
ThreadLocalPtr tls;
tls.Reset(reinterpret_cast<void*>(1));
ASSERT_EQ(reinterpret_cast<int64_t>(tls.Swap(nullptr)), 1);
ASSERT_TRUE(tls.Swap(reinterpret_cast<void*>(2)) == nullptr);
ASSERT_EQ(reinterpret_cast<int64_t>(tls.Get()), 2);
ASSERT_EQ(reinterpret_cast<int64_t>(tls.Swap(reinterpret_cast<void*>(3))), 2);
}
TEST_F(ThreadLocalTest, Scrape) {
auto unref = [](void* ptr) {
auto& p = *static_cast<Params*>(ptr);
p.mu->Lock();
++(*p.unref);
p.mu->Unlock();
};
auto func = [](void* ptr) {
auto& p = *static_cast<Params*>(ptr);
ASSERT_TRUE(p.tls1.Get() == nullptr);
ASSERT_TRUE(p.tls2->Get() == nullptr);
p.tls1.Reset(ptr);
p.tls2->Reset(ptr);
p.tls1.Reset(ptr);
p.tls2->Reset(ptr);
p.mu->Lock();
++(p.completed);
p.cv->SignalAll();
while (p.completed != 0) {
p.cv->Wait();
}
p.mu->Unlock();
};
for (int th = 1; th <= 128; th += th) {
port::Mutex mu;
port::CondVar cv(&mu);
int unref_count = 0;
Params p(&mu, &cv, &unref_count, th, unref);
p.tls2 = new ThreadLocalPtr(unref);
for (int i = 0; i < p.total; ++i) {
env_->StartThreadTyped(func, &p);
}
mu.Lock();
while (p.completed != p.total) {
cv.Wait();
}
mu.Unlock();
ASSERT_EQ(unref_count, 0);
autovector<void*> ptrs;
p.tls1.Scrape(&ptrs, nullptr);
p.tls2->Scrape(&ptrs, nullptr);
delete p.tls2;
mu.Lock();
p.completed = 0;
cv.SignalAll();
mu.Unlock();
env_->WaitForJoin();
ASSERT_EQ(unref_count, 0);
}
}
TEST_F(ThreadLocalTest, Fold) {
auto unref = [](void* ptr) {
delete static_cast<std::atomic<int64_t>*>(ptr);
};
static const int kNumThreads = 16;
static const int kItersPerThread = 10;
port::Mutex mu;
port::CondVar cv(&mu);
Params params(&mu, &cv, nullptr, kNumThreads, unref);
auto func = [](void* ptr) {
auto& p = *static_cast<Params*>(ptr);
ASSERT_TRUE(p.tls1.Get() == nullptr);
p.tls1.Reset(new std::atomic<int64_t>(0));
for (int i = 0; i < kItersPerThread; ++i) {
static_cast<std::atomic<int64_t>*>(p.tls1.Get())->fetch_add(1);
}
p.mu->Lock();
++(p.completed);
p.cv->SignalAll();
while (p.completed != 0) {
p.cv->Wait();
}
p.mu->Unlock();
};
for (int th = 0; th < params.total; ++th) {
env_->StartThread(func, ¶ms);
}
mu.Lock();
while (params.completed != params.total) {
cv.Wait();
}
mu.Unlock();
int64_t sum = 0;
params.tls1.Fold(
[](void* ptr, void* res) {
auto sum_ptr = static_cast<int64_t*>(res);
*sum_ptr += static_cast<std::atomic<int64_t>*>(ptr)->load();
},
&sum);
ASSERT_EQ(sum, kNumThreads * kItersPerThread);
mu.Lock();
params.completed = 0;
cv.SignalAll();
mu.Unlock();
env_->WaitForJoin();
}
TEST_F(ThreadLocalTest, CompareAndSwap) {
ThreadLocalPtr tls;
ASSERT_TRUE(tls.Swap(reinterpret_cast<void*>(1)) == nullptr);
void* expected = reinterpret_cast<void*>(1);
ASSERT_TRUE(tls.CompareAndSwap(reinterpret_cast<void*>(2), expected));
expected = reinterpret_cast<void*>(100);
ASSERT_TRUE(!tls.CompareAndSwap(reinterpret_cast<void*>(2), expected));
ASSERT_EQ(expected, reinterpret_cast<void*>(2));
expected = reinterpret_cast<void*>(2);
ASSERT_TRUE(tls.CompareAndSwap(reinterpret_cast<void*>(3), expected));
ASSERT_EQ(tls.Get(), reinterpret_cast<void*>(3));
}
namespace {
void* AccessThreadLocal(void* ) {
TEST_SYNC_POINT("AccessThreadLocal:Start");
ThreadLocalPtr tlp;
tlp.Reset(new std::string("hello RocksDB"));
TEST_SYNC_POINT("AccessThreadLocal:End");
return nullptr;
}
}
TEST_F(ThreadLocalTest, DISABLED_MainThreadDiesFirst) {
ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency(
{{"AccessThreadLocal:Start", "MainThreadDiesFirst:End"},
{"PosixEnv::~PosixEnv():End", "AccessThreadLocal:End"}});
Env::Default();
try {
ROCKSDB_NAMESPACE::port::Thread th(&AccessThreadLocal, nullptr);
th.detach();
TEST_SYNC_POINT("MainThreadDiesFirst:End");
} catch (const std::system_error& ex) {
std::cerr << "Start thread: " << ex.code() << std::endl;
FAIL();
}
}
}
int main(int argc, char** argv) {
ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}