use std::{
fmt,
pin::Pin,
sync::{Arc, Mutex},
task::{self, Poll},
time::Duration,
};
use futures_core::{FusedStream, Stream};
use futures_util::{StreamExt, stream::TakeUntil};
use pin_project_lite::pin_project;
use serde::de::DeserializeOwned;
use socketioxide_core::{
Sid,
adapter::AckStreamItem,
adapter::remote_packet::{Response, ResponseType},
};
use tokio::time;
use crate::{ResponseHandlers, drivers::MessageStream};
pin_project! {
pub struct AckStream<S> {
#[pin]
local: S,
#[pin]
remote: DropStream<TakeUntil<MessageStream<Vec<u8>>, time::Sleep>>,
ack_cnt: u32,
total_ack_cnt: usize,
serv_cnt: u16,
}
}
impl<S> AckStream<S> {
pub fn new(
local: S,
remote: MessageStream<Vec<u8>>,
timeout: Duration,
serv_cnt: u16,
req_id: Sid,
handlers: Arc<Mutex<ResponseHandlers>>,
) -> Self {
let remote = remote.take_until(time::sleep(timeout));
let remote = DropStream::new(remote, handlers, req_id);
Self {
local,
remote,
ack_cnt: 0,
total_ack_cnt: 0,
serv_cnt,
}
}
pub fn new_local(local: S) -> Self {
let handlers = Arc::new(Mutex::new(ResponseHandlers::new()));
let remote = MessageStream::new_empty().take_until(time::sleep(Duration::ZERO));
let remote = DropStream::new(remote, handlers, Sid::ZERO);
Self {
local,
remote,
ack_cnt: 0,
total_ack_cnt: 0,
serv_cnt: 0,
}
}
}
impl<Err, S> AckStream<S>
where
Err: DeserializeOwned + fmt::Debug,
S: Stream<Item = AckStreamItem<Err>> + FusedStream,
{
fn poll_remote<E: DeserializeOwned + fmt::Debug>(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> Poll<Option<AckStreamItem<E>>> {
if FusedStream::is_terminated(&self) {
return Poll::Ready(None);
}
let projection = self.as_mut().project();
match projection.remote.poll_next(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(item)) => {
let res = rmp_serde::from_slice::<(Sid, Response<E>)>(&item);
match res {
Ok((
req_id,
Response {
node_id: uid,
r#type: ResponseType::BroadcastAckCount(count),
},
)) if *projection.serv_cnt > 0 => {
tracing::trace!(?uid, ?req_id, "receiving broadcast ack count {count}");
*projection.ack_cnt += count;
*projection.total_ack_cnt += count as usize;
*projection.serv_cnt -= 1;
self.poll_remote(cx)
}
Ok((
req_id,
Response {
node_id: uid,
r#type: ResponseType::BroadcastAck((sid, res)),
},
)) if *projection.ack_cnt > 0 => {
tracing::trace!(?uid, ?req_id, "receiving broadcast ack {sid} {:?}", res);
*projection.ack_cnt -= 1;
Poll::Ready(Some((sid, res)))
}
Ok((req_id, Response { node_id: uid, .. })) => {
tracing::warn!(?uid, ?req_id, ?self, "unexpected response type");
self.poll_remote(cx)
}
Err(e) => {
tracing::warn!("error decoding ack response: {e}");
self.poll_remote(cx)
}
}
}
}
}
}
impl<E, S> Stream for AckStream<S>
where
E: DeserializeOwned + fmt::Debug,
S: Stream<Item = AckStreamItem<E>> + FusedStream,
{
type Item = AckStreamItem<E>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
match self.as_mut().project().local.poll_next(cx) {
Poll::Pending => match self.poll_remote(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Some(item)) => Poll::Ready(Some(item)),
Poll::Ready(None) => Poll::Pending,
},
Poll::Ready(Some(item)) => Poll::Ready(Some(item)),
Poll::Ready(None) => self.poll_remote(cx),
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let (lower, upper) = self.local.size_hint();
(lower, upper.map(|upper| upper + self.total_ack_cnt))
}
}
impl<Err, S> FusedStream for AckStream<S>
where
Err: DeserializeOwned + fmt::Debug,
S: Stream<Item = AckStreamItem<Err>> + FusedStream,
{
fn is_terminated(&self) -> bool {
let remote_term = (self.ack_cnt == 0 && self.serv_cnt == 0) || self.remote.is_terminated();
self.local.is_terminated() && remote_term
}
}
impl<S> fmt::Debug for AckStream<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AckStream")
.field("ack_cnt", &self.ack_cnt)
.field("total_ack_cnt", &self.total_ack_cnt)
.field("serv_cnt", &self.serv_cnt)
.finish()
}
}
pin_project! {
pub struct DropStream<S> {
#[pin]
stream: S,
req_id: Sid,
handlers: Arc<Mutex<ResponseHandlers>>
}
impl<S> PinnedDrop for DropStream<S> {
fn drop(this: Pin<&mut Self>) {
let stream = this.project();
let chan = stream.req_id;
tracing::debug!(?chan, "dropping stream");
stream.handlers.lock().unwrap().remove(chan);
}
}
}
impl<S> DropStream<S> {
pub fn new(stream: S, handlers: Arc<Mutex<ResponseHandlers>>, req_id: Sid) -> Self {
Self {
stream,
handlers,
req_id,
}
}
}
impl<S: Stream> Stream for DropStream<S> {
type Item = S::Item;
fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
self.project().stream.poll_next(cx)
}
}
impl<S: FusedStream> FusedStream for DropStream<S> {
fn is_terminated(&self) -> bool {
self.stream.is_terminated()
}
}
#[cfg(test)]
mod tests {
use futures_core::FusedStream;
use futures_util::StreamExt;
use socketioxide_core::{Sid, Value};
use super::AckStream;
#[tokio::test]
async fn local_ack_stream_should_have_a_closed_remote() {
let sid = Sid::new();
let local = futures_util::stream::once(async move {
(sid, Ok::<_, ()>(Value::Str("local".into(), None)))
});
let stream = AckStream::new_local(local);
futures_util::pin_mut!(stream);
assert_eq!(stream.ack_cnt, 0);
assert_eq!(stream.total_ack_cnt, 0);
assert_eq!(stream.serv_cnt, 0);
assert!(!stream.local.is_terminated());
assert!(!stream.is_terminated());
let data = stream.next().await;
assert!(
matches!(data, Some((id, Ok(Value::Str(msg, None)))) if id == sid && msg == "local")
);
assert_eq!(stream.next().await, None);
assert!(stream.is_terminated());
}
}