use std::{
collections::HashMap,
future::Future,
pin::Pin,
time::Duration,
};
use super::*;
use parking_lot::Mutex;
use tokio::{
sync::{
mpsc::{
self,
error::TrySendError,
},
OwnedSemaphorePermit,
Semaphore,
},
time::Instant,
};
use tokio_stream::{
wrappers::ReceiverStream,
Stream,
};
use tx_status_stream::TxUpdateStream;
#[cfg(test)]
mod tests;
mod tx_status_stream;
const BUFFER_SIZE: usize = 2;
#[derive(Debug)]
pub struct UpdateSender {
senders: Arc<Mutex<SenderMap<Permit, Tx>>>,
permits: GetPermit,
ttl: Duration,
}
#[derive(Debug)]
pub enum SendError {
Full,
Closed,
}
pub trait PermitTrait: Send + Sync {}
pub type Permit = Box<dyn PermitTrait + Send + Sync + 'static>;
pub type Tx = Box<dyn SendStatus + Send + Sync + 'static>;
type SenderMap<P, Tx> = HashMap<Bytes32, Vec<Sender<P, Tx>>>;
pub type TxStatusStream = Pin<Box<dyn Stream<Item = TxStatusMessage> + Send + Sync>>;
type GetPermit = Arc<dyn PermitsDebug + Send + Sync>;
struct Sender<P = OwnedSemaphorePermit, Tx = mpsc::Sender<TxStatusMessage>> {
_permit: P,
stream: TxUpdateStream,
tx: Tx,
created: Instant,
}
#[cfg_attr(test, mockall::automock)]
pub trait SendStatus {
fn try_send(&mut self, msg: TxStatusMessage) -> Result<(), SendError>;
fn is_closed(&self) -> bool;
fn is_full(&self) -> bool;
}
pub trait CreateChannel {
fn channel() -> (Tx, TxStatusStream);
}
#[cfg_attr(test, mockall::automock(type P = ();))]
trait Permits {
fn try_acquire(self: Arc<Self>) -> Option<Permit>;
fn acquire(self: Arc<Self>) -> Pin<Box<dyn Future<Output = Permit> + Send + Sync>>;
}
trait PermitsDebug: Permits + std::fmt::Debug {}
impl<T: Permits + std::fmt::Debug> PermitsDebug for T {}
pub struct MpscChannel;
impl CreateChannel for MpscChannel {
fn channel() -> (Tx, TxStatusStream) {
let (tx, rx) = mpsc::channel(BUFFER_SIZE);
(Box::new(tx), Box::pin(ReceiverStream::from(rx)))
}
}
impl Permits for Semaphore {
fn try_acquire(self: Arc<Self>) -> Option<Permit> {
Semaphore::try_acquire_owned(self).ok().map(|p| {
let b: Permit = Box::new(p);
b
})
}
fn acquire(self: Arc<Self>) -> Pin<Box<dyn Future<Output = Permit> + Send + Sync>> {
Box::pin(async move {
let p = Semaphore::acquire_owned(self)
.await
.expect("Semaphore is not ever closed");
let b: Permit = Box::new(p);
b
})
}
}
impl PermitTrait for OwnedSemaphorePermit {}
impl<P, Tx> SendStatus for Sender<P, Tx>
where
Tx: SendStatus,
{
fn try_send(&mut self, msg: TxStatusMessage) -> Result<(), SendError> {
self.stream.add_msg(msg);
if let Some(msg) = self.stream.try_next() {
match self.tx.try_send(msg) {
Ok(()) => (),
Err(SendError::Full) => self.stream.add_failure(),
Err(SendError::Closed) => self.stream.close_recv(),
}
}
if self.stream.is_closed() {
Err(SendError::Closed)
} else {
Ok(())
}
}
fn is_closed(&self) -> bool {
self.stream.is_closed()
}
fn is_full(&self) -> bool {
self.tx.is_full()
}
}
impl SendStatus for mpsc::Sender<TxStatusMessage> {
fn try_send(&mut self, msg: TxStatusMessage) -> Result<(), SendError> {
match (*self).try_send(msg) {
Ok(()) => Ok(()),
Err(TrySendError::Full(_)) => Err(SendError::Full),
Err(TrySendError::Closed(_)) => Err(SendError::Closed),
}
}
fn is_closed(&self) -> bool {
self.is_closed()
}
fn is_full(&self) -> bool {
self.capacity() == 0
}
}
impl UpdateSender {
pub fn new(capacity: usize, ttl: Duration) -> UpdateSender {
UpdateSender {
senders: Default::default(),
permits: Arc::new(Semaphore::new(capacity)),
ttl,
}
}
pub fn try_subscribe<C>(&self, tx_id: Bytes32) -> Option<TxStatusStream>
where
C: CreateChannel,
{
remove_closed_and_expired(&mut self.senders.lock(), self.ttl);
let permit = Arc::clone(&self.permits).try_acquire()?;
Some(self.subscribe_inner::<C>(tx_id, permit))
}
fn subscribe_inner<C>(&self, tx_id: Bytes32, permit: Permit) -> TxStatusStream
where
C: CreateChannel,
{
let mut senders = self.senders.lock();
remove_closed_and_expired(&mut senders, self.ttl);
subscribe::<_, C>(tx_id, &mut (*senders), permit)
}
pub fn send(&self, update: TxUpdate) {
let mut senders = self.senders.lock();
remove_closed_and_expired(&mut senders, self.ttl);
let mut empty = false;
if let Some(senders) = senders.get_mut(update.tx_id()) {
senders
.retain_mut(|sender| sender.try_send(update.clone().into_msg()).is_ok());
empty = senders.is_empty();
}
if empty {
senders.remove(update.tx_id());
}
}
}
fn subscribe<P, C>(
tx_id: Bytes32, senders: &mut SenderMap<P, Tx>, permit: P, ) -> TxStatusStream
where
C: CreateChannel,
{
let (tx, rx) = C::channel();
senders.entry(tx_id).or_default().push(Sender {
_permit: permit,
stream: TxUpdateStream::new(),
tx,
created: Instant::now(),
});
rx
}
fn remove_closed_and_expired<P, Tx>(senders: &mut SenderMap<P, Tx>, ttl: Duration)
where
Tx: SendStatus,
{
senders.retain(|_, senders| {
senders.retain(|sender| !sender.is_closed() && sender.created.elapsed() < ttl);
!senders.is_empty()
});
}
impl<T> SendStatus for Box<T>
where
T: SendStatus + ?Sized,
{
fn try_send(&mut self, msg: TxStatusMessage) -> Result<(), SendError> {
(**self).try_send(msg)
}
fn is_closed(&self) -> bool {
(**self).is_closed()
}
fn is_full(&self) -> bool {
(**self).is_full()
}
}
impl Clone for UpdateSender {
fn clone(&self) -> Self {
Self {
senders: self.senders.clone(),
permits: self.permits.clone(),
ttl: self.ttl,
}
}
}
impl<P, Tx> std::fmt::Debug for Sender<P, Tx> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Sender")
.field("stream", &self.stream)
.finish()
}
}