use std::{marker::Unpin, pin::Pin};
use futures::{
channel::oneshot,
future::{self, FusedFuture},
stream::{self, FusedStream},
task::{Context, Poll},
FutureExt, Stream, StreamExt,
};
type Shutdown = oneshot::Receiver<()>;
type FusedShutdown = future::Fuse<Shutdown>;
pub struct ShutdownStream<S> {
shutdown: FusedShutdown,
stream: S,
}
impl<S: Stream> ShutdownStream<stream::Fuse<S>> {
pub fn new(shutdown: Shutdown, stream: S) -> Self {
Self {
shutdown: shutdown.fuse(),
stream: stream.fuse(),
}
}
pub fn from_fused(shutdown: FusedShutdown, stream: stream::Fuse<S>) -> Self {
Self { shutdown, stream }
}
pub fn split(self) -> (FusedShutdown, stream::Fuse<S>) {
(self.shutdown, self.stream)
}
}
impl<S: Stream<Item = T> + FusedStream + Unpin, T> Stream for ShutdownStream<S> {
type Item = T;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
if !self.shutdown.is_terminated() {
if self.shutdown.poll_unpin(cx).is_ready() {
return Poll::Ready(None);
}
if !self.stream.is_terminated() {
return self.stream.poll_next_unpin(cx);
}
}
Poll::Ready(None)
}
}
impl<S: Stream<Item = T> + FusedStream + Unpin, T> FusedStream for ShutdownStream<S> {
fn is_terminated(&self) -> bool {
self.shutdown.is_terminated() || self.stream.is_terminated()
}
}