use super::InStream;
use bytes::{Buf, BufMut, BytesMut};
use futures::prelude::*;
use std::io;
use std::marker::PhantomData;
use std::pin::Pin;
use std::str::from_utf8;
use std::task::{Context, Poll};
use tokio_serde::Deserializer;
pub(crate) struct TaggedInStream {
recv: InStream,
tag: Vec<u8>,
}
impl TaggedInStream {
pub(crate) fn new(recv: InStream, id: String) -> Self {
let tag = {
let id_len = id.len();
let mut tmp = id.into_bytes();
tmp.put_u32(id_len as u32);
tmp
};
Self { recv, tag }
}
fn get_id(buf: &mut BytesMut) -> io::Result<BytesMut> {
use std::io::{Error, ErrorKind::InvalidData};
let buf_len = buf.len();
if buf_len < 4 {
return Err(Error::new(InvalidData, "BufLength"));
}
let id_len = (&buf[buf_len - 4..]).get_u32() as usize;
if buf_len < id_len + 4 {
return Err(Error::new(InvalidData, "IdLength"));
}
let mut id_buf = buf.split_off(buf_len - id_len - 4);
id_buf.truncate(id_len);
Ok(id_buf)
}
}
impl Stream for TaggedInStream {
type Item = Result<BytesMut, io::Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
match self.recv.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(mut buf))) => {
buf.put(self.tag.as_ref());
Poll::Ready(Some(Ok(buf)))
}
p => p,
}
}
}
pub struct TaglessBroadcastInStream<T, E>(T, PhantomData<*const E>);
impl<T: Unpin, E> Unpin for TaglessBroadcastInStream<T, E> {}
impl<T, E> TaglessBroadcastInStream<T, E> {
pub fn new(recv: T) -> Self {
Self(recv, PhantomData)
}
pub fn into_inner(self) -> T {
self.0
}
}
impl<T, E> Stream for TaglessBroadcastInStream<T, E>
where
T: Stream<Item = Result<BytesMut, E>> + Unpin,
io::Error: Into<E>,
{
type Item = Result<BytesMut, E>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
loop {
return match self.0.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(mut buf))) => {
if buf.is_empty() {
continue;
}
match TaggedInStream::get_id(&mut buf) {
Ok(_) => (),
Err(e) => return Poll::Ready(Some(Err(e.into()))),
};
Poll::Ready(Some(Ok(buf)))
}
p => p,
};
}
}
}
pub struct TaggedBroadcastInStream<T, E>(T, PhantomData<*const E>);
impl<T: Unpin, E> Unpin for TaggedBroadcastInStream<T, E> {}
impl<T, E> TaggedBroadcastInStream<T, E> {
pub fn new(recv: T) -> Self {
Self(recv, PhantomData)
}
pub fn into_inner(self) -> T {
self.0
}
}
impl<T, E> Stream for TaggedBroadcastInStream<T, E>
where
T: Stream<Item = Result<BytesMut, E>> + Unpin,
io::Error: Into<E>,
{
type Item = Result<(String, BytesMut), E>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
use std::io::{Error, ErrorKind::InvalidData};
loop {
return match self.0.poll_next_unpin(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
Poll::Ready(Some(Ok(mut buf))) => {
if buf.is_empty() {
continue;
}
let id_buf = match TaggedInStream::get_id(&mut buf) {
Ok(b) => b,
Err(e) => return Poll::Ready(Some(Err(e.into()))),
};
let id = String::from(match from_utf8(&id_buf[..]) {
Ok(s) => s,
Err(_) => return Poll::Ready(Some(Err(Error::new(InvalidData, "Utf8").into()))),
});
Poll::Ready(Some(Ok((id, buf))))
}
};
}
}
}
pub struct TaggedDeserializer<T, D, O, Et, Ec>(T, D, PhantomData<*const (O, Et, Ec)>);
impl<T: Unpin, D: Unpin, O, Et, Ec> Unpin for TaggedDeserializer<T, D, O, Et, Ec> {}
impl<T, D: Unpin, O, Et, Ec> TaggedDeserializer<T, D, O, Et, Ec> {
pub fn new(recv: T, deserializer: D) -> Self {
Self(recv, deserializer, PhantomData)
}
pub fn into_inner(self) -> T {
self.0
}
fn get_td(&mut self) -> (&mut T, Pin<&mut D>) {
(&mut self.0, Pin::new(&mut self.1))
}
}
impl<T, D, O, Et, Ec> Stream for TaggedDeserializer<T, D, O, Et, Ec>
where
T: Stream<Item = Result<(String, BytesMut), Et>> + Unpin,
D: Deserializer<O, Error = Ec> + Unpin,
io::Error: Into<Et>,
Ec: Into<io::Error>,
{
type Item = Result<(String, O), Et>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
let (recv, deser) = self.get_td();
match recv.poll_next_unpin(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
Poll::Ready(Some(Ok((id, buf)))) => {
let res = match deser.deserialize(&buf) {
Ok(r) => r,
Err(e) => return Poll::Ready(Some(Err(e.into().into()))),
};
Poll::Ready(Some(Ok((id, res))))
}
}
}
}