use crate::*;
use futures_util::stream::Stream;
use futures_util::task::{AtomicWaker, Context, Poll};
use pin_project_lite::pin_project;
use std::pin::Pin;
use std::sync::atomic::Ordering::Relaxed;
use std::sync::atomic::{AtomicBool, AtomicUsize};
use std::sync::Arc;
#[derive(Debug)]
struct Inner {
change_channel: AtomicBool,
next_channel: AtomicUsize,
waker: AtomicWaker,
set: AtomicBool,
}
#[derive(Clone, Debug)]
pub(crate) struct StreamDropControl {
inner: Arc<Inner>,
}
impl StreamDropControl {
pub(crate) fn change_channel(&self, channel_id: ChannelId) {
self.inner.set.store(true, Relaxed);
self.inner.change_channel.store(true, Relaxed);
self.inner.next_channel.store(channel_id, Relaxed);
self.inner.waker.wake();
}
pub(crate) fn drop_stream(&self) {
self.inner.set.store(true, Relaxed);
self.inner.waker.wake();
}
pub(crate) fn wrap<T>(
id: StreamId,
stream: T,
channel_change_tx: async_channel::Sender<ChannelChange<T>>,
) -> (Self, StreamDropper<T>) {
let inner = Arc::new(Inner {
change_channel: AtomicBool::new(false),
next_channel: AtomicUsize::new(0),
set: AtomicBool::new(false),
waker: AtomicWaker::new(),
});
(
Self {
inner: Arc::clone(&inner),
},
StreamDropper {
channel_change_tx,
id,
inner,
stream: Some(stream),
},
)
}
}
pin_project! {
#[derive(Debug)]
pub(crate) struct StreamDropper<T> {
pub id: StreamId,
#[pin]
pub stream: Option<T>,
inner: Arc<Inner>,
channel_change_tx: async_channel::Sender<ChannelChange<T>>,
}
}
impl<T> StreamDropper<T>
where
T: Stream + Unpin,
{
fn drop_stream(self: Pin<&mut Self>) -> Poll<Option<T::Item>> {
let this = self.project();
let stream: Option<T> = Option::take(this.stream.get_mut());
match stream {
None => Poll::Ready(None),
Some(stream) => {
if this.inner.change_channel.load(Relaxed) {
let next_channel_id = this.inner.next_channel.load(Relaxed);
let channel_change = ChannelChange {
next_channel_id,
stream_id: *this.id,
stream,
};
if let Err(error) = this.channel_change_tx.try_send(channel_change) {
log::error!("Failed to send to change channel stream: {:?}", error);
}
}
Poll::Ready(None)
}
}
}
}
impl<T> Stream for StreamDropper<T>
where
T: Stream + Unpin,
{
type Item = T::Item;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if self.inner.set.load(Relaxed) {
return Self::drop_stream(self);
}
self.inner.waker.register(cx.waker());
if self.inner.set.load(Relaxed) {
Self::drop_stream(self)
} else {
let this = self.project();
this.stream
.as_pin_mut()
.expect("Stream should exist, haven't shut down yet.")
.poll_next(cx)
}
}
}