use std::{
marker::PhantomData,
pin::Pin,
task::{Context, Poll},
};
use futures::Stream;
use super::{ClientError, runtime::ClientStream};
use crate::{app::Packet, message::DecodeWith, serializer::Serializer};
pub struct ResponseStream<'a, P, S, T, C>
where
P: Packet + DecodeWith<S>,
S: Serializer + Send + Sync,
T: ClientStream,
{
client: &'a mut super::WireframeClient<S, T, C>,
correlation_id: u64,
terminated: bool,
frame_count: usize,
_phantom: PhantomData<fn() -> P>,
}
impl<'a, P, S, T, C> ResponseStream<'a, P, S, T, C>
where
P: Packet + DecodeWith<S>,
S: Serializer + Send + Sync,
T: ClientStream,
{
pub(crate) fn new(
client: &'a mut super::WireframeClient<S, T, C>,
correlation_id: u64,
) -> Self {
Self {
client,
correlation_id,
terminated: false,
frame_count: 0,
_phantom: PhantomData,
}
}
#[must_use]
pub fn correlation_id(&self) -> u64 { self.correlation_id }
#[must_use]
pub fn is_terminated(&self) -> bool { self.terminated }
#[must_use]
pub fn frame_count(&self) -> usize { self.frame_count }
fn on_frame_received(&mut self, frame_bytes: usize, result: Option<&Result<P, ClientError>>) {
if let Some(Ok(_)) = result {
self.frame_count = self.frame_count.saturating_add(1);
tracing::debug!(
frame.bytes = frame_bytes,
stream.frames_received = self.frame_count,
correlation_id = self.correlation_id,
"stream frame received"
);
}
}
fn process_frame(&mut self, bytes: &[u8]) -> Option<Result<P, ClientError>> {
let (packet, _consumed) = match self.client.serializer.deserialize::<P>(bytes) {
Ok(result) => result,
Err(e) => {
self.terminated = true;
return Some(Err(ClientError::decode(e)));
}
};
if packet.is_stream_terminator() {
self.terminated = true;
tracing::debug!(
stream.frames_total = self.frame_count,
correlation_id = self.correlation_id,
"stream terminated"
);
return None;
}
let received_cid = packet.correlation_id();
if received_cid != Some(self.correlation_id) {
self.terminated = true;
return Some(Err(ClientError::StreamCorrelationMismatch {
expected: Some(self.correlation_id),
received: received_cid,
}));
}
Some(Ok(packet))
}
}
impl<P, S, T, C> Stream for ResponseStream<'_, P, S, T, C>
where
P: Packet + DecodeWith<S>,
S: Serializer + Send + Sync,
T: ClientStream,
{
type Item = Result<P, ClientError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
if this.terminated {
return Poll::Ready(None);
}
match Pin::new(&mut this.client.framed).poll_next(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(None) => {
this.terminated = true;
Poll::Ready(Some(Err(ClientError::disconnected())))
}
Poll::Ready(Some(Err(e))) => {
this.terminated = true;
Poll::Ready(Some(Err(ClientError::from(e))))
}
Poll::Ready(Some(Ok(mut bytes))) => {
let frame_bytes = bytes.len();
this.client.invoke_after_receive_hooks(&mut bytes);
let result = this.process_frame(&bytes);
this.on_frame_received(frame_bytes, result.as_ref());
Poll::Ready(result)
}
}
}
}