use futures_core::Stream;
use pin_project::{pin_project, pinned_drop};
use std::{
pin::Pin,
task::{Context, Poll},
};
#[pin_project(PinnedDrop)]
pub struct DropStream<S: Stream<Item = T>, T, U: FnOnce()> {
#[pin]
stream: S,
dropper: Option<U>,
}
impl<S: Stream<Item = T>, T, U: FnOnce()> DropStream<S, T, U> {
pub fn new(stream: S, dropper: U) -> Self {
Self {
stream,
dropper: Some(dropper),
}
}
}
impl<S: Stream<Item = T>, T, U: FnOnce()> Stream for DropStream<S, T, U> {
type Item = T;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let stream = self.project().stream;
stream.poll_next(cx)
}
}
#[pinned_drop]
impl<S: Stream<Item = T>, T, U: FnOnce()> PinnedDrop for DropStream<S, T, U> {
fn drop(self: Pin<&mut Self>) {
let Some(dropper) = self.project().dropper.take() else {
unreachable!()
};
dropper()
}
}
pub trait DropStreamExt<U: FnOnce()>: Stream + Sized {
fn on_drop(self, dropper: U) -> DropStream<Self, Self::Item, U>;
}
impl<T, U: FnOnce()> DropStreamExt<U> for T
where
T: Stream + Sized,
{
fn on_drop(self, dropper: U) -> DropStream<T, T::Item, U> {
DropStream::new(self, dropper)
}
}
#[cfg(test)]
mod tests {
use std::task::Poll;
use crate::{DropStream, DropStreamExt};
use futures::{stream::repeat, Stream};
#[test]
fn dropper_runs_on_drop() {
let test_stream = repeat(true);
let mut has_run = false;
{
let has_run_ref = &mut has_run;
let _drop_stream = DropStream::new(test_stream, move || {
*has_run_ref = true;
});
}
assert!(has_run)
}
#[test]
fn stream_passes_through_result() {
let test_stream = repeat(true);
let drop_stream = DropStream::new(test_stream, || {});
let mut drop_stream = Box::pin(drop_stream);
let waker = futures::task::noop_waker();
let mut context = futures::task::Context::from_waker(&waker);
assert_eq!(
drop_stream.as_mut().poll_next(&mut context),
Poll::Ready(Some(true))
);
}
#[test]
fn dropper_runs_on_drop_after_passing_result() {
let test_stream = repeat(true);
let mut has_run = false;
{
let has_run_ref = &mut has_run;
let drop_stream = DropStream::new(test_stream, move || {
*has_run_ref = true;
});
let mut drop_stream = Box::pin(drop_stream);
let waker = futures::task::noop_waker();
let mut context = futures::task::Context::from_waker(&waker);
assert_eq!(
drop_stream.as_mut().poll_next(&mut context),
Poll::Ready(Some(true))
);
}
assert!(has_run)
}
#[test]
fn stream_trait_is_implemented() {
let test_stream = repeat(true);
let mut has_run = false;
{
let has_run_ref = &mut has_run;
let drop_stream = test_stream.on_drop(move || {
*has_run_ref = true;
});
let mut drop_stream = Box::pin(drop_stream);
let waker = futures::task::noop_waker();
let mut context = futures::task::Context::from_waker(&waker);
assert_eq!(
drop_stream.as_mut().poll_next(&mut context),
Poll::Ready(Some(true))
);
}
assert!(has_run)
}
}