use std::{
collections::BTreeMap,
pin::Pin,
sync::{
Arc, Weak,
atomic::{AtomicU64, Ordering},
},
};
use moire::sync::{Semaphore, SyncMutex};
use tokio::sync::watch;
use moire::task::FutureExt as _;
use vox_types::{
BoxFut, CallResult, ChannelBinder, ChannelBody, ChannelClose, ChannelCreditReplenisher,
ChannelCreditReplenisherHandle, ChannelId, ChannelItem, ChannelLivenessHandle, ChannelMessage,
ChannelRetryMode, ChannelSink, CreditSink, Handler, IdAllocator, IncomingChannelMessage,
MaybeSend, MaybeSync, Payload, ReplySink, RequestBody, RequestCall, RequestId, RequestMessage,
RequestResponse, SelfRef, TxError, VoxError, ensure_operation_id, metadata_channel_retry_mode,
metadata_operation_id,
};
use crate::session::{
ConnectionHandle, ConnectionMessage, ConnectionSender, DropControlRequest, FailureDisposition,
};
use crate::{InMemoryOperationStore, OperationStore};
use moire::sync::mpsc;
use vox_types::{OperationId, PostcardPayload, SchemaHash, TypeRef};
struct PendingResponse {
msg: SelfRef<RequestMessage<'static>>,
schemas: Arc<vox_types::SchemaRecvTracker>,
}
type ResponseSlot = moire::sync::oneshot::Sender<PendingResponse>;
struct InFlightHandler {
handle: moire::task::JoinHandle<()>,
method_id: vox_types::MethodId,
retry: vox_types::RetryPolicy,
has_channels: bool,
operation_id: Option<OperationId>,
}
struct LiveOperationTracker {
live: HashMap<OperationId, LiveOperation>,
request_to_operation: HashMap<RequestId, OperationId>,
}
struct LiveOperation {
method_id: vox_types::MethodId,
args_hash: u64,
owner_request_id: RequestId,
waiters: Vec<RequestId>,
retry: vox_types::RetryPolicy,
}
enum AdmitResult {
Start,
Attached,
Conflict,
}
impl LiveOperationTracker {
fn new() -> Self {
Self {
live: HashMap::new(),
request_to_operation: HashMap::new(),
}
}
fn admit(
&mut self,
operation_id: OperationId,
method_id: vox_types::MethodId,
args: &[u8],
retry: vox_types::RetryPolicy,
request_id: RequestId,
) -> AdmitResult {
use std::hash::{Hash, Hasher};
let args_hash = {
let mut h = std::collections::hash_map::DefaultHasher::new();
method_id.hash(&mut h);
args.hash(&mut h);
h.finish()
};
let live_operations = self.live.len();
if let Some(live) = self.live.get_mut(&operation_id) {
if live.method_id != method_id || live.args_hash != args_hash {
let request_bindings = self.request_to_operation.len();
tracing::trace!(
%operation_id,
%request_id,
?method_id,
live_operations,
request_bindings,
"live operation conflict"
);
return AdmitResult::Conflict;
}
live.waiters.push(request_id);
self.request_to_operation.insert(request_id, operation_id);
let waiters = live.waiters.len();
let request_bindings = self.request_to_operation.len();
tracing::trace!(
%operation_id,
%request_id,
?method_id,
waiters,
live_operations,
request_bindings,
"live operation attached"
);
return AdmitResult::Attached;
}
self.live.insert(
operation_id,
LiveOperation {
method_id,
args_hash,
owner_request_id: request_id,
waiters: vec![request_id],
retry,
},
);
self.request_to_operation.insert(request_id, operation_id);
let live_operations = self.live.len();
let request_bindings = self.request_to_operation.len();
tracing::trace!(
%operation_id,
%request_id,
?method_id,
live_operations,
request_bindings,
"live operation admitted"
);
AdmitResult::Start
}
fn seal(&mut self, operation_id: OperationId) -> Vec<RequestId> {
if let Some(live) = self.live.remove(&operation_id) {
for waiter in &live.waiters {
self.request_to_operation.remove(waiter);
}
let waiters = live.waiters.len();
let live_operations = self.live.len();
let request_bindings = self.request_to_operation.len();
tracing::trace!(
%operation_id,
waiters,
live_operations,
request_bindings,
"live operation sealed"
);
live.waiters
} else {
vec![]
}
}
fn release(&mut self, operation_id: OperationId) -> Option<LiveOperation> {
if let Some(live) = self.live.remove(&operation_id) {
for waiter in &live.waiters {
self.request_to_operation.remove(waiter);
}
let waiters = live.waiters.len();
let live_operations = self.live.len();
let request_bindings = self.request_to_operation.len();
tracing::trace!(
%operation_id,
waiters,
live_operations,
request_bindings,
"live operation released"
);
Some(live)
} else {
None
}
}
fn cancel(&mut self, request_id: RequestId) -> CancelResult {
let Some(&operation_id) = self.request_to_operation.get(&request_id) else {
return CancelResult::NotFound;
};
let live_operations = self.live.len();
let Some(live) = self.live.get_mut(&operation_id) else {
self.request_to_operation.remove(&request_id);
return CancelResult::NotFound;
};
if live.retry.persist {
if live.owner_request_id == request_id {
return CancelResult::NotFound; }
live.waiters.retain(|w| *w != request_id);
self.request_to_operation.remove(&request_id);
let waiters = live.waiters.len();
let request_bindings = self.request_to_operation.len();
tracing::trace!(
%operation_id,
%request_id,
waiters,
live_operations,
request_bindings,
"live operation detached waiter"
);
CancelResult::Detached
} else {
let live = self.live.remove(&operation_id).unwrap();
for waiter in &live.waiters {
self.request_to_operation.remove(waiter);
}
let waiters = live.waiters.len();
let live_operations = self.live.len();
let request_bindings = self.request_to_operation.len();
tracing::trace!(
%operation_id,
%request_id,
waiters,
live_operations,
request_bindings,
"live operation aborted"
);
CancelResult::Abort {
owner_request_id: live.owner_request_id,
waiters: live.waiters,
}
}
}
}
enum CancelResult {
NotFound,
Detached,
Abort {
owner_request_id: RequestId,
waiters: Vec<RequestId>,
},
}
use std::collections::HashMap;
struct DriverShared {
pending_responses: SyncMutex<BTreeMap<RequestId, ResponseSlot>>,
request_ids: SyncMutex<IdAllocator<RequestId>>,
next_operation_id: AtomicU64,
operations: Arc<dyn OperationStore>,
channel_ids: SyncMutex<IdAllocator<ChannelId>>,
channel_senders: SyncMutex<BTreeMap<ChannelId, mpsc::Sender<IncomingChannelMessage>>>,
channel_buffers: SyncMutex<BTreeMap<ChannelId, Vec<IncomingChannelMessage>>>,
channel_credits: SyncMutex<BTreeMap<ChannelId, Arc<Semaphore>>>,
stale_close_channels: SyncMutex<std::collections::HashSet<ChannelId>>,
}
struct CallerDropGuard {
control_tx: mpsc::UnboundedSender<DropControlRequest>,
request: DropControlRequest,
}
impl Drop for CallerDropGuard {
fn drop(&mut self) {
let _ = self.control_tx.send(self.request);
}
}
#[cfg(test)]
mod tests {
use super::{DriverChannelCreditReplenisher, DriverLocalControl};
use vox_types::{ChannelCreditReplenisher, ChannelId};
#[tokio::test]
async fn replenisher_batches_at_half_the_initial_window() {
let (tx, mut rx) = moire::sync::mpsc::unbounded_channel("test.replenisher");
let replenisher = DriverChannelCreditReplenisher::new(ChannelId(7), 16, tx);
for _ in 0..7 {
replenisher.on_item_consumed();
}
assert!(
tokio::time::timeout(std::time::Duration::from_millis(20), rx.recv())
.await
.is_err(),
"should not emit credit before reaching the batch threshold"
);
replenisher.on_item_consumed();
let Some(DriverLocalControl::GrantCredit {
channel_id,
additional,
}) = rx.recv().await
else {
panic!("expected batched credit grant");
};
assert_eq!(channel_id, ChannelId(7));
assert_eq!(additional, 8);
}
#[tokio::test]
async fn replenisher_grants_one_by_one_for_single_credit_windows() {
let (tx, mut rx) = moire::sync::mpsc::unbounded_channel("test.replenisher.single");
let replenisher = DriverChannelCreditReplenisher::new(ChannelId(9), 1, tx);
replenisher.on_item_consumed();
let Some(DriverLocalControl::GrantCredit {
channel_id,
additional,
}) = rx.recv().await
else {
panic!("expected immediate credit grant");
};
assert_eq!(channel_id, ChannelId(9));
assert_eq!(additional, 1);
}
}
pub struct DriverReplySink {
sender: Option<ConnectionSender>,
request_id: RequestId,
method_id: vox_types::MethodId,
retry: vox_types::RetryPolicy,
operation_id: Option<OperationId>,
operations: Option<Arc<dyn OperationStore>>,
live_operations: Option<Arc<SyncMutex<LiveOperationTracker>>>,
binder: DriverChannelBinder,
}
async fn replay_sealed_response(
sender: ConnectionSender,
request_id: RequestId,
method_id: vox_types::MethodId,
encoded_response: &[u8],
root_type: TypeRef,
operations: &dyn OperationStore,
) -> Result<(), ()> {
let mut response: RequestResponse<'_> =
vox_postcard::from_slice_borrowed(encoded_response).map_err(|_| ())?;
sender.prepare_replay_schemas(request_id, method_id, &root_type, operations, &mut response);
sender.send_response(request_id, response).await
}
fn extract_root_type_ref(schemas_cbor: &vox_types::CborPayload) -> TypeRef {
if schemas_cbor.is_empty() {
return TypeRef::concrete(SchemaHash(0));
}
let payload =
vox_types::SchemaPayload::from_cbor(&schemas_cbor.0).expect("schema CBOR must be valid");
payload.root
}
fn incoming_args_bytes<'a>(call: &'a RequestCall<'a>) -> &'a [u8] {
match &call.args {
Payload::PostcardBytes(bytes) => bytes,
Payload::Value { .. } => {
panic!("incoming request payload should always be decoded as incoming bytes")
}
}
}
impl ReplySink for DriverReplySink {
async fn send_reply(mut self, response: RequestResponse<'_>) {
let sender = self
.sender
.take()
.expect("unreachable: send_reply takes self by value");
vox_types::dlog!(
"[driver] send_reply: conn={:?} req={:?} method={:?} payload={} operation_id={:?}",
sender.connection_id(),
self.request_id,
self.method_id,
match &response.ret {
Payload::Value { .. } => "Value",
Payload::PostcardBytes(_) => "PostcardBytes",
},
self.operation_id
);
if let Payload::Value { shape, .. } = &response.ret
&& let Ok(extracted) = vox_types::extract_schemas(shape)
{
vox_types::dlog!(
"[schema] driver send_reply: method={:?} root={:?}",
self.method_id,
extracted.root
);
}
if let (Some(operation_id), Some(operations)) = (self.operation_id, self.operations.take())
{
let mut response = response;
sender.prepare_response_for_method(self.request_id, self.method_id, &mut response);
let root_type = extract_root_type_ref(&response.schemas);
let schemas_for_wire = std::mem::take(&mut response.schemas);
let encoded_for_store = PostcardPayload(
vox_postcard::to_vec(&response).expect("serialize operation response for store"),
);
response.schemas = schemas_for_wire;
vox_types::dlog!(
"[driver] send_reply wire send: conn={:?} req={:?} method={:?} schemas={}",
sender.connection_id(),
self.request_id,
self.method_id,
response.schemas.0.len()
);
if let Err(_e) = sender.send_response(self.request_id, response).await {
sender.mark_failure(self.request_id, FailureDisposition::Cancelled);
}
let registry = sender.schema_registry();
operations.seal(operation_id, &encoded_for_store, &root_type, ®istry);
let waiters = self
.live_operations
.as_ref()
.map(|lo| lo.lock().seal(operation_id))
.unwrap_or_default();
for waiter in waiters {
if waiter == self.request_id {
continue;
}
if replay_sealed_response(
sender.clone(),
waiter,
self.method_id,
encoded_for_store.as_bytes(),
root_type.clone(),
operations.as_ref(),
)
.await
.is_err()
{
sender.mark_failure(waiter, FailureDisposition::Cancelled);
}
}
} else {
vox_types::dlog!(
"[driver] send_reply direct send: conn={:?} req={:?} method={:?}",
sender.connection_id(),
self.request_id,
self.method_id
);
if let Err(_e) = sender
.send_response_for_method(self.request_id, self.method_id, response)
.await
{
sender.mark_failure(self.request_id, FailureDisposition::Cancelled);
}
}
}
fn channel_binder(&self) -> Option<&dyn ChannelBinder> {
Some(&self.binder)
}
fn request_id(&self) -> Option<RequestId> {
Some(self.request_id)
}
fn connection_id(&self) -> Option<vox_types::ConnectionId> {
self.sender.as_ref().map(|sender| sender.connection_id())
}
}
impl Drop for DriverReplySink {
fn drop(&mut self) {
if let Some(sender) = self.sender.take() {
let disposition = if self.retry.persist {
FailureDisposition::Indeterminate
} else {
FailureDisposition::Cancelled
};
if let Some(operation_id) = self.operation_id {
if let Some(live_ops) = self.live_operations.take()
&& let Some(live) = live_ops.lock().release(operation_id)
{
for waiter in live.waiters {
sender.mark_failure(waiter, disposition);
}
return;
}
}
sender.mark_failure(self.request_id, disposition);
}
}
}
pub struct DriverChannelSink {
sender: ConnectionSender,
channel_id: ChannelId,
local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
}
impl ChannelSink for DriverChannelSink {
fn send_payload<'payload>(
&self,
payload: Payload<'payload>,
) -> Pin<Box<dyn vox_types::MaybeSendFuture<Output = Result<(), TxError>> + 'payload>> {
let sender = self.sender.clone();
let channel_id = self.channel_id;
Box::pin(async move {
sender
.send(ConnectionMessage::Channel(ChannelMessage {
id: channel_id,
body: ChannelBody::Item(ChannelItem { item: payload }),
}))
.await
.map_err(|()| TxError::Transport("connection closed".into()))
})
}
fn close_channel(
&self,
_metadata: vox_types::Metadata,
) -> Pin<Box<dyn vox_types::MaybeSendFuture<Output = Result<(), TxError>> + 'static>> {
let sender = self.sender.clone();
let channel_id = self.channel_id;
Box::pin(async move {
sender
.send(ConnectionMessage::Channel(ChannelMessage {
id: channel_id,
body: ChannelBody::Close(ChannelClose {
metadata: Default::default(),
}),
}))
.await
.map_err(|()| TxError::Transport("connection closed".into()))
})
}
fn close_channel_on_drop(&self) {
let _ = self
.local_control_tx
.send(DriverLocalControl::CloseChannel {
channel_id: self.channel_id,
});
}
}
pub trait ErasedHandler: MaybeSend + MaybeSync + 'static {
fn retry_policy(&self, method_id: vox_types::MethodId) -> vox_types::RetryPolicy {
let _ = method_id;
vox_types::RetryPolicy::VOLATILE
}
fn args_have_channels(&self, method_id: vox_types::MethodId) -> bool {
let _ = method_id;
false
}
fn response_wire_shape(&self, method_id: vox_types::MethodId) -> Option<&'static facet::Shape> {
let _ = method_id;
None
}
fn handle_erased(
&self,
call: SelfRef<RequestCall<'static>>,
reply: DriverReplySink,
schemas: std::sync::Arc<vox_types::SchemaRecvTracker>,
) -> BoxFut<'_, ()>;
}
impl<H: Handler<DriverReplySink>> ErasedHandler for H {
fn retry_policy(&self, method_id: vox_types::MethodId) -> vox_types::RetryPolicy {
Handler::retry_policy(self, method_id)
}
fn args_have_channels(&self, method_id: vox_types::MethodId) -> bool {
Handler::args_have_channels(self, method_id)
}
fn response_wire_shape(&self, method_id: vox_types::MethodId) -> Option<&'static facet::Shape> {
Handler::response_wire_shape(self, method_id)
}
fn handle_erased(
&self,
call: SelfRef<RequestCall<'static>>,
reply: DriverReplySink,
schemas: std::sync::Arc<vox_types::SchemaRecvTracker>,
) -> BoxFut<'_, ()> {
Box::pin(Handler::handle(self, call, reply, schemas))
}
}
impl Handler<DriverReplySink> for Box<dyn ErasedHandler> {
fn retry_policy(&self, method_id: vox_types::MethodId) -> vox_types::RetryPolicy {
(**self).retry_policy(method_id)
}
fn args_have_channels(&self, method_id: vox_types::MethodId) -> bool {
(**self).args_have_channels(method_id)
}
fn response_wire_shape(&self, method_id: vox_types::MethodId) -> Option<&'static facet::Shape> {
(**self).response_wire_shape(method_id)
}
async fn handle(
&self,
call: SelfRef<RequestCall<'static>>,
reply: DriverReplySink,
schemas: std::sync::Arc<vox_types::SchemaRecvTracker>,
) {
(**self).handle_erased(call, reply, schemas).await
}
}
#[must_use = "Dropping this caller may close the connection if it is the last caller."]
#[derive(Clone)]
pub struct Caller {
inner: Arc<DriverCaller>,
service: Option<&'static vox_types::ServiceDescriptor>,
middlewares: Vec<Arc<dyn vox_types::ClientMiddleware>>,
}
impl Caller {
pub fn new(driver: DriverCaller) -> Self {
Self {
inner: Arc::new(driver),
service: None,
middlewares: vec![],
}
}
#[cfg(test)]
pub(crate) fn driver(&self) -> &DriverCaller {
&self.inner
}
pub fn with_middleware(
mut self,
service: &'static vox_types::ServiceDescriptor,
middleware: impl vox_types::ClientMiddleware,
) -> Self {
if let Some(existing_service) = self.service {
assert_eq!(
existing_service.service_name, service.service_name,
"Caller middleware service mismatch"
);
} else {
self.service = Some(service);
}
self.middlewares.push(Arc::new(middleware));
self
}
pub async fn call(&self, mut call: RequestCall<'_>) -> CallResult {
use vox_types::{
ClientCallOutcome, ClientContext, ClientRequest, Extensions, OwnedMetadata,
};
let Some(service) = self.service else {
return self.inner.call_inner(call).await;
};
let extensions = Extensions::new();
let method = service.by_id(call.method_id);
let context = ClientContext::new(method, call.method_id, &extensions);
let mut owned_metadata = OwnedMetadata::default();
if !self.middlewares.is_empty() {
for middleware in &self.middlewares {
let mut request = ClientRequest::new(&mut call, &mut owned_metadata);
middleware.pre(&context, &mut request).await;
}
}
let result = self.inner.call_inner(call).await;
if !self.middlewares.is_empty() {
let outcome = match &result {
Ok(_) => ClientCallOutcome::Response,
Err(error) => ClientCallOutcome::Error(error),
};
for middleware in self.middlewares.iter().rev() {
middleware.post(&context, outcome).await;
}
}
result
}
pub async fn closed(&self) {
if *self.inner.closed_rx.borrow() {
return;
}
let mut rx = self.inner.closed_rx.clone();
while rx.changed().await.is_ok() {
if *rx.borrow() {
return;
}
}
}
pub fn is_connected(&self) -> bool {
!*self.inner.closed_rx.borrow()
}
pub fn channel_binder(&self) -> Option<&dyn ChannelBinder> {
Some(self.inner.as_ref())
}
}
pub trait FromVoxSession {
const SERVICE_NAME: &'static str;
fn from_vox_session(
caller: Caller,
session_handle: Option<crate::session::SessionHandle>,
) -> Self;
}
#[must_use = "Dropping NoopClient may close the connection if it is the last caller."]
#[derive(Clone)]
pub struct NoopClient {
pub caller: Caller,
pub session: Option<crate::session::SessionHandle>,
}
impl FromVoxSession for NoopClient {
const SERVICE_NAME: &'static str = "Noop";
fn from_vox_session(caller: Caller, session: Option<crate::session::SessionHandle>) -> Self {
Self { caller, session }
}
}
#[derive(Clone)]
struct DriverChannelBinder {
sender: ConnectionSender,
shared: Arc<DriverShared>,
local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
drop_guard: Option<Arc<CallerDropGuard>>,
}
const DEFAULT_CHANNEL_CREDIT: u32 = 16;
fn register_rx_channel_impl(
shared: &Arc<DriverShared>,
channel_id: ChannelId,
queue_name: &'static str,
liveness: Option<ChannelLivenessHandle>,
local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
) -> vox_types::BoundChannelReceiver {
let (tx, rx) = mpsc::channel(queue_name, 64);
let mut terminal_buffered = false;
{
let mut senders = shared.channel_senders.lock();
senders.insert(channel_id, tx.clone());
let buffered = shared.channel_buffers.lock().remove(&channel_id);
if let Some(buffered) = buffered {
for msg in buffered {
let is_terminal = matches!(
msg,
IncomingChannelMessage::Close(_) | IncomingChannelMessage::Reset(_)
);
let _ = tx.try_send(msg);
if is_terminal {
terminal_buffered = true;
break;
}
}
}
if terminal_buffered {
senders.remove(&channel_id);
}
}
if terminal_buffered {
shared.channel_credits.lock().remove(&channel_id);
return vox_types::BoundChannelReceiver {
receiver: rx,
liveness,
replenisher: None,
};
}
vox_types::BoundChannelReceiver {
receiver: rx,
liveness,
replenisher: Some(Arc::new(DriverChannelCreditReplenisher::new(
channel_id,
DEFAULT_CHANNEL_CREDIT,
local_control_tx,
)) as ChannelCreditReplenisherHandle),
}
}
impl DriverChannelBinder {
fn create_tx_channel(&self) -> (ChannelId, Arc<CreditSink<DriverChannelSink>>) {
let channel_id = self.shared.channel_ids.lock().alloc();
let inner = DriverChannelSink {
sender: self.sender.clone(),
channel_id,
local_control_tx: self.local_control_tx.clone(),
};
let sink = Arc::new(CreditSink::new(inner, DEFAULT_CHANNEL_CREDIT));
self.shared
.channel_credits
.lock()
.insert(channel_id, Arc::clone(sink.credit()));
(channel_id, sink)
}
fn register_rx_channel(&self, channel_id: ChannelId) -> vox_types::BoundChannelReceiver {
register_rx_channel_impl(
&self.shared,
channel_id,
"driver.register_rx_channel",
self.channel_liveness(),
self.local_control_tx.clone(),
)
}
}
impl ChannelBinder for DriverChannelBinder {
fn create_tx(&self) -> (ChannelId, Arc<dyn ChannelSink>) {
let (id, sink) = self.create_tx_channel();
(id, sink as Arc<dyn ChannelSink>)
}
fn create_rx(&self) -> (ChannelId, vox_types::BoundChannelReceiver) {
let channel_id = self.shared.channel_ids.lock().alloc();
let rx = self.register_rx_channel(channel_id);
(channel_id, rx)
}
fn bind_tx(&self, channel_id: ChannelId) -> Arc<dyn ChannelSink> {
let inner = DriverChannelSink {
sender: self.sender.clone(),
channel_id,
local_control_tx: self.local_control_tx.clone(),
};
let sink = Arc::new(CreditSink::new(inner, DEFAULT_CHANNEL_CREDIT));
self.shared
.channel_credits
.lock()
.insert(channel_id, Arc::clone(sink.credit()));
sink
}
fn register_rx(&self, channel_id: ChannelId) -> vox_types::BoundChannelReceiver {
self.register_rx_channel(channel_id)
}
fn channel_liveness(&self) -> Option<ChannelLivenessHandle> {
self.drop_guard
.as_ref()
.map(|guard| guard.clone() as ChannelLivenessHandle)
}
}
#[derive(Clone)]
pub struct DriverCaller {
sender: ConnectionSender,
shared: Arc<DriverShared>,
local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
closed_rx: watch::Receiver<bool>,
resumed_rx: watch::Receiver<u64>,
resume_processed_rx: watch::Receiver<u64>,
peer_supports_retry: bool,
_drop_guard: Option<Arc<CallerDropGuard>>,
}
impl DriverCaller {
pub fn create_tx_channel(&self) -> (ChannelId, Arc<CreditSink<DriverChannelSink>>) {
let channel_id = self.shared.channel_ids.lock().alloc();
let inner = DriverChannelSink {
sender: self.sender.clone(),
channel_id,
local_control_tx: self.local_control_tx.clone(),
};
let sink = Arc::new(CreditSink::new(inner, DEFAULT_CHANNEL_CREDIT));
self.shared
.channel_credits
.lock()
.insert(channel_id, Arc::clone(sink.credit()));
(channel_id, sink)
}
#[cfg(test)]
pub(crate) fn connection_sender(&self) -> &ConnectionSender {
&self.sender
}
pub fn register_rx_channel(&self, channel_id: ChannelId) -> vox_types::BoundChannelReceiver {
register_rx_channel_impl(
&self.shared,
channel_id,
"driver.caller.register_rx_channel",
self.channel_liveness(),
self.local_control_tx.clone(),
)
}
}
impl ChannelBinder for DriverCaller {
fn create_tx(&self) -> (ChannelId, Arc<dyn ChannelSink>) {
let (id, sink) = self.create_tx_channel();
(id, sink as Arc<dyn ChannelSink>)
}
fn create_rx(&self) -> (ChannelId, vox_types::BoundChannelReceiver) {
let channel_id = self.shared.channel_ids.lock().alloc();
let rx = self.register_rx_channel(channel_id);
(channel_id, rx)
}
fn bind_tx(&self, channel_id: ChannelId) -> Arc<dyn ChannelSink> {
let inner = DriverChannelSink {
sender: self.sender.clone(),
channel_id,
local_control_tx: self.local_control_tx.clone(),
};
let sink = Arc::new(CreditSink::new(inner, DEFAULT_CHANNEL_CREDIT));
self.shared
.channel_credits
.lock()
.insert(channel_id, Arc::clone(sink.credit()));
sink
}
fn register_rx(&self, channel_id: ChannelId) -> vox_types::BoundChannelReceiver {
self.register_rx_channel(channel_id)
}
fn channel_liveness(&self) -> Option<ChannelLivenessHandle> {
self._drop_guard
.as_ref()
.map(|guard| guard.clone() as ChannelLivenessHandle)
}
}
impl DriverCaller {
async fn call_inner(&self, mut call: RequestCall<'_>) -> CallResult {
if self.peer_supports_retry {
let operation_id = OperationId(
self.shared
.next_operation_id
.fetch_add(1, Ordering::Relaxed),
);
ensure_operation_id(&mut call.metadata, operation_id);
}
let req_id = self.shared.request_ids.lock().alloc();
let (tx, rx) = moire::sync::oneshot::channel("driver.response");
self.shared.pending_responses.lock().insert(req_id, tx);
if self
.sender
.send_with_binder(
ConnectionMessage::Request(RequestMessage {
id: req_id,
body: RequestBody::Call(RequestCall {
method_id: call.method_id,
args: call.args.reborrow(),
metadata: call.metadata.clone(),
schemas: Default::default(),
}),
}),
Some(self),
)
.await
.is_err()
{
self.shared.pending_responses.lock().remove(&req_id);
return Err(VoxError::SendFailed);
}
let mut resumed_rx = self.resumed_rx.clone();
let mut seen_resume_generation = *resumed_rx.borrow();
let mut resume_processed_rx = self.resume_processed_rx.clone();
let mut closed_rx = self.closed_rx.clone();
let mut response = std::pin::pin!(rx.named("awaiting_response"));
let pending: PendingResponse = loop {
tokio::select! {
result = &mut response => {
match result {
Ok(pending) => break pending,
Err(_) => {
return Err(VoxError::ConnectionClosed);
}
}
}
changed = resumed_rx.changed(), if self.peer_supports_retry => {
vox_types::dlog!("[CALLER] resumed_rx fired");
if changed.is_err() {
self.shared.pending_responses.lock().remove(&req_id);
return Err(VoxError::SessionShutdown);
}
let generation = *resumed_rx.borrow();
if generation == seen_resume_generation {
continue;
}
seen_resume_generation = generation;
while *resume_processed_rx.borrow() < generation {
if resume_processed_rx.changed().await.is_err() {
self.shared.pending_responses.lock().remove(&req_id);
return Err(VoxError::SessionShutdown);
}
}
match metadata_channel_retry_mode(&call.metadata) {
ChannelRetryMode::NonIdem => {
self.shared.pending_responses.lock().remove(&req_id);
return Err(VoxError::Indeterminate);
}
ChannelRetryMode::Idem | ChannelRetryMode::None => {}
}
let _ = self.sender.send_with_binder(
ConnectionMessage::Request(RequestMessage {
id: req_id,
body: RequestBody::Call(RequestCall {
method_id: call.method_id,
args: call.args.reborrow(),
metadata: call.metadata.clone(),
schemas: Default::default(),
}),
}),
Some(self),
).await;
}
changed = closed_rx.changed() => {
vox_types::dlog!("[CALLER] closed_rx fired, value={}", *closed_rx.borrow());
if changed.is_err() || *closed_rx.borrow() {
self.shared.pending_responses.lock().remove(&req_id);
return Err(VoxError::ConnectionClosed);
}
}
}
};
let PendingResponse {
msg: response_msg,
schemas: response_schemas,
} = pending;
let response = response_msg.map(|m| match m.body {
RequestBody::Response(r) => r,
_ => unreachable!("pending_responses only gets Response variants"),
});
Ok(vox_types::WithTracker {
value: response,
tracker: response_schemas,
})
}
}
pub struct Driver<H: Handler<DriverReplySink>> {
sender: ConnectionSender,
rx: mpsc::Receiver<crate::session::RecvMessage>,
failures_rx: mpsc::UnboundedReceiver<(RequestId, FailureDisposition)>,
closed_rx: watch::Receiver<bool>,
resumed_rx: watch::Receiver<u64>,
resume_processed_tx: watch::Sender<u64>,
peer_supports_retry: bool,
local_control_rx: mpsc::UnboundedReceiver<DriverLocalControl>,
handler: Arc<H>,
shared: Arc<DriverShared>,
in_flight_handlers: BTreeMap<RequestId, InFlightHandler>,
live_operations: Arc<SyncMutex<LiveOperationTracker>>,
local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
drop_control_seed: Option<mpsc::UnboundedSender<DropControlRequest>>,
drop_control_request: DropControlRequest,
drop_guard: SyncMutex<Option<Weak<CallerDropGuard>>>,
}
enum DriverLocalControl {
CloseChannel {
channel_id: ChannelId,
},
GrantCredit {
channel_id: ChannelId,
additional: u32,
},
HandlerCompleted {
request_id: RequestId,
},
}
struct DriverChannelCreditReplenisher {
channel_id: ChannelId,
threshold: u32,
local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
pending: std::sync::Mutex<u32>,
}
impl DriverChannelCreditReplenisher {
fn new(
channel_id: ChannelId,
initial_credit: u32,
local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
) -> Self {
Self {
channel_id,
threshold: (initial_credit / 2).max(1),
local_control_tx,
pending: std::sync::Mutex::new(0),
}
}
}
impl ChannelCreditReplenisher for DriverChannelCreditReplenisher {
fn on_item_consumed(&self) {
let mut pending = self.pending.lock().expect("pending credit mutex poisoned");
*pending += 1;
if *pending < self.threshold {
return;
}
let additional = *pending;
*pending = 0;
let _ = self.local_control_tx.send(DriverLocalControl::GrantCredit {
channel_id: self.channel_id,
additional,
});
}
}
impl<H: Handler<DriverReplySink>> Driver<H> {
fn close_all_channel_runtime_state(&self) {
let mut credits = self.shared.channel_credits.lock();
for semaphore in credits.values() {
semaphore.close();
}
let mut stale = self.shared.stale_close_channels.lock();
stale.extend(credits.keys().copied());
credits.clear();
drop(credits);
self.shared.channel_senders.lock().clear();
self.shared.channel_buffers.lock().clear();
}
fn close_outbound_channel(&self, channel_id: ChannelId) {
if let Some(semaphore) = self.shared.channel_credits.lock().remove(&channel_id) {
semaphore.close();
}
}
fn abort_channel_handlers(&mut self) {
for in_flight in self.in_flight_handlers.values() {
if in_flight.has_channels {
if let Some(operation_id) = in_flight.operation_id {
self.shared.operations.remove(operation_id);
self.live_operations.lock().release(operation_id);
}
in_flight.handle.abort();
}
}
}
pub fn new(handle: ConnectionHandle, handler: H) -> Self {
Self::with_operation_store(handle, handler, Arc::new(InMemoryOperationStore::default()))
}
pub fn with_operation_store(
handle: ConnectionHandle,
handler: H,
operation_store: Arc<dyn OperationStore>,
) -> Self {
let conn_id = handle.connection_id();
let ConnectionHandle {
sender,
rx,
failures_rx,
control_tx,
closed_rx,
resumed_rx,
parity,
peer_supports_retry,
} = handle;
let drop_control_request = DropControlRequest::Close(conn_id);
let (local_control_tx, local_control_rx) = mpsc::unbounded_channel("driver.local_control");
let (resume_processed_tx, _resume_processed_rx) = watch::channel(0_u64);
Self {
sender,
rx,
failures_rx,
closed_rx,
resumed_rx,
resume_processed_tx,
peer_supports_retry,
local_control_rx,
handler: Arc::new(handler),
shared: Arc::new(DriverShared {
pending_responses: SyncMutex::new("driver.pending_responses", BTreeMap::new()),
request_ids: SyncMutex::new("driver.request_ids", IdAllocator::new(parity)),
next_operation_id: AtomicU64::new(1),
operations: operation_store,
channel_ids: SyncMutex::new("driver.channel_ids", IdAllocator::new(parity)),
channel_senders: SyncMutex::new("driver.channel_senders", BTreeMap::new()),
channel_buffers: SyncMutex::new("driver.channel_buffers", BTreeMap::new()),
channel_credits: SyncMutex::new("driver.channel_credits", BTreeMap::new()),
stale_close_channels: SyncMutex::new(
"driver.stale_close_channels",
std::collections::HashSet::new(),
),
}),
in_flight_handlers: BTreeMap::new(),
live_operations: Arc::new(SyncMutex::new(
"driver.live_operations",
LiveOperationTracker::new(),
)),
local_control_tx,
drop_control_seed: control_tx,
drop_control_request,
drop_guard: SyncMutex::new("driver.drop_guard", None),
}
}
fn existing_drop_guard(&self) -> Option<Arc<CallerDropGuard>> {
self.drop_guard.lock().as_ref().and_then(Weak::upgrade)
}
fn connection_drop_guard(&self) -> Option<Arc<CallerDropGuard>> {
if let Some(existing) = self.existing_drop_guard() {
Some(existing)
} else if let Some(seed) = &self.drop_control_seed {
let mut guard = self.drop_guard.lock();
if let Some(existing) = guard.as_ref().and_then(Weak::upgrade) {
Some(existing)
} else {
let arc = Arc::new(CallerDropGuard {
control_tx: seed.clone(),
request: self.drop_control_request,
});
*guard = Some(Arc::downgrade(&arc));
Some(arc)
}
} else {
None
}
}
pub fn caller(&self) -> DriverCaller {
let drop_guard = self.connection_drop_guard();
DriverCaller {
sender: self.sender.clone(),
shared: Arc::clone(&self.shared),
local_control_tx: self.local_control_tx.clone(),
closed_rx: self.closed_rx.clone(),
resumed_rx: self.resumed_rx.clone(),
resume_processed_rx: self.resume_processed_tx.subscribe(),
peer_supports_retry: self.peer_supports_retry,
_drop_guard: drop_guard,
}
}
fn internal_binder(&self) -> DriverChannelBinder {
DriverChannelBinder {
sender: self.sender.clone(),
shared: Arc::clone(&self.shared),
local_control_tx: self.local_control_tx.clone(),
drop_guard: self.existing_drop_guard(),
}
}
pub async fn run(&mut self) {
let mut resumed_rx = self.resumed_rx.clone();
let mut seen_resume_generation = *resumed_rx.borrow();
loop {
tracing::trace!("driver select loop top");
tokio::select! {
biased;
changed = resumed_rx.changed() => {
if changed.is_err() {
break;
}
let generation = *resumed_rx.borrow();
if generation != seen_resume_generation {
seen_resume_generation = generation;
self.close_all_channel_runtime_state();
self.abort_channel_handlers();
let _ = self.resume_processed_tx.send(generation);
}
}
recv = self.rx.recv() => {
match recv {
Some(recv) => {
self.handle_recv(recv);
}
None => {
tracing::trace!("driver rx closed, exiting loop");
break;
}
}
}
Some((req_id, disposition)) = self.failures_rx.recv() => {
tracing::trace!(%req_id, ?disposition, "failures_rx fired");
let in_flight_found = self.in_flight_handlers.contains_key(&req_id);
let in_flight_method_id =
self.in_flight_handlers.get(&req_id).map(|in_flight| in_flight.method_id);
let reply_disposition = self
.in_flight_handlers
.get(&req_id)
.map(|in_flight| {
if in_flight.has_channels && !in_flight.retry.idem {
Some(FailureDisposition::Indeterminate)
} else if in_flight.has_channels && in_flight.retry.idem {
None
} else {
Some(disposition)
}
})
.unwrap_or(Some(disposition));
tracing::trace!(%req_id, in_flight_found, ?reply_disposition, "failures_rx computed disposition");
self.in_flight_handlers.remove(&req_id);
tracing::trace!(%req_id, in_flight = self.in_flight_handlers.len(), "handler removed on failure");
let had_pending = self.shared.pending_responses.lock().remove(&req_id).is_some();
tracing::trace!(%req_id, had_pending, "failures_rx checked pending_responses");
if !had_pending {
let Some(reply_disposition) = reply_disposition else {
tracing::trace!(%req_id, "failures_rx: no reply_disposition, skipping");
continue;
};
tracing::trace!(%req_id, ?reply_disposition, "failures_rx: sending error response");
let vox_error = match reply_disposition {
FailureDisposition::Cancelled => VoxError::Cancelled,
FailureDisposition::Indeterminate => VoxError::Indeterminate,
};
if let Some(method_id) = in_flight_method_id
&& let Some(response_shape) = self.handler.response_wire_shape(method_id)
&& let Ok(extracted) = vox_types::extract_schemas(response_shape)
{
let registry = vox_types::build_registry(&extracted.schemas);
let error: Result<(), VoxError<core::convert::Infallible>> =
Err(vox_error);
let encoded = vox_postcard::to_vec(&error)
.expect("serialize runtime-generated error response");
let mut response = RequestResponse {
ret: Payload::PostcardBytes(Box::leak(encoded.into_boxed_slice())),
metadata: Default::default(),
schemas: Default::default(),
};
self.sender.prepare_response_from_source(
req_id,
method_id,
&extracted.root,
®istry,
&mut response,
);
let _ = self.sender.send_response(req_id, response).await;
} else {
let error: Result<(), VoxError<core::convert::Infallible>> =
Err(vox_error);
let _ = self.sender.send_response(req_id, RequestResponse {
ret: Payload::outgoing(&error),
metadata: Default::default(),
schemas: Default::default(),
}).await;
}
tracing::trace!(%req_id, "failures_rx: error response sent");
}
}
Some(ctrl) = self.local_control_rx.recv() => {
self.handle_local_control(ctrl).await;
}
}
}
for (_, in_flight) in std::mem::take(&mut self.in_flight_handlers) {
if !in_flight.retry.persist {
in_flight.handle.abort();
}
}
self.shared.pending_responses.lock().clear();
self.close_all_channel_runtime_state();
}
async fn handle_local_control(&mut self, control: DriverLocalControl) {
match control {
DriverLocalControl::CloseChannel { channel_id } => {
if self.shared.stale_close_channels.lock().remove(&channel_id) {
tracing::trace!(%channel_id, "suppressing ChannelClose for stale channel");
return;
}
let _ = self
.sender
.send(ConnectionMessage::Channel(ChannelMessage {
id: channel_id,
body: ChannelBody::Close(ChannelClose {
metadata: Default::default(),
}),
}))
.await;
}
DriverLocalControl::GrantCredit {
channel_id,
additional,
} => {
let _ = self
.sender
.send(ConnectionMessage::Channel(ChannelMessage {
id: channel_id,
body: ChannelBody::GrantCredit(vox_types::ChannelGrantCredit {
additional,
}),
}))
.await;
}
DriverLocalControl::HandlerCompleted { request_id } => {
let removed = self.in_flight_handlers.remove(&request_id).is_some();
tracing::trace!(
%request_id,
removed,
in_flight = self.in_flight_handlers.len(),
"handler completion processed"
);
}
}
}
fn handle_recv(&mut self, recv: crate::session::RecvMessage) {
let crate::session::RecvMessage { schemas, msg } = recv;
let msg_ref = msg.get();
let is_request = matches!(msg_ref, ConnectionMessage::Request(_));
if is_request {
if let ConnectionMessage::Request(req) = msg_ref {
vox_types::dlog!(
"[driver] handle_recv request: conn={:?} req={:?} body={} method={:?}",
self.sender.connection_id(),
req.id,
match &req.body {
RequestBody::Call(_) => "Call",
RequestBody::Response(_) => "Response",
RequestBody::Cancel(_) => "Cancel",
},
match &req.body {
RequestBody::Call(call) => Some(call.method_id),
RequestBody::Response(_) | RequestBody::Cancel(_) => None,
}
);
match &req.body {
RequestBody::Call(call) => tracing::trace!(
conn_id = self.sender.connection_id().0,
req_id = req.id.0,
method_id = call.method_id.0,
"driver received call"
),
RequestBody::Response(_) => tracing::trace!(
conn_id = self.sender.connection_id().0,
req_id = req.id.0,
"driver received response message"
),
RequestBody::Cancel(_) => tracing::trace!(
conn_id = self.sender.connection_id().0,
req_id = req.id.0,
"driver received cancel message"
),
}
}
let msg = msg.map(|m| match m {
ConnectionMessage::Request(r) => r,
_ => unreachable!(),
});
self.handle_request(msg, schemas);
} else {
let msg = msg.map(|m| match m {
ConnectionMessage::Channel(c) => c,
_ => unreachable!(),
});
self.handle_channel(msg);
}
}
fn handle_request(
&mut self,
msg: SelfRef<RequestMessage<'static>>,
schemas: Arc<vox_types::SchemaRecvTracker>,
) {
let msg_ref = msg.get();
let req_id = msg_ref.id;
let is_call = matches!(&msg_ref.body, RequestBody::Call(_));
let is_response = matches!(&msg_ref.body, RequestBody::Response(_));
let is_cancel = matches!(&msg_ref.body, RequestBody::Cancel(_));
if is_call {
let method_id = match &msg_ref.body {
RequestBody::Call(call) => call.method_id,
_ => unreachable!(),
};
vox_types::dlog!(
"[driver] inbound call: conn={:?} req={:?} method={:?}",
self.sender.connection_id(),
req_id,
method_id
);
let call = msg.map(|m| match m.body {
RequestBody::Call(c) => c,
_ => unreachable!(),
});
let call_ref = call.get();
let handler = Arc::clone(&self.handler);
let retry = handler.retry_policy(call_ref.method_id);
let operation_id = metadata_operation_id(&call_ref.metadata).filter(|_| !retry.idem);
let method_id = call_ref.method_id;
if let Some(operation_id) = operation_id {
let admit = self.live_operations.lock().admit(
operation_id,
call_ref.method_id,
incoming_args_bytes(call_ref),
retry,
req_id,
);
match admit {
AdmitResult::Attached => return,
AdmitResult::Conflict => {
let sender = self.sender.clone();
moire::task::spawn(
async move {
let error: Result<(), VoxError<core::convert::Infallible>> =
Err(VoxError::InvalidPayload("operation ID conflict".into()));
let _ = sender
.send_response(
req_id,
RequestResponse {
ret: Payload::outgoing(&error),
metadata: Default::default(),
schemas: Default::default(),
},
)
.await;
}
.named("operation_reject"),
);
return;
}
AdmitResult::Start => {}
}
match self.shared.operations.lookup(operation_id) {
crate::OperationState::Sealed => {
if let Some(sealed) = self.shared.operations.get_sealed(operation_id) {
let sender = self.sender.clone();
let method_id = call_ref.method_id;
let operations = Arc::clone(&self.shared.operations);
self.live_operations.lock().seal(operation_id);
moire::task::spawn(
async move {
if replay_sealed_response(
sender.clone(),
req_id,
method_id,
sealed.response.as_bytes(),
sealed.root_type,
operations.as_ref(),
)
.await
.is_err()
{
sender.mark_failure(req_id, FailureDisposition::Cancelled);
}
}
.named("operation_replay"),
);
return;
}
}
crate::OperationState::Admitted => {
self.live_operations.lock().seal(operation_id);
let sender = self.sender.clone();
moire::task::spawn(
async move {
let error: Result<(), VoxError<core::convert::Infallible>> =
Err(VoxError::Indeterminate);
let _ = sender
.send_response(
req_id,
RequestResponse {
ret: Payload::outgoing(&error),
metadata: Default::default(),
schemas: Default::default(),
},
)
.await;
}
.named("operation_indeterminate"),
);
return;
}
crate::OperationState::Unknown => {
if !retry.idem {
self.shared.operations.admit(operation_id);
}
}
}
}
let reply = DriverReplySink {
sender: Some(self.sender.clone()),
request_id: req_id,
method_id: call_ref.method_id,
retry,
operation_id,
operations: operation_id.map(|_| Arc::clone(&self.shared.operations)),
live_operations: operation_id.map(|_| Arc::clone(&self.live_operations)),
binder: self.internal_binder(),
};
let has_channels = handler.args_have_channels(call_ref.method_id);
let local_control_tx = self.local_control_tx.clone();
let join_handle = moire::task::spawn(
async move {
vox_types::dlog!(
"[driver] handler start: req={:?} method={:?}",
req_id,
method_id
);
handler.handle(call, reply, schemas).await;
vox_types::dlog!(
"[driver] handler done: req={:?} method={:?}",
req_id,
method_id
);
let _ = local_control_tx
.send(DriverLocalControl::HandlerCompleted { request_id: req_id });
}
.named("handler"),
);
self.in_flight_handlers.insert(
req_id,
InFlightHandler {
handle: join_handle,
method_id,
retry,
has_channels,
operation_id,
},
);
tracing::trace!(%req_id, in_flight = self.in_flight_handlers.len(), "handler inserted");
} else if is_response {
vox_types::dlog!(
"[driver] inbound response: conn={:?} req={:?}",
self.sender.connection_id(),
req_id
);
tracing::trace!(%req_id, "driver received response");
if let Some(tx) = self.shared.pending_responses.lock().remove(&req_id) {
vox_types::dlog!("[driver] routing response to waiter: req={:?}", req_id);
tracing::trace!(%req_id, "routing response to pending oneshot");
let _: Result<(), _> = tx.send(PendingResponse { msg, schemas });
} else {
vox_types::dlog!("[driver] dropped unmatched response: req={:?}", req_id);
tracing::trace!(%req_id, "no pending response slot for this req_id");
}
} else if is_cancel {
vox_types::dlog!(
"[driver] inbound cancel: conn={:?} req={:?}",
self.sender.connection_id(),
req_id
);
tracing::trace!(%req_id, in_flight = self.in_flight_handlers.contains_key(&req_id), "received cancel");
match self.live_operations.lock().cancel(req_id) {
CancelResult::NotFound => {
let should_abort = self
.in_flight_handlers
.get(&req_id)
.map(|in_flight| !in_flight.retry.persist)
.unwrap_or(false);
tracing::trace!(%req_id, should_abort, "cancel: not in live operations");
if should_abort && let Some(in_flight) = self.in_flight_handlers.remove(&req_id)
{
tracing::trace!(%req_id, "aborting handler");
in_flight.handle.abort();
tracing::trace!(%req_id, in_flight = self.in_flight_handlers.len(), "handler removed on cancel");
}
}
CancelResult::Detached => {}
CancelResult::Abort {
owner_request_id,
waiters,
} => {
if let Some(in_flight) = self.in_flight_handlers.remove(&owner_request_id) {
if let Some(op_id) = in_flight.operation_id {
self.shared.operations.remove(op_id);
}
in_flight.handle.abort();
tracing::trace!(%owner_request_id, in_flight = self.in_flight_handlers.len(), "owner handler removed on abort");
}
for waiter in waiters {
self.sender
.mark_failure(waiter, FailureDisposition::Cancelled);
}
}
}
}
}
fn handle_channel(&mut self, msg: SelfRef<ChannelMessage<'static>>) {
let msg_ref = msg.get();
let chan_id = msg_ref.id;
let sender = self.shared.channel_senders.lock().get(&chan_id).cloned();
match &msg_ref.body {
ChannelBody::Item(_item) => {
if let Some(tx) = &sender {
tracing::trace!(
conn_id = self.sender.connection_id().0,
channel_id = chan_id.0,
registered = true,
"driver received channel item"
);
let item = msg.map(|m| match m.body {
ChannelBody::Item(item) => item,
_ => unreachable!(),
});
let _ = tx.try_send(IncomingChannelMessage::Item(item));
} else {
tracing::trace!(
conn_id = self.sender.connection_id().0,
channel_id = chan_id.0,
registered = false,
"driver buffered channel item before registration"
);
let item = msg.map(|m| match m.body {
ChannelBody::Item(item) => item,
_ => unreachable!(),
});
self.shared
.channel_buffers
.lock()
.entry(chan_id)
.or_default()
.push(IncomingChannelMessage::Item(item));
}
}
ChannelBody::Close(_close) => {
if let Some(tx) = &sender {
tracing::trace!(
conn_id = self.sender.connection_id().0,
channel_id = chan_id.0,
registered = true,
"driver received channel close"
);
let close = msg.map(|m| match m.body {
ChannelBody::Close(close) => close,
_ => unreachable!(),
});
let _ = tx.try_send(IncomingChannelMessage::Close(close));
} else {
tracing::trace!(
conn_id = self.sender.connection_id().0,
channel_id = chan_id.0,
registered = false,
"driver buffered channel close before registration"
);
let close = msg.map(|m| match m.body {
ChannelBody::Close(close) => close,
_ => unreachable!(),
});
self.shared
.channel_buffers
.lock()
.entry(chan_id)
.or_default()
.push(IncomingChannelMessage::Close(close));
}
self.shared.channel_senders.lock().remove(&chan_id);
self.close_outbound_channel(chan_id);
}
ChannelBody::Reset(_reset) => {
if let Some(tx) = &sender {
tracing::trace!(
conn_id = self.sender.connection_id().0,
channel_id = chan_id.0,
registered = true,
"driver received channel reset"
);
let reset = msg.map(|m| match m.body {
ChannelBody::Reset(reset) => reset,
_ => unreachable!(),
});
let _ = tx.try_send(IncomingChannelMessage::Reset(reset));
} else {
tracing::trace!(
conn_id = self.sender.connection_id().0,
channel_id = chan_id.0,
registered = false,
"driver buffered channel reset before registration"
);
let reset = msg.map(|m| match m.body {
ChannelBody::Reset(reset) => reset,
_ => unreachable!(),
});
self.shared
.channel_buffers
.lock()
.entry(chan_id)
.or_default()
.push(IncomingChannelMessage::Reset(reset));
}
self.shared.channel_senders.lock().remove(&chan_id);
self.close_outbound_channel(chan_id);
}
ChannelBody::GrantCredit(grant) => {
tracing::trace!(
conn_id = self.sender.connection_id().0,
channel_id = chan_id.0,
additional = grant.additional,
"driver received channel credit"
);
if let Some(semaphore) = self.shared.channel_credits.lock().get(&chan_id) {
semaphore.add_permits(grant.additional as usize);
}
}
}
}
}