1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
use anyhow::Context;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use futures::stream::{StreamExt, TryStreamExt};
use serde::de::DeserializeOwned;
use serde::Serialize;

use crate::{Result, Error};

#[derive(Debug, Clone, Copy, PartialEq)]
enum DecodeState {
    ReadHeader,
    ReadBody { len: usize },
}

fn decode_chunk<T>(buf: &mut BytesMut, state: &mut DecodeState) -> Result<Option<T>>
where
    T: DeserializeOwned + Send,
{
    if *state == DecodeState::ReadHeader {
        if buf.remaining() < 4 {
            return Ok(None);
        }

        let len = buf.get_u32() as usize;
        buf.reserve(len);
        *state = DecodeState::ReadBody { len };
    }

    if let DecodeState::ReadBody { len } = *state {
        if buf.remaining() < len || buf.len() < len {
            return Ok(None);
        }

        let to_decode = buf.split_to(len).freeze();

        let msg = rmp_serde::from_read_ref(&to_decode).context("Could not deserialize message")?;

        *state = DecodeState::ReadHeader;
        return Ok(Some(msg));
    }

    Ok(None)
}

pub fn decode_stream<T>(
    stream: impl futures::Stream<Item = Result<Bytes, hyper::Error>> + Send + 'static,
) -> impl futures::Stream<Item = Result<T>> + Send
where
    T: DeserializeOwned + Send,
{
    let mut buf = BytesMut::with_capacity(8 * 1024);
    let mut state = DecodeState::ReadHeader;

    async_stream::stream! {
        futures::pin_mut!(stream);

        loop {
            match decode_chunk(&mut buf, &mut state) {
                Err(e) => {
                    yield Err(e);
                    continue;
                },
                Ok(Some(t)) => {
                    yield Ok(t);
                    continue;
                }
                Ok(None) => ()
            }

            match stream.next().await {
                Some(Ok(bytes)) => {

                    buf.put(bytes);



                },
                Some(Err(e)) => yield Err(Error::from(e)),
                None => break,
            }
        }
    }
}

pub fn encode_stream<T>(
    stream: impl futures::Stream<Item = Result<T>> + Send + 'static,
) -> impl futures::Stream<Item = Result<Bytes>> + Send + 'static
where
    T: Serialize,
{
    let mut buf = BytesMut::with_capacity(8 * 1024).writer();
    stream.and_then(move |msg| {
        futures::future::ready({
            buf.get_mut().reserve(4);
            unsafe {
                buf.get_mut().advance_mut(4);
            }

            if let Err(e) =
                rmp_serde::encode::write(&mut buf, &msg).context("Could not serialize message")
            {
                return futures::future::ready(Result::Err(e));
            }

            let len = buf.get_ref().len() - 4;
            assert!(len <= std::u32::MAX as usize);
            {
                let buf = buf.get_mut();
                let mut buf = &mut buf[..4];
                buf.put_u32(len as u32);
            }

            let encoded = buf.get_mut().split_to(len + 4).freeze();
            Ok(encoded)
        })
    })
}