use alloc::boxed::Box;
use core::{
fmt::Debug,
pin::Pin,
task::{Context, Poll},
};
use futures_util::stream::{Stream, unfold};
use serde::de::DeserializeOwned;
use crate::{
Result,
connection::{ReadConnection, socket::ReadHalf},
reply,
};
#[cfg(feature = "std")]
use std::os::fd::OwnedFd;
#[cfg(feature = "std")]
pub(crate) type ChainResult<Params, ReplyError> =
(reply::Result<Params, ReplyError>, alloc::vec::Vec<OwnedFd>);
#[cfg(not(feature = "std"))]
pub(crate) type ChainResult<Params, ReplyError> = reply::Result<Params, ReplyError>;
pub struct ReplyStream<'c, Params, ReplyError> {
inner: InnerStream<'c, Params, ReplyError>,
}
impl<'c, Params, ReplyError> ReplyStream<'c, Params, ReplyError>
where
Params: DeserializeOwned + Debug,
ReplyError: DeserializeOwned + Debug,
{
pub fn new<Read>(connection: &'c mut ReadConnection<Read>, reply_count: usize) -> Self
where
Read: ReadHalf + 'c,
{
let inner = unfold(
(connection, 0),
move |(conn, mut current_index)| async move {
if current_index >= reply_count {
return None;
}
let item = conn.receive_reply::<Params, ReplyError>().await;
let item_ref = item.as_ref();
#[cfg(feature = "std")]
let item_ref = item_ref.map(|r| &r.0);
match item_ref {
Ok(Ok(r)) if r.continues() != Some(true) => {
current_index += 1;
}
Ok(Ok(_)) => {
}
Ok(Err(_)) => {
current_index += 1;
}
Err(_) => {
current_index = reply_count;
}
}
Some((item, (conn, current_index)))
},
);
Self {
inner: Box::pin(inner),
}
}
}
impl<Params, ReplyError> Debug for ReplyStream<'_, Params, ReplyError> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("ReplyStream").finish_non_exhaustive()
}
}
impl<Params, ReplyError> Stream for ReplyStream<'_, Params, ReplyError> {
type Item = Result<ChainResult<Params, ReplyError>>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.inner.as_mut().poll_next(cx)
}
}
type InnerStream<'c, Params, ReplyError> =
Pin<Box<dyn Stream<Item = Result<ChainResult<Params, ReplyError>>> + 'c>>;