#include "absl/synchronization/notification.h"
#include <thread>
#include <tuple>
#include <vector>
#include "gtest/gtest.h"
#include "absl/base/attributes.h"
#include "absl/base/config.h"
#include "absl/base/internal/tracing.h"
#include "absl/synchronization/mutex.h"
#include "absl/time/time.h"
namespace absl {
ABSL_NAMESPACE_BEGIN
class ThreadSafeCounter {
public:
ThreadSafeCounter() : count_(0) {}
void Increment() {
MutexLock lock(mutex_);
++count_;
}
int Get() const {
MutexLock lock(mutex_);
return count_;
}
void WaitUntilGreaterOrEqual(int n) {
MutexLock lock(mutex_);
auto cond = [this, n]() { return count_ >= n; };
mutex_.Await(Condition(&cond));
}
private:
mutable Mutex mutex_;
int count_;
};
static void RunWorker(int i, ThreadSafeCounter* ready_counter,
Notification* notification,
ThreadSafeCounter* done_counter) {
ready_counter->Increment();
notification->WaitForNotification();
done_counter->Increment();
}
static void BasicTests(bool notify_before_waiting, Notification* notification) {
EXPECT_FALSE(notification->HasBeenNotified());
EXPECT_FALSE(
notification->WaitForNotificationWithTimeout(absl::Milliseconds(0)));
EXPECT_FALSE(notification->WaitForNotificationWithDeadline(absl::Now()));
const absl::Duration delay = absl::Milliseconds(50);
const absl::Time start = absl::Now();
EXPECT_FALSE(notification->WaitForNotificationWithTimeout(delay));
const absl::Duration elapsed = absl::Now() - start;
absl::Duration slop = absl::Milliseconds(5);
#ifdef _MSC_VER
slop = absl::Milliseconds(15);
#endif
EXPECT_LE(delay - slop, elapsed)
<< "WaitForNotificationWithTimeout returned " << delay - elapsed
<< " early (with " << slop << " slop), start time was " << start;
ThreadSafeCounter ready_counter;
ThreadSafeCounter done_counter;
if (notify_before_waiting) {
notification->Notify();
}
const int kNumThreads = 10;
std::vector<std::thread> workers;
for (int i = 0; i < kNumThreads; ++i) {
workers.push_back(std::thread(&RunWorker, i, &ready_counter, notification,
&done_counter));
}
if (!notify_before_waiting) {
ready_counter.WaitUntilGreaterOrEqual(kNumThreads);
EXPECT_EQ(0, done_counter.Get());
notification->Notify();
}
notification->WaitForNotification(); EXPECT_TRUE(notification->HasBeenNotified());
EXPECT_TRUE(notification->WaitForNotificationWithTimeout(absl::Seconds(0)));
EXPECT_TRUE(notification->WaitForNotificationWithDeadline(absl::Now()));
for (std::thread& worker : workers) {
worker.join();
}
EXPECT_EQ(kNumThreads, ready_counter.Get());
EXPECT_EQ(kNumThreads, done_counter.Get());
}
TEST(NotificationTest, SanityTest) {
Notification local_notification1, local_notification2;
BasicTests(false, &local_notification1);
BasicTests(true, &local_notification2);
}
#if ABSL_HAVE_ATTRIBUTE_WEAK
namespace base_internal {
namespace {
using TraceRecord = std::tuple<const void*, ObjectKind>;
thread_local TraceRecord tls_signal;
thread_local TraceRecord tls_wait;
thread_local TraceRecord tls_continue;
thread_local TraceRecord tls_observed;
}
extern "C" {
void ABSL_INTERNAL_C_SYMBOL(AbslInternalTraceWait)(const void* object,
ObjectKind kind) {
tls_wait = {object, kind};
}
void ABSL_INTERNAL_C_SYMBOL(AbslInternalTraceContinue)(const void* object,
ObjectKind kind) {
tls_continue = {object, kind};
}
void ABSL_INTERNAL_C_SYMBOL(AbslInternalTraceSignal)(const void* object,
ObjectKind kind) {
tls_signal = {object, kind};
}
void ABSL_INTERNAL_C_SYMBOL(AbslInternalTraceObserved)(const void* object,
ObjectKind kind) {
tls_observed = {object, kind};
}
}
TEST(NotificationTest, TracesNotify) {
Notification n;
tls_signal = {};
n.Notify();
EXPECT_EQ(tls_signal, TraceRecord(&n, ObjectKind::kNotification));
}
TEST(NotificationTest, TracesWaitForNotification) {
Notification n;
n.Notify();
tls_wait = tls_continue = {};
n.WaitForNotification();
EXPECT_EQ(tls_wait, TraceRecord(&n, ObjectKind::kNotification));
EXPECT_EQ(tls_continue, TraceRecord(&n, ObjectKind::kNotification));
}
TEST(NotificationTest, TracesWaitForNotificationWithTimeout) {
Notification n;
tls_wait = tls_continue = {};
n.WaitForNotificationWithTimeout(absl::Milliseconds(1));
EXPECT_EQ(tls_wait, TraceRecord(&n, ObjectKind::kNotification));
EXPECT_EQ(tls_continue, TraceRecord(nullptr, ObjectKind::kNotification));
n.Notify();
tls_wait = tls_continue = {};
n.WaitForNotificationWithTimeout(absl::Milliseconds(1));
EXPECT_EQ(tls_wait, TraceRecord(&n, ObjectKind::kNotification));
EXPECT_EQ(tls_continue, TraceRecord(&n, ObjectKind::kNotification));
}
TEST(NotificationTest, TracesHasBeenNotified) {
Notification n;
tls_observed = {};
ASSERT_FALSE(n.HasBeenNotified());
EXPECT_EQ(tls_observed, TraceRecord(nullptr, ObjectKind::kUnknown));
n.Notify();
tls_observed = {};
ASSERT_TRUE(n.HasBeenNotified());
EXPECT_EQ(tls_observed, TraceRecord(&n, ObjectKind::kNotification));
}
}
#endif
ABSL_NAMESPACE_END
}