use std::pin::Pin;
use std::task::{Context, Poll};
use futures::Stream;
use serde_json::Value;
use tokio::sync::oneshot;
pub type ValueStream = Pin<Box<dyn Stream<Item = Value> + Send>>;
pub fn wrap_stream_with_signal(stream: ValueStream, signal: oneshot::Sender<()>) -> ValueStream {
Box::pin(SignalingStream {
inner: stream,
signal: Some(signal),
})
}
struct SignalingStream {
inner: ValueStream,
signal: Option<oneshot::Sender<()>>,
}
impl Stream for SignalingStream {
type Item = Value;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let inner = unsafe { self.as_mut().map_unchecked_mut(|s| &mut s.inner) };
match inner.poll_next(cx) {
Poll::Ready(None) => {
if let Some(signal) = self.get_mut().signal.take() {
let _ = signal.send(());
}
Poll::Ready(None)
}
other => other,
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
}
impl Drop for SignalingStream {
fn drop(&mut self) {
if let Some(signal) = self.signal.take() {
let _ = signal.send(());
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::StreamExt;
#[tokio::test]
async fn test_signal_fires_on_completion() {
let (tx, rx) = oneshot::channel();
let inner: ValueStream = Box::pin(futures::stream::iter(vec![
Value::from(1),
Value::from(2),
]));
let mut wrapped = wrap_stream_with_signal(inner, tx);
let mut items = Vec::new();
while let Some(item) = wrapped.next().await {
items.push(item);
}
assert_eq!(items, vec![Value::from(1), Value::from(2)]);
assert!(rx.await.is_ok());
}
#[tokio::test]
async fn test_signal_fires_on_drop() {
let (tx, rx) = oneshot::channel();
let inner: ValueStream = Box::pin(futures::stream::iter(vec![
Value::from(1),
Value::from(2),
Value::from(3),
]));
let mut wrapped = wrap_stream_with_signal(inner, tx);
let _ = wrapped.next().await;
drop(wrapped);
assert!(rx.await.is_ok());
}
}