#include <grpc/support/port_platform.h>
#include "src/core/lib/channel/connected_channel.h"
#include <inttypes.h>
#include <functional>
#include <initializer_list>
#include <memory>
#include <string>
#include <type_traits>
#include <utility>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/optional.h"
#include "absl/types/variant.h"
#include <grpc/grpc.h>
#include <grpc/status.h>
#include <grpc/support/alloc.h>
#include <grpc/support/log.h>
#include "src/core/lib/channel/channel_args.h"
#include "src/core/lib/channel/channel_fwd.h"
#include "src/core/lib/channel/channel_stack.h"
#include "src/core/lib/debug/trace.h"
#include "src/core/lib/experiments/experiments.h"
#include "src/core/lib/gpr/alloc.h"
#include "src/core/lib/gprpp/debug_location.h"
#include "src/core/lib/gprpp/orphanable.h"
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/gprpp/time.h"
#include "src/core/lib/iomgr/call_combiner.h"
#include "src/core/lib/iomgr/closure.h"
#include "src/core/lib/iomgr/error.h"
#include "src/core/lib/iomgr/polling_entity.h"
#include "src/core/lib/promise/activity.h"
#include "src/core/lib/promise/arena_promise.h"
#include "src/core/lib/promise/context.h"
#include "src/core/lib/promise/detail/basic_seq.h"
#include "src/core/lib/promise/detail/status.h"
#include "src/core/lib/promise/for_each.h"
#include "src/core/lib/promise/if.h"
#include "src/core/lib/promise/latch.h"
#include "src/core/lib/promise/loop.h"
#include "src/core/lib/promise/map.h"
#include "src/core/lib/promise/party.h"
#include "src/core/lib/promise/pipe.h"
#include "src/core/lib/promise/poll.h"
#include "src/core/lib/promise/promise.h"
#include "src/core/lib/promise/race.h"
#include "src/core/lib/promise/seq.h"
#include "src/core/lib/promise/try_seq.h"
#include "src/core/lib/resource_quota/arena.h"
#include "src/core/lib/slice/slice.h"
#include "src/core/lib/slice/slice_buffer.h"
#include "src/core/lib/surface/call.h"
#include "src/core/lib/surface/call_trace.h"
#include "src/core/lib/surface/channel_stack_type.h"
#include "src/core/lib/transport/batch_builder.h"
#include "src/core/lib/transport/error_utils.h"
#include "src/core/lib/transport/metadata_batch.h"
#include "src/core/lib/transport/transport.h"
#include "src/core/lib/transport/transport_fwd.h"
#include "src/core/lib/transport/transport_impl.h"
typedef struct connected_channel_channel_data {
grpc_transport* transport;
} channel_data;
struct callback_state {
grpc_closure closure;
grpc_closure* original_closure;
grpc_core::CallCombiner* call_combiner;
const char* reason;
};
typedef struct connected_channel_call_data {
grpc_core::CallCombiner* call_combiner;
callback_state on_complete[6]; callback_state recv_initial_metadata_ready;
callback_state recv_message_ready;
callback_state recv_trailing_metadata_ready;
} call_data;
static void run_in_call_combiner(void* arg, grpc_error_handle error) {
callback_state* state = static_cast<callback_state*>(arg);
GRPC_CALL_COMBINER_START(state->call_combiner, state->original_closure, error,
state->reason);
}
static void run_cancel_in_call_combiner(void* arg, grpc_error_handle error) {
run_in_call_combiner(arg, error);
gpr_free(arg);
}
static void intercept_callback(call_data* calld, callback_state* state,
bool free_when_done, const char* reason,
grpc_closure** original_closure) {
state->original_closure = *original_closure;
state->call_combiner = calld->call_combiner;
state->reason = reason;
*original_closure = GRPC_CLOSURE_INIT(
&state->closure,
free_when_done ? run_cancel_in_call_combiner : run_in_call_combiner,
state, grpc_schedule_on_exec_ctx);
}
static callback_state* get_state_for_batch(
call_data* calld, grpc_transport_stream_op_batch* batch) {
if (batch->send_initial_metadata) return &calld->on_complete[0];
if (batch->send_message) return &calld->on_complete[1];
if (batch->send_trailing_metadata) return &calld->on_complete[2];
if (batch->recv_initial_metadata) return &calld->on_complete[3];
if (batch->recv_message) return &calld->on_complete[4];
if (batch->recv_trailing_metadata) return &calld->on_complete[5];
GPR_UNREACHABLE_CODE(return nullptr);
}
#define TRANSPORT_STREAM_FROM_CALL_DATA(calld) \
((grpc_stream*)(((char*)(calld)) + \
GPR_ROUND_UP_TO_ALIGNMENT_SIZE(sizeof(call_data))))
#define CALL_DATA_FROM_TRANSPORT_STREAM(transport_stream) \
((call_data*)(((char*)(transport_stream)) - \
GPR_ROUND_UP_TO_ALIGNMENT_SIZE(sizeof(call_data))))
static void connected_channel_start_transport_stream_op_batch(
grpc_call_element* elem, grpc_transport_stream_op_batch* batch) {
call_data* calld = static_cast<call_data*>(elem->call_data);
channel_data* chand = static_cast<channel_data*>(elem->channel_data);
if (batch->recv_initial_metadata) {
callback_state* state = &calld->recv_initial_metadata_ready;
intercept_callback(
calld, state, false, "recv_initial_metadata_ready",
&batch->payload->recv_initial_metadata.recv_initial_metadata_ready);
}
if (batch->recv_message) {
callback_state* state = &calld->recv_message_ready;
intercept_callback(calld, state, false, "recv_message_ready",
&batch->payload->recv_message.recv_message_ready);
}
if (batch->recv_trailing_metadata) {
callback_state* state = &calld->recv_trailing_metadata_ready;
intercept_callback(
calld, state, false, "recv_trailing_metadata_ready",
&batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready);
}
if (batch->cancel_stream) {
callback_state* state =
static_cast<callback_state*>(gpr_malloc(sizeof(*state)));
intercept_callback(calld, state, true, "on_complete (cancel_stream)",
&batch->on_complete);
} else if (batch->on_complete != nullptr) {
callback_state* state = get_state_for_batch(calld, batch);
intercept_callback(calld, state, false, "on_complete", &batch->on_complete);
}
grpc_transport_perform_stream_op(
chand->transport, TRANSPORT_STREAM_FROM_CALL_DATA(calld), batch);
GRPC_CALL_COMBINER_STOP(calld->call_combiner, "passed batch to transport");
}
static void connected_channel_start_transport_op(grpc_channel_element* elem,
grpc_transport_op* op) {
channel_data* chand = static_cast<channel_data*>(elem->channel_data);
grpc_transport_perform_op(chand->transport, op);
}
static grpc_error_handle connected_channel_init_call_elem(
grpc_call_element* elem, const grpc_call_element_args* args) {
call_data* calld = static_cast<call_data*>(elem->call_data);
channel_data* chand = static_cast<channel_data*>(elem->channel_data);
calld->call_combiner = args->call_combiner;
int r = grpc_transport_init_stream(
chand->transport, TRANSPORT_STREAM_FROM_CALL_DATA(calld),
&args->call_stack->refcount, args->server_transport_data, args->arena);
return r == 0 ? absl::OkStatus()
: GRPC_ERROR_CREATE("transport stream initialization failed");
}
static void set_pollset_or_pollset_set(grpc_call_element* elem,
grpc_polling_entity* pollent) {
call_data* calld = static_cast<call_data*>(elem->call_data);
channel_data* chand = static_cast<channel_data*>(elem->channel_data);
grpc_transport_set_pops(chand->transport,
TRANSPORT_STREAM_FROM_CALL_DATA(calld), pollent);
}
static void connected_channel_destroy_call_elem(
grpc_call_element* elem, const grpc_call_final_info* ,
grpc_closure* then_schedule_closure) {
call_data* calld = static_cast<call_data*>(elem->call_data);
channel_data* chand = static_cast<channel_data*>(elem->channel_data);
grpc_transport_destroy_stream(chand->transport,
TRANSPORT_STREAM_FROM_CALL_DATA(calld),
then_schedule_closure);
}
static grpc_error_handle connected_channel_init_channel_elem(
grpc_channel_element* elem, grpc_channel_element_args* args) {
channel_data* cd = static_cast<channel_data*>(elem->channel_data);
GPR_ASSERT(args->is_last);
cd->transport = args->channel_args.GetObject<grpc_transport>();
return absl::OkStatus();
}
static void connected_channel_destroy_channel_elem(grpc_channel_element* elem) {
channel_data* cd = static_cast<channel_data*>(elem->channel_data);
if (cd->transport) {
grpc_transport_destroy(cd->transport);
}
}
static void connected_channel_get_channel_info(
grpc_channel_element* , const grpc_channel_info* ) {
}
namespace grpc_core {
namespace {
#if defined(GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_CLIENT_CALL) || \
defined(GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_SERVER_CALL)
class ConnectedChannelStream : public Orphanable {
public:
explicit ConnectedChannelStream(grpc_transport* transport)
: transport_(transport), stream_(nullptr, StreamDeleter(this)) {
GRPC_STREAM_REF_INIT(
&stream_refcount_, 1,
[](void* p, grpc_error_handle) {
static_cast<ConnectedChannelStream*>(p)->BeginDestroy();
},
this, "ConnectedChannelStream");
}
grpc_transport* transport() { return transport_; }
grpc_closure* stream_destroyed_closure() { return &stream_destroyed_; }
BatchBuilder::Target batch_target() {
return BatchBuilder::Target{transport_, stream_.get(), &stream_refcount_};
}
void IncrementRefCount(const char* reason = "smartptr") {
#ifndef NDEBUG
grpc_stream_ref(&stream_refcount_, reason);
#else
(void)reason;
grpc_stream_ref(&stream_refcount_);
#endif
}
void Unref(const char* reason = "smartptr") {
#ifndef NDEBUG
grpc_stream_unref(&stream_refcount_, reason);
#else
(void)reason;
grpc_stream_unref(&stream_refcount_);
#endif
}
RefCountedPtr<ConnectedChannelStream> InternalRef() {
IncrementRefCount("smartptr");
return RefCountedPtr<ConnectedChannelStream>(this);
}
void Orphan() final {
bool finished = finished_.IsSet();
if (grpc_call_trace.enabled()) {
gpr_log(GPR_DEBUG, "%s[connected] Orphan stream, finished: %d",
party_->DebugTag().c_str(), finished);
}
if (!finished) {
party_->Spawn(
"finish",
[self = InternalRef()]() {
if (!self->finished_.IsSet()) {
self->finished_.Set();
}
return Empty{};
},
[](Empty) {});
GetContext<BatchBuilder>()->Cancel(batch_target(),
absl::CancelledError());
}
Unref("orphan connected stream");
}
auto RecvMessages(PipeSender<MessageHandle>* incoming_messages,
bool cancel_on_error);
auto SendMessages(PipeReceiver<MessageHandle>* outgoing_messages);
void SetStream(grpc_stream* stream) { stream_.reset(stream); }
grpc_stream* stream() { return stream_.get(); }
grpc_stream_refcount* stream_refcount() { return &stream_refcount_; }
void set_finished() { finished_.Set(); }
auto WaitFinished() { return finished_.Wait(); }
private:
class StreamDeleter {
public:
explicit StreamDeleter(ConnectedChannelStream* impl) : impl_(impl) {}
void operator()(grpc_stream* stream) const {
if (stream == nullptr) return;
grpc_transport_destroy_stream(impl_->transport(), stream,
impl_->stream_destroyed_closure());
}
private:
ConnectedChannelStream* impl_;
};
using StreamPtr = std::unique_ptr<grpc_stream, StreamDeleter>;
void StreamDestroyed() {
call_context_->RunInContext([this] { this->~ConnectedChannelStream(); });
}
void BeginDestroy() {
if (stream_ != nullptr) {
stream_.reset();
} else {
StreamDestroyed();
}
}
grpc_transport* const transport_;
RefCountedPtr<CallContext> const call_context_{
GetContext<CallContext>()->Ref()};
grpc_closure stream_destroyed_ =
MakeMemberClosure<ConnectedChannelStream,
&ConnectedChannelStream::StreamDestroyed>(
this, DEBUG_LOCATION);
grpc_stream_refcount stream_refcount_;
StreamPtr stream_;
Arena* arena_ = GetContext<Arena>();
Party* const party_ = static_cast<Party*>(Activity::current());
ExternallyObservableLatch<void> finished_;
};
auto ConnectedChannelStream::RecvMessages(
PipeSender<MessageHandle>* incoming_messages, bool cancel_on_error) {
return Loop([self = InternalRef(), cancel_on_error,
incoming_messages = std::move(*incoming_messages)]() mutable {
return Seq(
GetContext<BatchBuilder>()->ReceiveMessage(self->batch_target()),
[cancel_on_error, &incoming_messages](
absl::StatusOr<absl::optional<MessageHandle>> status) mutable {
bool has_message = status.ok() && status->has_value();
auto publish_message = [&incoming_messages, &status]() {
auto pending_message = std::move(**status);
if (grpc_call_trace.enabled()) {
gpr_log(GPR_INFO,
"%s[connected] RecvMessage: received payload of %" PRIdPTR
" bytes",
Activity::current()->DebugTag().c_str(),
pending_message->payload()->Length());
}
return Map(incoming_messages.Push(std::move(pending_message)),
[](bool ok) -> LoopCtl<absl::Status> {
if (!ok) {
if (grpc_call_trace.enabled()) {
gpr_log(GPR_INFO,
"%s[connected] RecvMessage: failed to "
"push message towards the application",
Activity::current()->DebugTag().c_str());
}
return absl::OkStatus();
}
return Continue{};
});
};
auto publish_close = [cancel_on_error, &incoming_messages,
&status]() mutable {
if (grpc_call_trace.enabled()) {
gpr_log(GPR_INFO,
"%s[connected] RecvMessage: reached end of stream with "
"status:%s",
Activity::current()->DebugTag().c_str(),
status.status().ToString().c_str());
}
if (cancel_on_error && !status.ok()) {
incoming_messages.CloseWithError();
}
return Immediate(LoopCtl<absl::Status>(status.status()));
};
return If(has_message, std::move(publish_message),
std::move(publish_close));
});
});
}
auto ConnectedChannelStream::SendMessages(
PipeReceiver<MessageHandle>* outgoing_messages) {
return ForEach(std::move(*outgoing_messages),
[self = InternalRef()](MessageHandle message) {
return GetContext<BatchBuilder>()->SendMessage(
self->batch_target(), std::move(message));
});
}
#endif
#ifdef GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_CLIENT_CALL
ArenaPromise<ServerMetadataHandle> MakeClientCallPromise(
grpc_transport* transport, CallArgs call_args, NextPromiseFactory) {
OrphanablePtr<ConnectedChannelStream> stream(
GetContext<Arena>()->New<ConnectedChannelStream>(transport));
stream->SetStream(static_cast<grpc_stream*>(
GetContext<Arena>()->Alloc(transport->vtable->sizeof_stream)));
grpc_transport_init_stream(transport, stream->stream(),
stream->stream_refcount(), nullptr,
GetContext<Arena>());
auto* party = static_cast<Party*>(Activity::current());
party->Spawn(
"set_polling_entity", call_args.polling_entity->Wait(),
[transport,
stream = stream->InternalRef()](grpc_polling_entity polling_entity) {
grpc_transport_set_pops(transport, stream->stream(), &polling_entity);
});
party->Spawn(
"send_messages",
TrySeq(stream->SendMessages(call_args.client_to_server_messages),
[stream = stream->InternalRef()]() {
return GetContext<BatchBuilder>()->SendClientTrailingMetadata(
stream->batch_target());
}),
[](absl::Status) {});
auto server_initial_metadata =
GetContext<Arena>()->MakePooled<ServerMetadata>(GetContext<Arena>());
party->Spawn(
"recv_initial_metadata",
TrySeq(GetContext<BatchBuilder>()->ReceiveServerInitialMetadata(
stream->batch_target()),
[pipe = call_args.server_initial_metadata](
ServerMetadataHandle server_initial_metadata) {
if (grpc_call_trace.enabled()) {
gpr_log(GPR_DEBUG,
"%s[connected] Publish client initial metadata: %s",
Activity::current()->DebugTag().c_str(),
server_initial_metadata->DebugString().c_str());
}
return Map(pipe->Push(std::move(server_initial_metadata)),
[](bool r) {
if (r) return absl::OkStatus();
return absl::CancelledError();
});
}),
[](absl::Status) {});
auto send_initial_metadata = Seq(
GetContext<BatchBuilder>()->SendClientInitialMetadata(
stream->batch_target(), std::move(call_args.client_initial_metadata)),
[sent_initial_metadata_token =
std::move(call_args.client_initial_metadata_outstanding)](
absl::Status status) mutable {
sent_initial_metadata_token.Complete(status.ok());
return status;
});
auto server_trailing_metadata =
GetContext<Arena>()->MakePooled<ServerMetadata>(GetContext<Arena>());
auto recv_trailing_metadata =
Map(GetContext<BatchBuilder>()->ReceiveServerTrailingMetadata(
stream->batch_target()),
[](absl::StatusOr<ServerMetadataHandle> status) mutable {
if (!status.ok()) {
auto server_trailing_metadata =
GetContext<Arena>()->MakePooled<ServerMetadata>(
GetContext<Arena>());
grpc_status_code status_code = GRPC_STATUS_UNKNOWN;
std::string message;
grpc_error_get_status(status.status(), Timestamp::InfFuture(),
&status_code, &message, nullptr, nullptr);
server_trailing_metadata->Set(GrpcStatusMetadata(), status_code);
server_trailing_metadata->Set(GrpcMessageMetadata(),
Slice::FromCopiedString(message));
return server_trailing_metadata;
} else {
return std::move(*status);
}
});
auto recv_messages =
stream->RecvMessages(call_args.server_to_client_messages, false);
return Map(
[send_initial_metadata = std::move(send_initial_metadata),
recv_messages = std::move(recv_messages),
recv_trailing_metadata = std::move(recv_trailing_metadata),
done_send_initial_metadata = false, done_recv_messages = false,
done_recv_trailing_metadata =
false]() mutable -> Poll<ServerMetadataHandle> {
if (!done_send_initial_metadata) {
auto p = send_initial_metadata();
if (auto* r = p.value_if_ready()) {
done_send_initial_metadata = true;
if (!r->ok()) return StatusCast<ServerMetadataHandle>(*r);
}
}
if (!done_recv_messages) {
auto p = recv_messages();
if (auto* r = p.value_if_ready()) {
done_recv_messages = true;
} else {
return Pending{};
}
}
if (!done_recv_trailing_metadata) {
auto p = recv_trailing_metadata();
if (auto* r = p.value_if_ready()) {
done_recv_trailing_metadata = true;
return std::move(*r);
}
}
return Pending{};
},
[stream = std::move(stream)](ServerMetadataHandle result) {
stream->set_finished();
return result;
});
}
#endif
#ifdef GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_SERVER_CALL
ArenaPromise<ServerMetadataHandle> MakeServerCallPromise(
grpc_transport* transport, CallArgs,
NextPromiseFactory next_promise_factory) {
OrphanablePtr<ConnectedChannelStream> stream(
GetContext<Arena>()->New<ConnectedChannelStream>(transport));
stream->SetStream(static_cast<grpc_stream*>(
GetContext<Arena>()->Alloc(transport->vtable->sizeof_stream)));
grpc_transport_init_stream(
transport, stream->stream(), stream->stream_refcount(),
GetContext<CallContext>()->server_call_context()->server_stream_data(),
GetContext<Arena>());
auto* party = static_cast<Party*>(Activity::current());
struct CallData {
Pipe<MessageHandle> server_to_client;
Pipe<MessageHandle> client_to_server;
Pipe<ServerMetadataHandle> server_initial_metadata;
Latch<ServerMetadataHandle> failure_latch;
Latch<grpc_polling_entity> polling_entity_latch;
bool sent_initial_metadata = false;
bool sent_trailing_metadata = false;
};
auto* call_data = GetContext<Arena>()->ManagedNew<CallData>();
party->Spawn(
"set_polling_entity", call_data->polling_entity_latch.Wait(),
[transport,
stream = stream->InternalRef()](grpc_polling_entity polling_entity) {
grpc_transport_set_pops(transport, stream->stream(), &polling_entity);
});
auto server_to_client_empty =
call_data->server_to_client.receiver.AwaitEmpty();
auto recv_initial_metadata_then_run_promise =
TrySeq(GetContext<BatchBuilder>()->ReceiveClientInitialMetadata(
stream->batch_target()),
[next_promise_factory = std::move(next_promise_factory),
server_to_client_empty = std::move(server_to_client_empty),
call_data](ClientMetadataHandle client_initial_metadata) {
auto call_promise = next_promise_factory(CallArgs{
std::move(client_initial_metadata),
ClientInitialMetadataOutstandingToken::Empty(),
&call_data->polling_entity_latch,
&call_data->server_initial_metadata.sender,
&call_data->client_to_server.receiver,
&call_data->server_to_client.sender,
});
return Race(call_data->failure_latch.Wait(),
[call_promise = std::move(call_promise),
server_to_client_empty =
std::move(server_to_client_empty)]() mutable
-> Poll<ServerMetadataHandle> {
if (server_to_client_empty().pending()) {
return Pending{};
}
return call_promise();
});
});
auto send_trailing_metadata = [call_data, stream = stream->InternalRef()](
ServerMetadataHandle
server_trailing_metadata) {
bool is_cancellation =
server_trailing_metadata->get(GrpcCallWasCancelled()).value_or(false);
return GetContext<BatchBuilder>()->SendServerTrailingMetadata(
stream->batch_target(), std::move(server_trailing_metadata),
is_cancellation ||
!std::exchange(call_data->sent_initial_metadata, true));
};
party->Spawn(
"recv_messages",
Race(
Map(stream->WaitFinished(), [](Empty) { return absl::OkStatus(); }),
Map(stream->RecvMessages(&call_data->client_to_server.sender, true),
[failure_latch = &call_data->failure_latch](absl::Status status) {
if (!status.ok() && !failure_latch->is_set()) {
failure_latch->Set(ServerMetadataFromStatus(status));
}
return status;
})),
[](absl::Status) {});
auto send_initial_metadata = Seq(
Race(Map(stream->WaitFinished(),
[](Empty) { return NextResult<ServerMetadataHandle>(true); }),
call_data->server_initial_metadata.receiver.Next()),
[call_data, stream = stream->InternalRef()](
NextResult<ServerMetadataHandle> next_result) mutable {
auto md = !call_data->sent_initial_metadata && next_result.has_value()
? std::move(next_result.value())
: nullptr;
if (md != nullptr) {
call_data->sent_initial_metadata = true;
auto* party = static_cast<Party*>(Activity::current());
party->Spawn("connected/send_initial_metadata",
GetContext<BatchBuilder>()->SendServerInitialMetadata(
stream->batch_target(), std::move(md)),
[](absl::Status) {});
return Immediate(absl::OkStatus());
}
return Immediate(absl::CancelledError());
});
party->Spawn(
"send_initial_metadata_then_messages",
Race(Map(stream->WaitFinished(), [](Empty) { return absl::OkStatus(); }),
TrySeq(std::move(send_initial_metadata),
stream->SendMessages(&call_data->server_to_client.receiver))),
[](absl::Status) {});
party->Spawn(
"recv_trailing_metadata",
Seq(GetContext<BatchBuilder>()->ReceiveClientTrailingMetadata(
stream->batch_target()),
[failure_latch = &call_data->failure_latch](
absl::StatusOr<ClientMetadataHandle> status) mutable {
if (grpc_call_trace.enabled()) {
gpr_log(
GPR_DEBUG,
"%s[connected] Got trailing metadata; status=%s metadata=%s",
Activity::current()->DebugTag().c_str(),
status.status().ToString().c_str(),
status.ok() ? (*status)->DebugString().c_str() : "<none>");
}
ClientMetadataHandle trailing_metadata;
if (status.ok()) {
trailing_metadata = std::move(*status);
} else {
trailing_metadata =
GetContext<Arena>()->MakePooled<ClientMetadata>(
GetContext<Arena>());
grpc_status_code status_code = GRPC_STATUS_UNKNOWN;
std::string message;
grpc_error_get_status(status.status(), Timestamp::InfFuture(),
&status_code, &message, nullptr, nullptr);
trailing_metadata->Set(GrpcStatusMetadata(), status_code);
trailing_metadata->Set(GrpcMessageMetadata(),
Slice::FromCopiedString(message));
}
if (trailing_metadata->get(GrpcStatusMetadata())
.value_or(GRPC_STATUS_UNKNOWN) != GRPC_STATUS_OK) {
if (!failure_latch->is_set()) {
failure_latch->Set(std::move(trailing_metadata));
}
}
return Empty{};
}),
[](Empty) {});
struct CleanupPollingEntityLatch {
void operator()(Latch<grpc_polling_entity>* latch) {
if (!latch->is_set()) latch->Set(grpc_polling_entity());
}
};
auto cleanup_polling_entity_latch =
std::unique_ptr<Latch<grpc_polling_entity>, CleanupPollingEntityLatch>(
&call_data->polling_entity_latch);
struct CleanupSendInitialMetadata {
void operator()(CallData* call_data) {
call_data->server_initial_metadata.receiver.CloseWithError();
}
};
auto cleanup_send_initial_metadata =
std::unique_ptr<CallData, CleanupSendInitialMetadata>(call_data);
return Map(
Seq(std::move(recv_initial_metadata_then_run_promise),
std::move(send_trailing_metadata)),
[cleanup_polling_entity_latch = std::move(cleanup_polling_entity_latch),
cleanup_send_initial_metadata = std::move(cleanup_send_initial_metadata),
stream = std::move(stream)](ServerMetadataHandle md) {
stream->set_finished();
return md;
});
}
#endif
template <ArenaPromise<ServerMetadataHandle> (*make_call_promise)(
grpc_transport*, CallArgs, NextPromiseFactory)>
grpc_channel_filter MakeConnectedFilter() {
auto make_call_wrapper = +[](grpc_channel_element* elem, CallArgs call_args,
NextPromiseFactory next) {
grpc_transport* transport =
static_cast<channel_data*>(elem->channel_data)->transport;
return make_call_promise(transport, std::move(call_args), std::move(next));
};
return {
connected_channel_start_transport_stream_op_batch,
make_call_promise != nullptr ? make_call_wrapper : nullptr,
connected_channel_start_transport_op,
sizeof(call_data),
connected_channel_init_call_elem,
set_pollset_or_pollset_set,
connected_channel_destroy_call_elem,
sizeof(channel_data),
connected_channel_init_channel_elem,
+[](grpc_channel_stack* channel_stack, grpc_channel_element* elem) {
channel_stack->call_stack_size += grpc_transport_stream_size(
static_cast<channel_data*>(elem->channel_data)->transport);
},
connected_channel_destroy_channel_elem,
connected_channel_get_channel_info,
"connected",
};
}
ArenaPromise<ServerMetadataHandle> MakeTransportCallPromise(
grpc_transport* transport, CallArgs call_args, NextPromiseFactory) {
return transport->vtable->make_call_promise(transport, std::move(call_args));
}
const grpc_channel_filter kPromiseBasedTransportFilter =
MakeConnectedFilter<MakeTransportCallPromise>();
#ifdef GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_CLIENT_CALL
const grpc_channel_filter kClientEmulatedFilter =
MakeConnectedFilter<MakeClientCallPromise>();
#else
const grpc_channel_filter kClientEmulatedFilter =
MakeConnectedFilter<nullptr>();
#endif
#ifdef GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_SERVER_CALL
const grpc_channel_filter kServerEmulatedFilter =
MakeConnectedFilter<MakeServerCallPromise>();
#else
const grpc_channel_filter kServerEmulatedFilter =
MakeConnectedFilter<nullptr>();
#endif
} }
bool grpc_add_connected_filter(grpc_core::ChannelStackBuilder* builder) {
grpc_transport* t = builder->transport();
GPR_ASSERT(t != nullptr);
if (t->vtable->make_call_promise != nullptr) {
builder->AppendFilter(&grpc_core::kPromiseBasedTransportFilter);
} else if (grpc_channel_stack_type_is_client(builder->channel_stack_type())) {
builder->AppendFilter(&grpc_core::kClientEmulatedFilter);
} else {
builder->AppendFilter(&grpc_core::kServerEmulatedFilter);
}
return true;
}