use std::fmt;
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex, TryLockError};
use async_std::stream::Stream;
use async_std::future::Future;
use async_std::task::{Poll, Context, Waker};
use crate::byte_stream::ByteStream;
struct Inner {
active: AtomicUsize,
limit: AtomicUsize,
task: Mutex<Option<Waker>>,
}
pub struct BackpressureToken<S>(Backpressure<S>);
pub struct BackpressureWrapper<S>(Backpressure<S>);
pub struct Backpressure<S> {
stream: S,
backpressure: Receiver,
}
pub struct Receiver {
inner: Arc<Inner>,
}
pub struct HasCapacity<'a> {
recv: &'a mut Receiver,
}
#[derive(Clone)]
pub struct Sender {
inner: Arc<Inner>,
}
#[derive(Clone)]
pub struct Token {
inner: Arc<Inner>,
}
impl<S: Unpin> Unpin for Backpressure<S> {}
impl<S: Unpin> Unpin for BackpressureToken<S> {}
impl<S: Unpin> Unpin for BackpressureWrapper<S> {}
impl Sender {
pub fn token(&self) -> Token {
self.inner.active.fetch_add(1, Ordering::SeqCst);
Token {
inner: self.inner.clone(),
}
}
pub fn set_limit(&self, new_limit: usize) {
let old_limit = self.inner.limit.swap(new_limit, Ordering::SeqCst);
if old_limit < new_limit {
match self.inner.task.try_lock() {
Ok(mut guard) => {
guard.take().map(|w| w.wake());
}
Err(TryLockError::WouldBlock) => {
}
Err(TryLockError::Poisoned(_)) => {
unreachable!("backpressure lock should never be poisoned");
}
}
}
}
pub fn get_active_tokens(&self) -> usize {
self.inner.active.load(Ordering::Relaxed)
}
}
impl Receiver {
fn token(&self) -> Token {
self.inner.active.fetch_add(1, Ordering::SeqCst);
Token {
inner: self.inner.clone(),
}
}
pub fn has_capacity(&mut self) -> HasCapacity {
HasCapacity { recv: self }
}
fn poll(&mut self, cx: &mut Context) -> Poll<()> {
let limit = self.inner.limit.load(Ordering::Acquire);
loop {
let active = self.inner.active.load(Ordering::Acquire);
if active < limit {
return Poll::Ready(());
}
match self.inner.task.try_lock() {
Ok(mut guard) => {
*guard = Some(cx.waker().clone());
break;
}
Err(TryLockError::WouldBlock) => {
continue;
}
Err(TryLockError::Poisoned(_)) => {
unreachable!("backpressure lock should never be poisoned");
}
}
}
let active = self.inner.active.load(Ordering::Acquire);
if active < limit {
Poll::Ready(())
} else {
Poll::Pending
}
}
}
impl Drop for Token {
fn drop(&mut self) {
let old_ref = self.inner.active.fetch_sub(1, Ordering::SeqCst);
let limit = self.inner.limit.load(Ordering::SeqCst);
if old_ref == limit {
match self.inner.task.try_lock() {
Ok(mut guard) => {
guard.take().map(|w| w.wake());
}
Err(TryLockError::WouldBlock) => {
}
Err(TryLockError::Poisoned(_)) => {
unreachable!("backpressure lock should never be poisoned");
}
}
}
}
}
impl<S> BackpressureToken<S> {
pub(crate) fn new(stream: S, backpressure: Receiver)
-> BackpressureToken<S>
{
BackpressureToken(Backpressure::new(stream, backpressure))
}
pub fn get_ref(&self) -> &S {
self.0.get_ref()
}
pub fn get_mut(&mut self) -> &mut S {
self.0.get_mut()
}
pub fn into_inner(self) -> S {
self.0.into_inner()
}
}
impl<S> BackpressureWrapper<S> {
pub(crate) fn new(stream: S, backpressure: Receiver)
-> BackpressureWrapper<S>
{
BackpressureWrapper(Backpressure::new(stream, backpressure))
}
pub fn get_ref(&self) -> &S {
self.0.get_ref()
}
pub fn get_mut(&mut self) -> &mut S {
self.0.get_mut()
}
pub fn into_inner(self) -> S {
self.0.into_inner()
}
}
impl<S> Backpressure<S> {
pub(crate) fn new(stream: S, backpressure: Receiver) -> Backpressure<S> {
Backpressure { stream, backpressure }
}
pub fn get_ref(&self) -> &S {
&self.stream
}
pub fn get_mut(&mut self) -> &mut S {
&mut self.stream
}
pub fn into_inner(self) -> S {
self.stream
}
}
pub fn new(initial_limit: usize) -> (Sender, Receiver) {
let inner = Arc::new(Inner {
limit: AtomicUsize::new(initial_limit),
active: AtomicUsize::new(0),
task: Mutex::new(None),
});
return (
Sender {
inner: inner.clone(),
},
Receiver {
inner: inner.clone(),
},
)
}
impl fmt::Debug for Token {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
debug("Token", &self.inner, f)
}
}
impl fmt::Debug for Sender {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
debug("Sender", &self.inner, f)
}
}
impl fmt::Debug for Receiver {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
debug("Receiver", &self.inner, f)
}
}
impl<'a> fmt::Debug for HasCapacity<'a> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
debug("HasCapacity", &self.recv.inner, f)
}
}
fn debug(name: &str, inner: &Arc<Inner>, f: &mut fmt::Formatter)
-> fmt::Result
{
let active = inner.active.load(Ordering::Relaxed);
let limit = inner.limit.load(Ordering::Relaxed);
write!(f, "<{} {}/{}>", name, active, limit)
}
impl<S: fmt::Debug> fmt::Debug for Backpressure<S> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Backpressure")
.field("stream", &self.stream)
.field("backpressure", &self.backpressure)
.finish()
}
}
impl<S: fmt::Debug> fmt::Debug for BackpressureToken<S> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("BackpressureToken")
.field("stream", &self.0.stream)
.field("backpressure", &self.0.backpressure)
.finish()
}
}
impl<S: fmt::Debug> fmt::Debug for BackpressureWrapper<S> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("BackpressureWrapper")
.field("stream", &self.0.stream)
.field("backpressure", &self.0.backpressure)
.finish()
}
}
impl<I, S> Stream for Backpressure<S>
where S: Stream<Item=I> + Unpin
{
type Item = I;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context)
-> Poll<Option<Self::Item>>
{
match self.backpressure.poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(()) => Pin::new(&mut self.stream).poll_next(cx),
}
}
}
impl<'a> Future for HasCapacity<'a> {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<()> {
self.recv.poll(cx)
}
}
impl<I, S> Stream for BackpressureToken<S>
where S: Stream<Item=I> + Unpin
{
type Item = (Token, I);
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context)
-> Poll<Option<Self::Item>>
{
Pin::new(&mut self.0)
.poll_next(cx)
.map(|opt| opt.map(|conn| (self.0.backpressure.token(), conn)))
}
}
impl<I, S> Stream for BackpressureWrapper<S>
where S: Stream<Item=I> + Unpin,
ByteStream: From<(Token, I)>,
{
type Item = ByteStream;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context)
-> Poll<Option<Self::Item>>
{
Pin::new(&mut self.0)
.poll_next(cx)
.map(|opt| opt.map(|conn| {
ByteStream::from((self.0.backpressure.token(), conn))
}))
}
}