use std::{
collections::BTreeMap,
pin::Pin,
sync::{Arc, Weak},
};
use moire::sync::SyncMutex;
use tokio::sync::{Semaphore, watch};
use moire::task::FutureExt as _;
use roam_types::{
Caller, ChannelBinder, ChannelBody, ChannelClose, ChannelCreditReplenisher,
ChannelCreditReplenisherHandle, ChannelId, ChannelItem, ChannelLivenessHandle, ChannelMessage,
ChannelSink, CreditSink, Handler, IdAllocator, IncomingChannelMessage, MaybeSend, Payload,
ReplySink, RequestBody, RequestCall, RequestId, RequestMessage, RequestResponse, RoamError,
SelfRef, TxError,
};
use crate::session::{ConnectionHandle, ConnectionMessage, ConnectionSender, DropControlRequest};
use moire::sync::mpsc;
type ResponseSlot = moire::sync::oneshot::Sender<SelfRef<RequestMessage<'static>>>;
struct DriverShared {
pending_responses: SyncMutex<BTreeMap<RequestId, ResponseSlot>>,
request_ids: SyncMutex<IdAllocator<RequestId>>,
channel_ids: SyncMutex<IdAllocator<ChannelId>>,
channel_senders:
SyncMutex<BTreeMap<ChannelId, tokio::sync::mpsc::Sender<IncomingChannelMessage>>>,
channel_buffers: SyncMutex<BTreeMap<ChannelId, Vec<IncomingChannelMessage>>>,
channel_credits: SyncMutex<BTreeMap<ChannelId, Arc<Semaphore>>>,
}
struct CallerDropGuard {
control_tx: mpsc::UnboundedSender<DropControlRequest>,
request: DropControlRequest,
}
impl Drop for CallerDropGuard {
fn drop(&mut self) {
let _ = self.control_tx.send(self.request);
}
}
#[cfg(test)]
mod tests {
use super::{DriverChannelCreditReplenisher, DriverLocalControl};
use roam_types::{ChannelCreditReplenisher, ChannelId};
use tokio::sync::mpsc::error::TryRecvError;
#[test]
fn replenisher_batches_at_half_the_initial_window() {
let (tx, mut rx) = moire::sync::mpsc::unbounded_channel("test.replenisher");
let replenisher = DriverChannelCreditReplenisher::new(ChannelId(7), 16, tx);
for _ in 0..7 {
replenisher.on_item_consumed();
}
assert!(
matches!(rx.try_recv(), Err(TryRecvError::Empty)),
"should not emit credit before reaching the batch threshold"
);
replenisher.on_item_consumed();
let Ok(DriverLocalControl::GrantCredit {
channel_id,
additional,
}) = rx.try_recv()
else {
panic!("expected batched credit grant");
};
assert_eq!(channel_id, ChannelId(7));
assert_eq!(additional, 8);
}
#[test]
fn replenisher_grants_one_by_one_for_single_credit_windows() {
let (tx, mut rx) = moire::sync::mpsc::unbounded_channel("test.replenisher.single");
let replenisher = DriverChannelCreditReplenisher::new(ChannelId(9), 1, tx);
replenisher.on_item_consumed();
let Ok(DriverLocalControl::GrantCredit {
channel_id,
additional,
}) = rx.try_recv()
else {
panic!("expected immediate credit grant");
};
assert_eq!(channel_id, ChannelId(9));
assert_eq!(additional, 1);
}
}
pub struct DriverReplySink {
sender: Option<ConnectionSender>,
request_id: RequestId,
binder: DriverChannelBinder,
}
impl ReplySink for DriverReplySink {
async fn send_reply(mut self, response: RequestResponse<'_>) {
let sender = self
.sender
.take()
.expect("unreachable: send_reply takes self by value");
if let Err(_e) = sender.send_response(self.request_id, response).await {
sender.mark_failure(self.request_id, "send_response failed");
}
}
fn channel_binder(&self) -> Option<&dyn ChannelBinder> {
Some(&self.binder)
}
}
impl Drop for DriverReplySink {
fn drop(&mut self) {
if let Some(sender) = self.sender.take() {
sender.mark_failure(self.request_id, "no reply sent")
}
}
}
pub struct DriverChannelSink {
sender: ConnectionSender,
channel_id: ChannelId,
local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
}
impl ChannelSink for DriverChannelSink {
fn send_payload<'payload>(
&self,
payload: Payload<'payload>,
) -> Pin<Box<dyn std::future::Future<Output = Result<(), TxError>> + Send + 'payload>> {
let sender = self.sender.clone();
let channel_id = self.channel_id;
Box::pin(async move {
sender
.send(ConnectionMessage::Channel(ChannelMessage {
id: channel_id,
body: ChannelBody::Item(ChannelItem { item: payload }),
}))
.await
.map_err(|()| TxError::Transport("connection closed".into()))
})
}
fn close_channel(
&self,
_metadata: roam_types::Metadata,
) -> Pin<Box<dyn std::future::Future<Output = Result<(), TxError>> + Send + 'static>> {
let sender = self.sender.clone();
let channel_id = self.channel_id;
Box::pin(async move {
sender
.send(ConnectionMessage::Channel(ChannelMessage {
id: channel_id,
body: ChannelBody::Close(ChannelClose {
metadata: Default::default(),
}),
}))
.await
.map_err(|()| TxError::Transport("connection closed".into()))
})
}
fn close_channel_on_drop(&self) {
let _ = self
.local_control_tx
.send(DriverLocalControl::CloseChannel {
channel_id: self.channel_id,
});
}
}
#[must_use = "Dropping NoopCaller may close the connection if it is the last caller."]
#[derive(Clone)]
pub struct NoopCaller(#[allow(dead_code)] DriverCaller);
impl From<DriverCaller> for NoopCaller {
fn from(caller: DriverCaller) -> Self {
Self(caller)
}
}
#[derive(Clone)]
struct DriverChannelBinder {
sender: ConnectionSender,
shared: Arc<DriverShared>,
local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
drop_guard: Option<Arc<CallerDropGuard>>,
}
impl DriverChannelBinder {
fn create_tx_channel(
&self,
initial_credit: u32,
) -> (ChannelId, Arc<CreditSink<DriverChannelSink>>) {
let channel_id = self.shared.channel_ids.lock().alloc();
let inner = DriverChannelSink {
sender: self.sender.clone(),
channel_id,
local_control_tx: self.local_control_tx.clone(),
};
let sink = Arc::new(CreditSink::new(inner, initial_credit));
self.shared
.channel_credits
.lock()
.insert(channel_id, Arc::clone(sink.credit()));
(channel_id, sink)
}
fn register_rx_channel(
&self,
channel_id: ChannelId,
initial_credit: u32,
) -> roam_types::BoundChannelReceiver {
let (tx, rx) = tokio::sync::mpsc::channel(64);
let mut terminal_buffered = false;
if let Some(buffered) = self.shared.channel_buffers.lock().remove(&channel_id) {
for msg in buffered {
let is_terminal = matches!(
msg,
IncomingChannelMessage::Close(_) | IncomingChannelMessage::Reset(_)
);
let _ = tx.try_send(msg);
if is_terminal {
terminal_buffered = true;
break;
}
}
}
if terminal_buffered {
self.shared.channel_credits.lock().remove(&channel_id);
return roam_types::BoundChannelReceiver {
receiver: rx,
liveness: self.channel_liveness(),
replenisher: None,
};
}
self.shared.channel_senders.lock().insert(channel_id, tx);
roam_types::BoundChannelReceiver {
receiver: rx,
liveness: self.channel_liveness(),
replenisher: Some(Arc::new(DriverChannelCreditReplenisher::new(
channel_id,
initial_credit,
self.local_control_tx.clone(),
)) as ChannelCreditReplenisherHandle),
}
}
}
impl ChannelBinder for DriverChannelBinder {
fn create_tx(&self, initial_credit: u32) -> (ChannelId, Arc<dyn ChannelSink>) {
let (id, sink) = self.create_tx_channel(initial_credit);
(id, sink as Arc<dyn ChannelSink>)
}
fn create_rx(&self, initial_credit: u32) -> (ChannelId, roam_types::BoundChannelReceiver) {
let channel_id = self.shared.channel_ids.lock().alloc();
let rx = self.register_rx_channel(channel_id, initial_credit);
(channel_id, rx)
}
fn bind_tx(&self, channel_id: ChannelId, initial_credit: u32) -> Arc<dyn ChannelSink> {
let inner = DriverChannelSink {
sender: self.sender.clone(),
channel_id,
local_control_tx: self.local_control_tx.clone(),
};
let sink = Arc::new(CreditSink::new(inner, initial_credit));
self.shared
.channel_credits
.lock()
.insert(channel_id, Arc::clone(sink.credit()));
sink
}
fn register_rx(
&self,
channel_id: ChannelId,
initial_credit: u32,
) -> roam_types::BoundChannelReceiver {
self.register_rx_channel(channel_id, initial_credit)
}
fn channel_liveness(&self) -> Option<ChannelLivenessHandle> {
self.drop_guard
.as_ref()
.map(|guard| guard.clone() as ChannelLivenessHandle)
}
}
#[derive(Clone)]
pub struct DriverCaller {
sender: ConnectionSender,
shared: Arc<DriverShared>,
local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
closed_rx: watch::Receiver<bool>,
_drop_guard: Option<Arc<CallerDropGuard>>,
}
impl DriverCaller {
pub fn create_tx_channel(
&self,
initial_credit: u32,
) -> (ChannelId, Arc<CreditSink<DriverChannelSink>>) {
let channel_id = self.shared.channel_ids.lock().alloc();
let inner = DriverChannelSink {
sender: self.sender.clone(),
channel_id,
local_control_tx: self.local_control_tx.clone(),
};
let sink = Arc::new(CreditSink::new(inner, initial_credit));
self.shared
.channel_credits
.lock()
.insert(channel_id, Arc::clone(sink.credit()));
(channel_id, sink)
}
#[cfg(test)]
pub(crate) fn connection_sender(&self) -> &ConnectionSender {
&self.sender
}
pub fn register_rx_channel(
&self,
channel_id: ChannelId,
initial_credit: u32,
) -> roam_types::BoundChannelReceiver {
let (tx, rx) = tokio::sync::mpsc::channel(64);
let mut terminal_buffered = false;
if let Some(buffered) = self.shared.channel_buffers.lock().remove(&channel_id) {
for msg in buffered {
let is_terminal = matches!(
msg,
IncomingChannelMessage::Close(_) | IncomingChannelMessage::Reset(_)
);
let _ = tx.try_send(msg);
if is_terminal {
terminal_buffered = true;
break;
}
}
}
if terminal_buffered {
self.shared.channel_credits.lock().remove(&channel_id);
return roam_types::BoundChannelReceiver {
receiver: rx,
liveness: self.channel_liveness(),
replenisher: None,
};
}
self.shared.channel_senders.lock().insert(channel_id, tx);
roam_types::BoundChannelReceiver {
receiver: rx,
liveness: self.channel_liveness(),
replenisher: Some(Arc::new(DriverChannelCreditReplenisher::new(
channel_id,
initial_credit,
self.local_control_tx.clone(),
)) as ChannelCreditReplenisherHandle),
}
}
}
impl ChannelBinder for DriverCaller {
fn create_tx(&self, initial_credit: u32) -> (ChannelId, Arc<dyn ChannelSink>) {
let (id, sink) = self.create_tx_channel(initial_credit);
(id, sink as Arc<dyn ChannelSink>)
}
fn create_rx(&self, initial_credit: u32) -> (ChannelId, roam_types::BoundChannelReceiver) {
let channel_id = self.shared.channel_ids.lock().alloc();
let rx = self.register_rx_channel(channel_id, initial_credit);
(channel_id, rx)
}
fn bind_tx(&self, channel_id: ChannelId, initial_credit: u32) -> Arc<dyn ChannelSink> {
let inner = DriverChannelSink {
sender: self.sender.clone(),
channel_id,
local_control_tx: self.local_control_tx.clone(),
};
let sink = Arc::new(CreditSink::new(inner, initial_credit));
self.shared
.channel_credits
.lock()
.insert(channel_id, Arc::clone(sink.credit()));
sink
}
fn register_rx(
&self,
channel_id: ChannelId,
initial_credit: u32,
) -> roam_types::BoundChannelReceiver {
self.register_rx_channel(channel_id, initial_credit)
}
fn channel_liveness(&self) -> Option<ChannelLivenessHandle> {
self._drop_guard
.as_ref()
.map(|guard| guard.clone() as ChannelLivenessHandle)
}
}
impl Caller for DriverCaller {
fn call<'a>(
&'a self,
call: RequestCall<'a>,
) -> impl std::future::Future<Output = Result<SelfRef<RequestResponse<'static>>, RoamError>>
+ MaybeSend
+ 'a {
async {
let req_id = self.shared.request_ids.lock().alloc();
let (tx, rx) = moire::sync::oneshot::channel("driver.response");
self.shared.pending_responses.lock().insert(req_id, tx);
let send_result = self
.sender
.send(ConnectionMessage::Request(RequestMessage {
id: req_id,
body: RequestBody::Call(call),
}))
.await;
if send_result.is_err() {
self.shared.pending_responses.lock().remove(&req_id);
return Err(RoamError::Cancelled);
}
let response_msg: SelfRef<RequestMessage<'static>> = rx
.named("awaiting_response")
.await
.map_err(|_| RoamError::Cancelled)?;
let response = response_msg.map(|m| match m.body {
RequestBody::Response(r) => r,
_ => unreachable!("pending_responses only gets Response variants"),
});
Ok(response)
}
.named("Caller::call")
}
fn closed(&self) -> Pin<Box<dyn Future<Output = ()> + Send + '_>> {
Box::pin(async move {
if *self.closed_rx.borrow() {
return;
}
let mut rx = self.closed_rx.clone();
while rx.changed().await.is_ok() {
if *rx.borrow() {
return;
}
}
})
}
fn is_connected(&self) -> bool {
!*self.closed_rx.borrow()
}
fn channel_binder(&self) -> Option<&dyn ChannelBinder> {
Some(self)
}
}
pub struct Driver<H: Handler<DriverReplySink>> {
sender: ConnectionSender,
rx: mpsc::Receiver<SelfRef<ConnectionMessage<'static>>>,
failures_rx: mpsc::UnboundedReceiver<(RequestId, &'static str)>,
closed_rx: watch::Receiver<bool>,
local_control_rx: mpsc::UnboundedReceiver<DriverLocalControl>,
handler: Arc<H>,
shared: Arc<DriverShared>,
in_flight_handlers: BTreeMap<RequestId, moire::task::JoinHandle<()>>,
local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
drop_control_seed: Option<mpsc::UnboundedSender<DropControlRequest>>,
drop_control_request: DropControlRequest,
drop_guard: SyncMutex<Option<Weak<CallerDropGuard>>>,
}
enum DriverLocalControl {
CloseChannel {
channel_id: ChannelId,
},
GrantCredit {
channel_id: ChannelId,
additional: u32,
},
}
struct DriverChannelCreditReplenisher {
channel_id: ChannelId,
threshold: u32,
local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
pending: std::sync::Mutex<u32>,
}
impl DriverChannelCreditReplenisher {
fn new(
channel_id: ChannelId,
initial_credit: u32,
local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
) -> Self {
Self {
channel_id,
threshold: (initial_credit / 2).max(1),
local_control_tx,
pending: std::sync::Mutex::new(0),
}
}
}
impl ChannelCreditReplenisher for DriverChannelCreditReplenisher {
fn on_item_consumed(&self) {
let mut pending = self.pending.lock().expect("pending credit mutex poisoned");
*pending += 1;
if *pending < self.threshold {
return;
}
let additional = *pending;
*pending = 0;
let _ = self.local_control_tx.send(DriverLocalControl::GrantCredit {
channel_id: self.channel_id,
additional,
});
}
}
impl<H: Handler<DriverReplySink>> Driver<H> {
pub fn new(handle: ConnectionHandle, handler: H) -> Self {
let conn_id = handle.connection_id();
let ConnectionHandle {
sender,
rx,
failures_rx,
control_tx,
closed_rx,
parity,
} = handle;
let drop_control_request = DropControlRequest::Close(conn_id);
let (local_control_tx, local_control_rx) = mpsc::unbounded_channel("driver.local_control");
Self {
sender,
rx,
failures_rx,
closed_rx,
local_control_rx,
handler: Arc::new(handler),
shared: Arc::new(DriverShared {
pending_responses: SyncMutex::new("driver.pending_responses", BTreeMap::new()),
request_ids: SyncMutex::new("driver.request_ids", IdAllocator::new(parity)),
channel_ids: SyncMutex::new("driver.channel_ids", IdAllocator::new(parity)),
channel_senders: SyncMutex::new("driver.channel_senders", BTreeMap::new()),
channel_buffers: SyncMutex::new("driver.channel_buffers", BTreeMap::new()),
channel_credits: SyncMutex::new("driver.channel_credits", BTreeMap::new()),
}),
in_flight_handlers: BTreeMap::new(),
local_control_tx,
drop_control_seed: control_tx,
drop_control_request,
drop_guard: SyncMutex::new("driver.drop_guard", None),
}
}
fn existing_drop_guard(&self) -> Option<Arc<CallerDropGuard>> {
self.drop_guard.lock().as_ref().and_then(Weak::upgrade)
}
fn connection_drop_guard(&self) -> Option<Arc<CallerDropGuard>> {
let drop_guard = if let Some(existing) = self.existing_drop_guard() {
Some(existing)
} else if let Some(seed) = &self.drop_control_seed {
let mut guard = self.drop_guard.lock();
if let Some(existing) = guard.as_ref().and_then(Weak::upgrade) {
Some(existing)
} else {
let arc = Arc::new(CallerDropGuard {
control_tx: seed.clone(),
request: self.drop_control_request,
});
*guard = Some(Arc::downgrade(&arc));
Some(arc)
}
} else {
None
};
drop_guard
}
pub fn caller(&self) -> DriverCaller {
let drop_guard = self.connection_drop_guard();
DriverCaller {
sender: self.sender.clone(),
shared: Arc::clone(&self.shared),
local_control_tx: self.local_control_tx.clone(),
closed_rx: self.closed_rx.clone(),
_drop_guard: drop_guard,
}
}
fn internal_binder(&self) -> DriverChannelBinder {
DriverChannelBinder {
sender: self.sender.clone(),
shared: Arc::clone(&self.shared),
local_control_tx: self.local_control_tx.clone(),
drop_guard: self.existing_drop_guard(),
}
}
pub async fn run(&mut self) {
loop {
tokio::select! {
msg = self.rx.recv() => {
match msg {
Some(msg) => self.handle_msg(msg),
None => break,
}
}
Some((req_id, _reason)) = self.failures_rx.recv() => {
self.in_flight_handlers.remove(&req_id);
if self.shared.pending_responses.lock().remove(&req_id).is_none() {
let error: Result<(), RoamError<core::convert::Infallible>> =
Err(RoamError::Cancelled);
let _ = self.sender.send_response(req_id, RequestResponse {
ret: Payload::outgoing(&error),
channels: vec![],
metadata: Default::default(),
}).await;
}
}
Some(ctrl) = self.local_control_rx.recv() => {
self.handle_local_control(ctrl).await;
}
}
}
for (_, handle) in std::mem::take(&mut self.in_flight_handlers) {
handle.abort();
}
self.shared.pending_responses.lock().clear();
self.shared.channel_senders.lock().clear();
self.shared.channel_buffers.lock().clear();
self.shared.channel_credits.lock().clear();
}
async fn handle_local_control(&mut self, control: DriverLocalControl) {
match control {
DriverLocalControl::CloseChannel { channel_id } => {
let _ = self
.sender
.send(ConnectionMessage::Channel(ChannelMessage {
id: channel_id,
body: ChannelBody::Close(ChannelClose {
metadata: Default::default(),
}),
}))
.await;
}
DriverLocalControl::GrantCredit {
channel_id,
additional,
} => {
let _ = self
.sender
.send(ConnectionMessage::Channel(ChannelMessage {
id: channel_id,
body: ChannelBody::GrantCredit(roam_types::ChannelGrantCredit {
additional,
}),
}))
.await;
}
}
}
fn handle_msg(&mut self, msg: SelfRef<ConnectionMessage<'static>>) {
let is_request = matches!(&*msg, ConnectionMessage::Request(_));
if is_request {
let msg = msg.map(|m| match m {
ConnectionMessage::Request(r) => r,
_ => unreachable!(),
});
self.handle_request(msg);
} else {
let msg = msg.map(|m| match m {
ConnectionMessage::Channel(c) => c,
_ => unreachable!(),
});
self.handle_channel(msg);
}
}
fn handle_request(&mut self, msg: SelfRef<RequestMessage<'static>>) {
let req_id = msg.id;
let is_call = matches!(&msg.body, RequestBody::Call(_));
let is_response = matches!(&msg.body, RequestBody::Response(_));
let is_cancel = matches!(&msg.body, RequestBody::Cancel(_));
if is_call {
let reply = DriverReplySink {
sender: Some(self.sender.clone()),
request_id: req_id,
binder: self.internal_binder(),
};
let call = msg.map(|m| match m.body {
RequestBody::Call(c) => c,
_ => unreachable!(),
});
let handler = Arc::clone(&self.handler);
let join_handle = moire::task::spawn(
async move {
handler.handle(call, reply).await;
}
.named("handler"),
);
self.in_flight_handlers.insert(req_id, join_handle);
} else if is_response {
if let Some(tx) = self.shared.pending_responses.lock().remove(&req_id) {
let _: Result<(), _> = tx.send(msg);
}
} else if is_cancel {
if let Some(handle) = self.in_flight_handlers.remove(&req_id) {
handle.abort();
}
}
}
fn handle_channel(&mut self, msg: SelfRef<ChannelMessage<'static>>) {
let chan_id = msg.id;
let sender = self.shared.channel_senders.lock().get(&chan_id).cloned();
match &msg.body {
ChannelBody::Item(_item) => {
if let Some(tx) = &sender {
let item = msg.map(|m| match m.body {
ChannelBody::Item(item) => item,
_ => unreachable!(),
});
let _ = tx.try_send(IncomingChannelMessage::Item(item));
} else {
let item = msg.map(|m| match m.body {
ChannelBody::Item(item) => item,
_ => unreachable!(),
});
self.shared
.channel_buffers
.lock()
.entry(chan_id)
.or_default()
.push(IncomingChannelMessage::Item(item));
}
}
ChannelBody::Close(_close) => {
if let Some(tx) = &sender {
let close = msg.map(|m| match m.body {
ChannelBody::Close(close) => close,
_ => unreachable!(),
});
let _ = tx.try_send(IncomingChannelMessage::Close(close));
} else {
let close = msg.map(|m| match m.body {
ChannelBody::Close(close) => close,
_ => unreachable!(),
});
self.shared
.channel_buffers
.lock()
.entry(chan_id)
.or_default()
.push(IncomingChannelMessage::Close(close));
}
self.shared.channel_senders.lock().remove(&chan_id);
self.shared.channel_credits.lock().remove(&chan_id);
}
ChannelBody::Reset(_reset) => {
if let Some(tx) = &sender {
let reset = msg.map(|m| match m.body {
ChannelBody::Reset(reset) => reset,
_ => unreachable!(),
});
let _ = tx.try_send(IncomingChannelMessage::Reset(reset));
} else {
let reset = msg.map(|m| match m.body {
ChannelBody::Reset(reset) => reset,
_ => unreachable!(),
});
self.shared
.channel_buffers
.lock()
.entry(chan_id)
.or_default()
.push(IncomingChannelMessage::Reset(reset));
}
self.shared.channel_senders.lock().remove(&chan_id);
self.shared.channel_credits.lock().remove(&chan_id);
}
ChannelBody::GrantCredit(grant) => {
if let Some(semaphore) = self.shared.channel_credits.lock().get(&chan_id) {
semaphore.add_permits(grant.additional as usize);
}
}
}
}
}