1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
use bytes::{Buf, BufMut, Bytes, BytesMut};
use futures::prelude::*;
use std::io::{Cursor, Error, ErrorKind, Result};

#[derive(Eq, PartialEq, Debug)]
pub enum Command {
    Sync,
    Finish,
    Push,
    Nop,
}

impl Command {
    #[inline]
    fn from_u8(val: u8) -> Result<Self> {
        match val {
            0 => Ok(Command::Sync),
            1 => Ok(Command::Finish),
            2 => Ok(Command::Push),
            3 => Ok(Command::Nop),
            _ => Err(Error::new(ErrorKind::InvalidData, "invalid command")),
        }
    }

    #[inline]
    fn to_u8(&self) -> u8 {
        match self {
            Command::Sync => 0,
            Command::Finish => 1,
            Command::Push => 2,
            Command::Nop => 3,
        }
    }
}

pub struct Frame {
    pub version: u8,
    pub command: Command,
    pub length: u16,
    pub stream_id: u32,
    pub payload: Bytes,
}

impl Frame {
    pub async fn read_from<R: AsyncRead + Unpin>(reader: &mut R) -> Result<Self> {
        let mut header_buf = [0u8; 8];
        reader.read_exact(&mut header_buf[..]).await?;
        let mut cursor = Cursor::new(&header_buf);
        let version = cursor.get_u8();
        if version != 1 {
            return Err(Error::new(ErrorKind::InvalidData, "invalid protocol"));
        }
        let command = Command::from_u8(cursor.get_u8())?;
        let length = cursor.get_u16();
        let stream_id = cursor.get_u32();
        let mut payload = BytesMut::with_capacity(length as usize);
        if length != 0 {
            payload.resize(length as usize, 0);
            reader.read_exact(&mut payload).await?;
        }
        let payload = payload.freeze();
        Ok(Self {
            version,
            command,
            length,
            stream_id,
            payload,
        })
    }

    pub async fn write_to<W: AsyncWrite + Unpin>(&self, writer: &mut W) -> Result<()> {
        let mut header_buf = BytesMut::with_capacity(8);
        header_buf.put_u8(self.version);
        header_buf.put_u8(self.command.to_u8());
        header_buf.put_u16(self.length);
        header_buf.put_u32(self.stream_id);
        writer.write_all(&header_buf).await?;
        if self.payload.len() != 0 {
            writer.write_all(&self.payload).await?;
        }
        Ok(())
    }
}

#[cfg(test)]
mod test {
    use super::{Command, Frame};
    use bytes::Bytes;
    use smol::io::Cursor;
    #[test]
    fn test_command() {
        assert_eq!(
            Command::from_u8(Command::Push.to_u8()).unwrap(),
            Command::Push,
        );
        assert_eq!(
            Command::from_u8(Command::Sync.to_u8()).unwrap(),
            Command::Sync,
        );
        assert_eq!(
            Command::from_u8(Command::Finish.to_u8()).unwrap(),
            Command::Finish,
        );
        assert_eq!(
            Command::from_u8(Command::Nop.to_u8()).unwrap(),
            Command::Nop,
        );
        assert!(Command::from_u8(100).is_err());
    }

    #[test]
    fn test_frame() {
        smol::block_on(async {
            let mut cursor = Cursor::new(Vec::<u8>::new());
            let payload = Bytes::from_static(b"payload12345678");

            let frame = Frame {
                version: 1,
                stream_id: 1234,
                command: Command::Sync,
                length: payload.len() as u16,
                payload: payload,
            };

            frame.write_to(&mut cursor).await.unwrap();

            let mut cursor = Cursor::new(cursor.get_ref());
            let new_frame = Frame::read_from(&mut cursor).await.unwrap();
            assert_eq!(new_frame.version, frame.version);
            assert_eq!(new_frame.stream_id, frame.stream_id);
            assert_eq!(new_frame.command.to_u8(), frame.command.to_u8());
            assert_eq!(new_frame.length, frame.length);
            assert_eq!(new_frame.payload.to_vec(), frame.payload.to_vec());

            let invalid_buf = [1u8, 4, 5, 6, 7, 1, 1, 1, 1, 1, 1, 1];
            let mut cursor = Cursor::new(&invalid_buf);
            assert!(Frame::read_from(&mut cursor).await.is_err());
            let invalid_buf = [3u8, 4, 8, 6, 7, 1, 1, 1, 1, 1, 11, 1, 1];
            let mut cursor = Cursor::new(&invalid_buf);
            assert!(Frame::read_from(&mut cursor).await.is_err());
        });
    }
}