#include <grpc/support/port_platform.h>
#include "src/core/lib/promise/party.h"
#include <atomic>
#include <initializer_list>
#include "absl/base/thread_annotations.h"
#include "absl/strings/str_format.h"
#include <grpc/support/log.h>
#include "src/core/lib/debug/trace.h"
#include "src/core/lib/gprpp/sync.h"
#include "src/core/lib/iomgr/exec_ctx.h"
#include "src/core/lib/promise/activity.h"
#include "src/core/lib/promise/trace.h"
#ifdef GRPC_MAXIMIZE_THREADYNESS
#include "src/core/lib/gprpp/thd.h"
#include "src/core/lib/iomgr/exec_ctx.h"
#endif
namespace grpc_core {
GRPC_MUST_USE_RESULT bool PartySyncUsingAtomics::RefIfNonZero() {
auto count = state_.load(std::memory_order_relaxed);
do {
if (count == 0) {
return false;
}
} while (!state_.compare_exchange_weak(count, count + kOneRef,
std::memory_order_acq_rel,
std::memory_order_relaxed));
return true;
}
bool PartySyncUsingAtomics::UnreffedLast() {
uint64_t prev_state =
state_.fetch_or(kDestroying | kLocked, std::memory_order_acq_rel);
return (prev_state & kLocked) == 0;
}
bool PartySyncUsingAtomics::ScheduleWakeup(WakeupMask mask) {
uint64_t prev_state = state_.fetch_or((mask & kWakeupMask) | kLocked,
std::memory_order_acq_rel);
return ((prev_state & kLocked) == 0);
}
bool PartySyncUsingMutex::ScheduleWakeup(WakeupMask mask) {
MutexLock lock(&mu_);
wakeups_ |= mask;
return !std::exchange(locked_, true);
}
class Party::Handle final : public Wakeable {
public:
explicit Handle(Party* party) : party_(party) {}
void Ref() { refs_.fetch_add(1, std::memory_order_relaxed); }
void DropActivity() ABSL_LOCKS_EXCLUDED(mu_) {
mu_.Lock();
GPR_ASSERT(party_ != nullptr);
party_ = nullptr;
mu_.Unlock();
Unref();
}
void WakeupGeneric(WakeupMask wakeup_mask,
void (Party::*wakeup_method)(WakeupMask))
ABSL_LOCKS_EXCLUDED(mu_) {
mu_.Lock();
Party* party = party_;
if (party != nullptr && party->RefIfNonZero()) {
mu_.Unlock();
(party->*wakeup_method)(wakeup_mask);
} else {
mu_.Unlock();
}
Unref();
}
void Wakeup(WakeupMask wakeup_mask) override ABSL_LOCKS_EXCLUDED(mu_) {
WakeupGeneric(wakeup_mask, &Party::Wakeup);
}
void WakeupAsync(WakeupMask wakeup_mask) override ABSL_LOCKS_EXCLUDED(mu_) {
WakeupGeneric(wakeup_mask, &Party::WakeupAsync);
}
void Drop(WakeupMask) override { Unref(); }
std::string ActivityDebugTag(WakeupMask) const override {
MutexLock lock(&mu_);
return party_ == nullptr ? "<unknown>" : party_->DebugTag();
}
private:
void Unref() {
if (1 == refs_.fetch_sub(1, std::memory_order_acq_rel)) {
delete this;
}
}
std::atomic<size_t> refs_{2};
mutable Mutex mu_;
Party* party_ ABSL_GUARDED_BY(mu_);
};
Wakeable* Party::Participant::MakeNonOwningWakeable(Party* party) {
if (handle_ == nullptr) {
handle_ = new Handle(party);
return handle_;
}
handle_->Ref();
return handle_;
}
Party::Participant::~Participant() {
if (handle_ != nullptr) {
handle_->DropActivity();
}
}
Party::~Party() {}
void Party::CancelRemainingParticipants() {
ScopedActivity activity(this);
promise_detail::Context<Arena> arena_ctx(arena_);
for (size_t i = 0; i < party_detail::kMaxParticipants; i++) {
if (auto* p =
participants_[i].exchange(nullptr, std::memory_order_acquire)) {
p->Destroy();
}
}
}
std::string Party::ActivityDebugTag(WakeupMask wakeup_mask) const {
return absl::StrFormat("%s [parts:%x]", DebugTag(), wakeup_mask);
}
Waker Party::MakeOwningWaker() {
GPR_DEBUG_ASSERT(currently_polling_ != kNotPolling);
IncrementRefCount();
return Waker(this, 1u << currently_polling_);
}
Waker Party::MakeNonOwningWaker() {
GPR_DEBUG_ASSERT(currently_polling_ != kNotPolling);
return Waker(participants_[currently_polling_]
.load(std::memory_order_relaxed)
->MakeNonOwningWakeable(this),
1u << currently_polling_);
}
void Party::ForceImmediateRepoll(WakeupMask mask) {
GPR_DEBUG_ASSERT(is_current());
sync_.ForceImmediateRepoll(mask);
}
void Party::RunLocked() {
auto body = [this]() {
if (RunParty()) {
ScopedActivity activity(this);
PartyOver();
}
};
#ifdef GRPC_MAXIMIZE_THREADYNESS
Thread thd(
"RunParty",
[body]() {
ApplicationCallbackExecCtx app_exec_ctx;
ExecCtx exec_ctx;
body();
},
nullptr, Thread::Options().set_joinable(false));
thd.Start();
#else
body();
#endif
}
bool Party::RunParty() {
ScopedActivity activity(this);
promise_detail::Context<Arena> arena_ctx(arena_);
return sync_.RunParty([this](int i) {
auto* participant = participants_[i].load(std::memory_order_acquire);
if (participant == nullptr) {
if (grpc_trace_promise_primitives.enabled()) {
gpr_log(GPR_DEBUG, "%s[party] wakeup %d already complete",
DebugTag().c_str(), i);
}
return false;
}
absl::string_view name;
if (grpc_trace_promise_primitives.enabled()) {
name = participant->name();
gpr_log(GPR_DEBUG, "%s[%s] begin job %d", DebugTag().c_str(),
std::string(name).c_str(), i);
}
currently_polling_ = i;
bool done = participant->Poll();
currently_polling_ = kNotPolling;
if (done) {
if (!name.empty()) {
gpr_log(GPR_DEBUG, "%s[%s] end poll and finish job %d",
DebugTag().c_str(), std::string(name).c_str(), i);
}
participants_[i].store(nullptr, std::memory_order_relaxed);
} else if (!name.empty()) {
gpr_log(GPR_DEBUG, "%s[%s] end poll", DebugTag().c_str(),
std::string(name).c_str());
}
return done;
});
}
void Party::AddParticipants(Participant** participants, size_t count) {
bool run_party = sync_.AddParticipantsAndRef(count, [this, participants,
count](size_t* slots) {
for (size_t i = 0; i < count; i++) {
participants_[slots[i]].store(participants[i], std::memory_order_release);
}
});
if (run_party) RunLocked();
Unref();
}
void Party::Wakeup(WakeupMask wakeup_mask) {
if (sync_.ScheduleWakeup(wakeup_mask)) RunLocked();
Unref();
}
void Party::WakeupAsync(WakeupMask wakeup_mask) {
if (sync_.ScheduleWakeup(wakeup_mask)) {
event_engine()->Run([this]() {
ApplicationCallbackExecCtx app_exec_ctx;
ExecCtx exec_ctx;
RunLocked();
Unref();
});
} else {
Unref();
}
}
void Party::Drop(WakeupMask) { Unref(); }
void Party::PartyIsOver() {
ScopedActivity activity(this);
PartyOver();
}
}