agilulf_protocol 0.1.0

Agilulf Protocol serializer and deserializer
Documentation
use super::{ProtocolError, Result};
use futures::io::{AsyncReadExt, AsyncWriteExt};
use futures::{AsyncRead, AsyncWrite};

pub const DEFAULT_BUF_SIZE: usize = 8 * 1024;

/// A Read buffer to provide `read_exact` and `read_line` methods.
///
/// Amazingly there isn't an `AsyncBuffer` for latest futures! The implementation of this struct is
/// similar with `BufferReader` in std. The implementation of `read_exact` and `read_line` learns
/// from `BufferRead` trait in std.
///
pub struct AsyncReadBuffer<T: AsyncRead + Unpin> {
    stream: T,
    read_buffer: Vec<u8>,

    read_pos: usize,
    read_cap: usize,
}

impl<T: AsyncRead + Unpin> AsyncReadBuffer<T> {
    pub fn new(stream: T) -> AsyncReadBuffer<T> {
        AsyncReadBuffer {
            stream,
            read_buffer: vec![0; DEFAULT_BUF_SIZE],

            read_cap: 0,
            read_pos: 0,
        }
    }

    pub async fn fill_buf(&mut self) -> Result<&[u8]> {
        if self.read_pos >= self.read_cap {
            self.read_cap = self.stream.read(&mut self.read_buffer).await?;
            if self.read_cap == 0 {
                return Err(ProtocolError::ConnectionClosed);
            }
            self.read_pos = 0;
        }
        Ok(&self.read_buffer[self.read_pos..self.read_cap])
    }

    pub fn consume(&mut self, amt: usize) {
        self.read_pos = std::cmp::min(self.read_pos + amt, self.read_cap);
    }

    /// **Note:** The return value of this method contains "\r\n".
    ///
    /// This method will find "\n" and check whether the former one byte is "\r". If it is, it will be
    /// consumed and return.
    pub async fn read_line(&mut self) -> Result<Vec<u8>> {
        let mut buf = Vec::new();

        loop {
            let (done, used) = {
                let available = self.fill_buf().await?;

                let index = memchr::memchr(b'\n', available);
                if index.is_some()
                    && ((index.unwrap() > 0 && available[index.unwrap() - 1] == b'\r')
                        || (index.unwrap() == 0 && !buf.is_empty() && buf[buf.len() - 1] == b'\r'))
                {
                    let index = index.unwrap();
                    buf.extend_from_slice(&available[..=index]);

                    (true, index + 1)
                } else {
                    buf.extend_from_slice(available);
                    (false, available.len())
                }
            };
            self.consume(used);

            if done || used == 0 {
                return Ok(buf);
            }
        }
    }

    pub async fn read_exact(&mut self, size: usize) -> Result<Vec<u8>> {
        let mut buf = Vec::new();
        let mut read = 0;
        loop {
            let (done, used) = {
                let available = self.fill_buf().await?;

                if read + available.len() >= size {
                    buf.extend_from_slice(&available[..size - read]);
                    (true, size - read)
                } else {
                    buf.extend_from_slice(available);
                    (false, available.len())
                }
            };
            self.consume(used);
            read += used;

            if done || used == 0 {
                return Ok(buf);
            }
        }
    }
}

/// Actually this struct doesn't have any buffer. It's only a wrapper for `AsyncWrite`
///
/// It's only used for implementing methods on stream.
pub struct AsyncWriteBuffer<T: AsyncWrite + Unpin> {
    pub stream: T,
}

impl<T: AsyncWrite + Unpin> AsyncWriteBuffer<T> {
    pub fn new(stream: T) -> AsyncWriteBuffer<T> {
        AsyncWriteBuffer { stream }
    }

    pub async fn write_all(&mut self, data: Vec<u8>) -> Result<()> {
        Ok(self.stream.write_all(data.as_slice()).await?)
    }
}

#[cfg(test)]
mod tests {
    use crate::AsyncReadBuffer;
    use futures::executor::{self, ThreadPool};
    use futures::task::SpawnExt;
    use futures::{AsyncWriteExt, StreamExt};
    use romio::{TcpListener, TcpStream};
    use std::net::SocketAddr;
    use std::sync::Once;

    const ADDRESS: &str = "127.0.0.1:7999";
    static START_SERVER: Once = Once::new();

    async fn start_server() -> TcpStream {
        START_SERVER.call_once(|| {
            std::thread::spawn(|| {
                executor::block_on(async {
                    let mut thread_pool = ThreadPool::new().unwrap();

                    let addr = ADDRESS.parse::<SocketAddr>().unwrap();
                    let mut listener = TcpListener::bind(&addr).unwrap();

                    let mut incoming = listener.incoming();

                    while let Some(stream) = incoming.next().await {
                        let mut stream: TcpStream = stream.unwrap();

                        thread_pool
                            .spawn(async move {
                                stream
                                    .write_all(b"TEST LINE 1\r\nTESTTESTTEST\r\n")
                                    .await
                                    .unwrap();
                                std::mem::forget(stream);
                            })
                            .unwrap();
                    }
                });
            });
        });
        let addr = ADDRESS.parse::<SocketAddr>().unwrap();
        loop {
            match TcpStream::connect(&addr).await {
                Err(_) => {}
                Ok(stream) => return stream,
            }
        }
    }

    #[test]
    fn read_line() {
        let future = async {
            let stream = start_server().await;
            let mut buffer = AsyncReadBuffer::new(stream);

            let line = buffer.read_line().await.unwrap();
            let line = std::str::from_utf8(line.as_slice()).unwrap();
            assert_eq!(line, "TEST LINE 1\r\n");
        };

        futures::executor::block_on(future);
    }

    #[test]
    fn read_exact() {
        let future = async {
            let stream = start_server().await;
            let mut buffer = AsyncReadBuffer::new(stream);

            let exact = buffer.read_exact(8).await.unwrap();
            let exact = std::str::from_utf8(exact.as_slice()).unwrap();
            assert_eq!(exact, "TEST LIN");

            buffer.read_line().await.unwrap();

            let exact = buffer.read_exact(8).await.unwrap();
            let exact = std::str::from_utf8(exact.as_slice()).unwrap();
            assert_eq!(exact, "TESTTEST");
        };

        futures::executor::block_on(future);
    }
}