#ifndef GRPC_SRC_CORE_LIB_PROMISE_LOOP_H
#define GRPC_SRC_CORE_LIB_PROMISE_LOOP_H
#include <grpc/support/port_platform.h>
#include <type_traits>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/variant.h"
#include "src/core/lib/gprpp/construct_destruct.h"
#include "src/core/lib/promise/detail/promise_factory.h"
#include "src/core/lib/promise/poll.h"
namespace grpc_core {
struct Continue {};
template <typename T>
using LoopCtl = absl::variant<Continue, T>;
namespace promise_detail {
template <typename T>
struct LoopTraits;
template <typename T>
struct LoopTraits<LoopCtl<T>> {
using Result = T;
static LoopCtl<T> ToLoopCtl(LoopCtl<T> value) { return value; }
};
template <typename T>
struct LoopTraits<absl::StatusOr<LoopCtl<T>>> {
using Result = absl::StatusOr<T>;
static LoopCtl<Result> ToLoopCtl(absl::StatusOr<LoopCtl<T>> value) {
if (!value.ok()) return value.status();
const auto& inner = *value;
if (absl::holds_alternative<Continue>(inner)) return Continue{};
return absl::get<T>(inner);
}
};
template <>
struct LoopTraits<absl::StatusOr<LoopCtl<absl::Status>>> {
using Result = absl::Status;
static LoopCtl<Result> ToLoopCtl(
absl::StatusOr<LoopCtl<absl::Status>> value) {
if (!value.ok()) return value.status();
const auto& inner = *value;
if (absl::holds_alternative<Continue>(inner)) return Continue{};
return absl::get<absl::Status>(inner);
}
};
template <typename F>
class Loop {
private:
using Factory = promise_detail::RepeatedPromiseFactory<void, F>;
using PromiseType = decltype(std::declval<Factory>().Make());
using PromiseResult = typename PromiseType::Result;
public:
using Result = typename LoopTraits<PromiseResult>::Result;
explicit Loop(F f) : factory_(std::move(f)) {}
~Loop() {
if (started_) Destruct(&promise_);
}
Loop(Loop&& loop) noexcept : factory_(std::move(loop.factory_)) {}
Loop(const Loop& loop) = delete;
Loop& operator=(const Loop& loop) = delete;
Poll<Result> operator()() {
if (!started_) {
started_ = true;
Construct(&promise_, factory_.Make());
}
while (true) {
auto promise_result = promise_();
if (auto* p = promise_result.value_if_ready()) {
auto lc = LoopTraits<PromiseResult>::ToLoopCtl(*p);
if (absl::holds_alternative<Continue>(lc)) {
Destruct(&promise_);
Construct(&promise_, factory_.Make());
continue;
}
return absl::get<Result>(lc);
} else {
return Pending();
}
}
}
private:
GPR_NO_UNIQUE_ADDRESS Factory factory_;
GPR_NO_UNIQUE_ADDRESS union {
GPR_NO_UNIQUE_ADDRESS PromiseType promise_;
};
bool started_ = false;
};
}
template <typename F>
promise_detail::Loop<F> Loop(F f) {
return promise_detail::Loop<F>(std::move(f));
}
}
#endif