use futures::{future::BoxFuture, FutureExt};
use std::{
mem,
sync::{Arc, Mutex, Weak},
};
use tokio::sync::{
mpsc::{self, error::TrySendError},
oneshot,
};
use super::{multiplexer::PortEvt, MultiplexError, SendError};
#[derive(Default)]
pub(crate) struct AssignedCredits {
port: u32,
port_inner: Weak<Mutex<ChannelCreditsInner>>,
}
impl AssignedCredits {
fn new(port: u32, port_inner: Weak<Mutex<ChannelCreditsInner>>) -> Self {
Self { port, port_inner }
}
pub fn is_empty(&self) -> bool {
self.port == 0
}
pub fn available(&self) -> u32 {
self.port
}
pub fn take(&mut self, credits: u32) {
if self.port >= credits {
self.port -= credits;
} else {
panic!("unsufficient AssignedCredits")
}
}
}
impl Drop for AssignedCredits {
fn drop(&mut self) {
if self.port > 0 {
if let Some(port) = self.port_inner.upgrade() {
let mut port = port.lock().unwrap();
port.credits += self.port;
}
}
}
}
#[derive(Debug)]
struct ChannelCreditsInner {
credits: u32,
closed: Option<bool>,
notify: Vec<oneshot::Sender<()>>,
}
#[derive(Debug)]
pub(crate) struct CreditProvider(Arc<Mutex<ChannelCreditsInner>>);
impl CreditProvider {
pub fn provide<SinkError, StreamError>(
&self, credits: u32,
) -> Result<(), MultiplexError<SinkError, StreamError>> {
let notify = {
let mut inner = self.0.lock().unwrap();
match inner.credits.checked_add(credits) {
Some(new_credits) => inner.credits = new_credits,
None => return Err(MultiplexError::Protocol("credits overflow".to_string())),
};
mem::take(&mut inner.notify)
};
for tx in notify {
let _ = tx.send(());
}
Ok(())
}
pub fn close(&self, gracefully: bool) {
let notify = {
let mut inner = self.0.lock().unwrap();
inner.closed = Some(gracefully);
mem::take(&mut inner.notify)
};
for tx in notify {
let _ = tx.send(());
}
}
}
pub(crate) struct CreditUser {
channel: Weak<Mutex<ChannelCreditsInner>>,
}
impl CreditUser {
pub async fn request(&self, req: u32, min_req: u32) -> Result<AssignedCredits, SendError> {
debug_assert!(req > 0);
loop {
let rx_channel = {
let channel = match self.channel.upgrade() {
Some(channel) => channel,
None => return Err(SendError::Multiplexer),
};
let mut channel = channel.lock().unwrap();
if let Some(gracefully) = channel.closed {
return Err(SendError::Closed { gracefully });
}
if channel.credits >= min_req {
let channel_taken = channel.credits.min(req);
channel.credits -= channel_taken;
return Ok(AssignedCredits::new(channel_taken, self.channel.clone()));
} else {
let (tx_channel, rx_channel) = oneshot::channel();
channel.notify.push(tx_channel);
rx_channel
}
};
let _ = rx_channel.await;
}
}
pub fn try_request(&self, req: u32) -> Result<Option<AssignedCredits>, SendError> {
debug_assert!(req > 0);
let channel = match self.channel.upgrade() {
Some(channel) => channel,
None => return Err(SendError::Multiplexer),
};
let mut channel = channel.lock().unwrap();
if let Some(gracefully) = channel.closed {
return Err(SendError::Closed { gracefully });
}
if channel.credits >= req {
channel.credits -= req;
Ok(Some(AssignedCredits::new(req, self.channel.clone())))
} else {
Ok(None)
}
}
}
pub(crate) fn credit_send_pair(initial_credits: u32) -> (CreditProvider, CreditUser) {
let inner =
Arc::new(Mutex::new(ChannelCreditsInner { credits: initial_credits, closed: None, notify: Vec::new() }));
let user = CreditUser { channel: Arc::downgrade(&inner) };
let provider = CreditProvider(inner);
(provider, user)
}
pub(crate) struct UsedCredit(u32);
#[derive(Debug)]
struct ChannelCreditMonitorInner {
used: u32,
limit: u32,
}
#[derive(Debug)]
pub(crate) struct ChannelCreditMonitor(Arc<Mutex<ChannelCreditMonitorInner>>);
impl ChannelCreditMonitor {
pub fn use_credits<SinkError, StreamError>(
&self, credits: u32,
) -> Result<UsedCredit, MultiplexError<SinkError, StreamError>> {
let mut inner = self.0.lock().unwrap();
match inner.used.checked_add(credits) {
Some(new_used) if new_used <= inner.limit => {
inner.used = new_used;
Ok(UsedCredit(credits))
}
_ => Err(MultiplexError::Protocol("remote endpoint used too many channel flow credits".to_string())),
}
}
}
pub(crate) struct ChannelCreditReturner {
monitor: Weak<Mutex<ChannelCreditMonitorInner>>,
to_return: u32,
return_fut: Option<BoxFuture<'static, ()>>,
}
impl ChannelCreditReturner {
pub fn start_return(&mut self, credit: UsedCredit, remote_port: u32, tx: &mpsc::Sender<PortEvt>) {
assert!(self.return_fut.is_none(), "start_return_one called without poll_return_flush");
if let Some(monitor) = self.monitor.upgrade() {
let mut monitor = monitor.lock().unwrap();
monitor.used -= credit.0;
self.to_return += credit.0;
let threshold = if monitor.limit >= 8 { monitor.limit / 2 } else { 1 };
if self.to_return >= threshold {
let msg = PortEvt::ReturnCredits { remote_port, credits: self.to_return };
self.to_return = 0;
if let Err(TrySendError::Full(msg)) = tx.try_send(msg) {
let tx = tx.clone();
self.return_fut = Some(
async move {
let _ = tx.send(msg).await;
}
.boxed(),
);
}
}
}
}
pub async fn return_flush(&mut self) {
if let Some(return_fut) = &mut self.return_fut {
return_fut.await;
self.return_fut = None;
}
}
}
pub(crate) fn credit_monitor_pair(limit: u32) -> (ChannelCreditMonitor, ChannelCreditReturner) {
let monitor = ChannelCreditMonitor(Arc::new(Mutex::new(ChannelCreditMonitorInner { used: 0, limit })));
let returner = ChannelCreditReturner { monitor: Arc::downgrade(&monitor.0), to_return: 0, return_fut: None };
(monitor, returner)
}