#ifndef GRPC_SRC_CORE_LIB_PROMISE_PARTY_H
#define GRPC_SRC_CORE_LIB_PROMISE_PARTY_H
#include <grpc/support/port_platform.h>
#include <stddef.h>
#include <stdint.h>
#include <atomic>
#include <string>
#include <utility>
#include "absl/base/thread_annotations.h"
#include "absl/strings/string_view.h"
#include <grpc/event_engine/event_engine.h>
#include <grpc/support/log.h>
#include "src/core/lib/debug/trace.h"
#include "src/core/lib/gprpp/construct_destruct.h"
#include "src/core/lib/gprpp/crash.h"
#include "src/core/lib/gprpp/ref_counted.h"
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/gprpp/sync.h"
#include "src/core/lib/promise/activity.h"
#include "src/core/lib/promise/context.h"
#include "src/core/lib/promise/detail/promise_factory.h"
#include "src/core/lib/promise/trace.h"
#include "src/core/lib/resource_quota/arena.h"
#define GRPC_PARTY_SYNC_USING_ATOMICS
#if defined(GRPC_PARTY_SYNC_USING_ATOMICS) + \
defined(GRPC_PARTY_SYNC_USING_MUTEX) != \
1
#error Must define a party sync mechanism
#endif
namespace grpc_core {
namespace party_detail {
static constexpr size_t kMaxParticipants = 16;
}
class PartySyncUsingAtomics {
public:
explicit PartySyncUsingAtomics(size_t initial_refs)
: state_(kOneRef * initial_refs) {}
void IncrementRefCount() {
state_.fetch_add(kOneRef, std::memory_order_relaxed);
}
GRPC_MUST_USE_RESULT bool RefIfNonZero();
GRPC_MUST_USE_RESULT bool Unref() {
uint64_t prev_state = state_.fetch_sub(kOneRef, std::memory_order_acq_rel);
if ((prev_state & kRefMask) == kOneRef) {
return UnreffedLast();
}
return false;
}
void ForceImmediateRepoll(WakeupMask mask) {
state_.fetch_or(mask, std::memory_order_relaxed);
}
template <typename F>
GRPC_MUST_USE_RESULT bool RunParty(F poll_one_participant) {
uint64_t prev_state;
do {
prev_state = state_.fetch_and(kRefMask | kLocked | kAllocatedMask,
std::memory_order_acquire);
GPR_ASSERT(prev_state & kLocked);
if (prev_state & kDestroying) return true;
uint64_t wakeups = prev_state & kWakeupMask;
prev_state &= kRefMask | kLocked | kAllocatedMask;
for (size_t i = 0; wakeups != 0; i++, wakeups >>= 1) {
if ((wakeups & 1) == 0) continue;
if (poll_one_participant(i)) {
const uint64_t allocated_bit = (1u << i << kAllocatedShift);
prev_state &= ~allocated_bit;
state_.fetch_and(~allocated_bit, std::memory_order_release);
}
}
} while (!state_.compare_exchange_weak(
prev_state, (prev_state & (kRefMask | kAllocatedMask)),
std::memory_order_acq_rel, std::memory_order_acquire));
return false;
}
template <typename F>
GRPC_MUST_USE_RESULT bool AddParticipantsAndRef(size_t count, F store) {
uint64_t state = state_.load(std::memory_order_acquire);
uint64_t allocated;
size_t slots[party_detail::kMaxParticipants];
WakeupMask wakeup_mask;
do {
wakeup_mask = 0;
allocated = (state & kAllocatedMask) >> kAllocatedShift;
size_t n = 0;
for (size_t bit = 0; n < count && bit < party_detail::kMaxParticipants;
bit++) {
if (allocated & (1 << bit)) continue;
wakeup_mask |= (1 << bit);
slots[n++] = bit;
allocated |= 1 << bit;
}
GPR_ASSERT(n == count);
} while (!state_.compare_exchange_weak(
state, (state | (allocated << kAllocatedShift)) + kOneRef,
std::memory_order_acq_rel, std::memory_order_acquire));
store(slots);
state = state_.fetch_or(wakeup_mask | kLocked, std::memory_order_release);
return ((state & kLocked) == 0);
}
GRPC_MUST_USE_RESULT bool ScheduleWakeup(WakeupMask mask);
private:
bool UnreffedLast();
static constexpr uint64_t kWakeupMask = 0x0000'0000'0000'ffff;
// Bits used to store 16 bits of allocated participant slots.
static constexpr uint64_t kAllocatedMask = 0x0000'0000'ffff'0000;
static constexpr uint64_t kDestroying = 0x0000'0001'0000'0000;
// Bit indicating locked or not
static constexpr uint64_t kLocked = 0x0000'0008'0000'0000;
static constexpr uint64_t kRefMask = 0xffff'ff00'0000'0000;
// clang-format on
// Shift to get from a participant mask to an allocated mask.
static constexpr size_t kAllocatedShift = 16;
// How far to shift to get the refcount
static constexpr size_t kRefShift = 40;
// One ref count
static constexpr uint64_t kOneRef = 1ull << kRefShift;
std::atomic<uint64_t> state_;
};
class PartySyncUsingMutex {
public:
explicit PartySyncUsingMutex(size_t initial_refs) : refs_(initial_refs) {}
void IncrementRefCount() { refs_.Ref(); }
GRPC_MUST_USE_RESULT bool RefIfNonZero() { return refs_.RefIfNonZero(); }
GRPC_MUST_USE_RESULT bool Unref() { return refs_.Unref(); }
void ForceImmediateRepoll(WakeupMask mask) {
MutexLock lock(&mu_);
wakeups_ |= mask;
}
template <typename F>
GRPC_MUST_USE_RESULT bool RunParty(F poll_one_participant) {
WakeupMask freed = 0;
while (true) {
ReleasableMutexLock lock(&mu_);
GPR_ASSERT(locked_);
allocated_ &= ~std::exchange(freed, 0);
auto wakeup = std::exchange(wakeups_, 0);
if (wakeup == 0) {
locked_ = false;
return false;
}
lock.Release();
for (size_t i = 0; wakeup != 0; i++, wakeup >>= 1) {
if ((wakeup & 1) == 0) continue;
if (poll_one_participant(i)) freed |= 1 << i;
}
}
}
template <typename F>
GRPC_MUST_USE_RESULT bool AddParticipantsAndRef(size_t count, F store) {
IncrementRefCount();
MutexLock lock(&mu_);
size_t slots[party_detail::kMaxParticipants];
WakeupMask wakeup_mask = 0;
size_t n = 0;
for (size_t bit = 0; n < count && bit < party_detail::kMaxParticipants;
bit++) {
if (allocated_ & (1 << bit)) continue;
slots[n++] = bit;
wakeup_mask |= 1 << bit;
allocated_ |= 1 << bit;
}
GPR_ASSERT(n == count);
store(slots);
wakeups_ |= wakeup_mask;
return !std::exchange(locked_, true);
}
GRPC_MUST_USE_RESULT bool ScheduleWakeup(WakeupMask mask);
private:
RefCount refs_;
Mutex mu_;
WakeupMask allocated_ ABSL_GUARDED_BY(mu_) = 0;
WakeupMask wakeups_ ABSL_GUARDED_BY(mu_) = 0;
bool locked_ ABSL_GUARDED_BY(mu_) = false;
};
// A Party is an Activity with multiple participant promises.
class Party : public Activity, private Wakeable {
private:
// Non-owning wakeup handle.
class Handle;
// One participant in the party.
class Participant {
public:
explicit Participant(absl::string_view name) : name_(name) {}
// Poll the participant. Return true if complete.
// Participant should take care of its own deallocation in this case.
virtual bool Poll() = 0;
// Destroy the participant before finishing.
virtual void Destroy() = 0;
// Return a Handle instance for this participant.
Wakeable* MakeNonOwningWakeable(Party* party);
absl::string_view name() const { return name_; }
protected:
~Participant();
private:
Handle* handle_ = nullptr;
absl::string_view name_;
};
public:
Party(const Party&) = delete;
Party& operator=(const Party&) = delete;
// Spawn one promise into the party.
// The promise will be polled until it is resolved, or until the party is shut
// down.
// The on_complete callback will be called with the result of the promise if
// it completes.
// A maximum of sixteen promises can be spawned onto a party.
template <typename Factory, typename OnComplete>
void Spawn(absl::string_view name, Factory promise_factory,
OnComplete on_complete);
void Orphan() final { Crash("unused"); }
// Activity implementation: not allowed to be overridden by derived types.
void ForceImmediateRepoll(WakeupMask mask) final;
WakeupMask CurrentParticipant() const final {
GPR_DEBUG_ASSERT(currently_polling_ != kNotPolling);
return 1u << currently_polling_;
}
Waker MakeOwningWaker() final;
Waker MakeNonOwningWaker() final;
std::string ActivityDebugTag(WakeupMask wakeup_mask) const final;
void IncrementRefCount() { sync_.IncrementRefCount(); }
void Unref() {
if (sync_.Unref()) PartyIsOver();
}
RefCountedPtr<Party> Ref() {
IncrementRefCount();
return RefCountedPtr<Party>(this);
}
Arena* arena() const { return arena_; }
class BulkSpawner {
public:
explicit BulkSpawner(Party* party) : party_(party) {}
~BulkSpawner() {
party_->AddParticipants(participants_, num_participants_);
}
template <typename Factory, typename OnComplete>
void Spawn(absl::string_view name, Factory promise_factory,
OnComplete on_complete);
private:
Party* const party_;
size_t num_participants_ = 0;
Participant* participants_[party_detail::kMaxParticipants];
};
protected:
explicit Party(Arena* arena, size_t initial_refs)
: sync_(initial_refs), arena_(arena) {}
~Party() override;
// Main run loop. Must be locked.
// Polls participants and drains the add queue until there is no work left to
// be done.
// Derived types will likely want to override this to set up their
// contexts before polling.
// Should not be called by derived types except as a tail call to the base
// class RunParty when overriding this method to add custom context.
// Returns true if the party is over.
virtual bool RunParty() GRPC_MUST_USE_RESULT;
bool RefIfNonZero() { return sync_.RefIfNonZero(); }
// Destroy any remaining participants.
// Should be called by derived types in response to PartyOver.
// Needs to have normal context setup before calling.
void CancelRemainingParticipants();
private:
// Concrete implementation of a participant for some promise & oncomplete
// type.
template <typename SuppliedFactory, typename OnComplete>
class ParticipantImpl final : public Participant {
using Factory = promise_detail::OncePromiseFactory<void, SuppliedFactory>;
using Promise = typename Factory::Promise;
public:
ParticipantImpl(absl::string_view name, SuppliedFactory promise_factory,
OnComplete on_complete)
: Participant(name), on_complete_(std::move(on_complete)) {
Construct(&factory_, std::move(promise_factory));
}
~ParticipantImpl() {
if (!started_) {
Destruct(&factory_);
} else {
Destruct(&promise_);
}
}
bool Poll() override {
if (!started_) {
auto p = factory_.Make();
Destruct(&factory_);
Construct(&promise_, std::move(p));
started_ = true;
}
auto p = promise_();
if (auto* r = p.value_if_ready()) {
on_complete_(std::move(*r));
GetContext<Arena>()->DeletePooled(this);
return true;
}
return false;
}
void Destroy() override { GetContext<Arena>()->DeletePooled(this); }
private:
union {
GPR_NO_UNIQUE_ADDRESS Factory factory_;
GPR_NO_UNIQUE_ADDRESS Promise promise_;
};
GPR_NO_UNIQUE_ADDRESS OnComplete on_complete_;
bool started_ = false;
};
// Notification that the party has finished and this instance can be deleted.
// Derived types should arrange to call CancelRemainingParticipants during
// this sequence.
virtual void PartyOver() = 0;
// Run the locked part of the party until it is unlocked.
void RunLocked();
// Called in response to Unref() hitting zero - ultimately calls PartyOver,
// but needs to set some stuff up.
// Here so it gets compiled out of line.
void PartyIsOver();
// Wakeable implementation
void Wakeup(WakeupMask wakeup_mask) final;
void WakeupAsync(WakeupMask wakeup_mask) final;
void Drop(WakeupMask wakeup_mask) final;
// Add a participant (backs Spawn, after type erasure to ParticipantFactory).
void AddParticipants(Participant** participant, size_t count);
virtual grpc_event_engine::experimental::EventEngine* event_engine()
const = 0;
// Sentinal value for currently_polling_ when no participant is being polled.
static constexpr uint8_t kNotPolling = 255;
#ifdef GRPC_PARTY_SYNC_USING_ATOMICS
PartySyncUsingAtomics sync_;
#elif defined(GRPC_PARTY_SYNC_USING_MUTEX)
PartySyncUsingMutex sync_;
#else
#error No synchronization method defined
#endif
Arena* const arena_;
uint8_t currently_polling_ = kNotPolling;
// All current participants, using a tagged format.
// If the lower bit is unset, then this is a Participant*.
// If the lower bit is set, then this is a ParticipantFactory*.
std::atomic<Participant*> participants_[party_detail::kMaxParticipants] = {};
};
template <typename Factory, typename OnComplete>
void Party::BulkSpawner::Spawn(absl::string_view name, Factory promise_factory,
OnComplete on_complete) {
if (grpc_trace_promise_primitives.enabled()) {
gpr_log(GPR_DEBUG, "%s[bulk_spawn] On %p queue %s",
party_->DebugTag().c_str(), this, std::string(name).c_str());
}
participants_[num_participants_++] =
party_->arena_->NewPooled<ParticipantImpl<Factory, OnComplete>>(
name, std::move(promise_factory), std::move(on_complete));
}
template <typename Factory, typename OnComplete>
void Party::Spawn(absl::string_view name, Factory promise_factory,
OnComplete on_complete) {
BulkSpawner(this).Spawn(name, std::move(promise_factory),
std::move(on_complete));
}
} // namespace grpc_core
#endif // GRPC_SRC_CORE_LIB_PROMISE_PARTY_H