use crate::error::FlightError;
use futures::{
channel::oneshot::{Receiver, Sender},
FutureExt, Stream, StreamExt,
};
use std::pin::Pin;
use std::task::{ready, Poll};
pub(crate) struct FallibleRequestStream<T, E> {
sender: Option<Sender<E>>,
fallible_stream: Pin<Box<dyn Stream<Item = std::result::Result<T, E>> + Send + 'static>>,
}
impl<T, E> FallibleRequestStream<T, E> {
pub(crate) fn new(
sender: Sender<E>,
fallible_stream: Pin<Box<dyn Stream<Item = std::result::Result<T, E>> + Send + 'static>>,
) -> Self {
Self {
sender: Some(sender),
fallible_stream,
}
}
}
impl<T, E> Stream for FallibleRequestStream<T, E> {
type Item = T;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let pinned = self.get_mut();
let mut request_streams = pinned.fallible_stream.as_mut();
match ready!(request_streams.poll_next_unpin(cx)) {
Some(Ok(data)) => Poll::Ready(Some(data)),
Some(Err(e)) => {
if let Some(sender) = pinned.sender.take() {
let _ = sender.send(e);
}
Poll::Ready(None)
}
None => Poll::Ready(None),
}
}
}
pub(crate) struct FallibleTonicResponseStream<T> {
receiver: Receiver<FlightError>,
response_stream: Pin<Box<dyn Stream<Item = Result<T, tonic::Status>> + Send + 'static>>,
}
impl<T> FallibleTonicResponseStream<T> {
pub(crate) fn new(
receiver: Receiver<FlightError>,
response_stream: Pin<Box<dyn Stream<Item = Result<T, tonic::Status>> + Send + 'static>>,
) -> Self {
Self {
receiver,
response_stream,
}
}
}
impl<T> Stream for FallibleTonicResponseStream<T> {
type Item = Result<T, FlightError>;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
let pinned = self.get_mut();
let receiver = &mut pinned.receiver;
if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) {
return Poll::Ready(Some(Err(err)));
};
match ready!(pinned.response_stream.poll_next_unpin(cx)) {
Some(Ok(res)) => Poll::Ready(Some(Ok(res))),
Some(Err(status)) => Poll::Ready(Some(Err(FlightError::Tonic(status)))),
None => Poll::Ready(None),
}
}
}