mini-telegram 0.1.17

mini-telegram is an unofficial, monolithic, idiomatic implementation of MTProto (telegram) server built with Rust that compatible with all telegram clients (web, android, iOS, desktop).
Documentation
use crate::frame::{self, Frame};

use bytes::{Buf, BytesMut};
use std::io::{self, Cursor};
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufWriter};
use tokio::net::TcpStream;

/// Send and receive `Frame` values from a remote peer.
///
/// When implementing networking protocols, a message on that protocol is
/// often composed of several smaller messages known as frames. The purpose of
/// `Connection` is to read and write frames on the underlying `TcpStream`.
///
/// To read frames, the `Connection` uses an internal buffer, which is filled
/// up until there are enough bytes to create a full frame. Once this happens,
/// the `Connection` creates the frame and returns it to the caller.
///
/// When sending frames, the frame is first encoded into the write buffer.
/// The contents of the write buffer are then written to the socket.
#[derive(Debug)]
pub struct Connection {
    // The `TcpStream`. It is decorated with a `BufWriter`, which provides write
    // level buffering. The `BufWriter` implementation provided by Tokio is
    // sufficient for our needs.
    stream: BufWriter<TcpStream>,

    // The buffer for reading frames.
    buffer: BytesMut,
}

impl Connection {
    /// Create a new `Connection`, backed by `socket`. Read and write buffers
    /// are initialized.
    pub fn new(socket: TcpStream) -> Connection {
        Connection {
            stream: BufWriter::new(socket),
            // Default to a 4KB read buffer. For the use case of mini telegram,
            // this is fine. However, real applications will want to tune this
            // value to their specific use case. There is a high likelihood that
            // a larger read buffer will work better.
            buffer: BytesMut::with_capacity(4 * 1024),
        }
    }

    /// Read a single `Frame` value from the underlying stream.
    ///
    /// The function waits until it has retrieved enough data to parse a frame.
    /// Any data remaining in the read buffer after the frame has been parsed is
    /// kept there for the next call to `read_frame`.
    ///
    /// # Returns
    ///
    /// On success, the received frame is returned. If the `TcpStream`
    /// is closed in a way that doesn't break a frame in half, it returns
    /// `None`. Otherwise, an error is returned.
    pub async fn read_frame(&mut self) -> crate::Result<Option<Frame>> {
        loop {
            // Attempt to parse a frame from the buffered data. If enough data
            // has been buffered, the frame is returned.
            if let Some(frame) = self.parse_frame()? {
                return Ok(Some(frame));
            }

            // There is not enough buffered data to read a frame. Attempt to
            // read more data from the socket.
            //
            // On success, the number of bytes is returned. `0` indicates "end
            // of stream".
            if 0 == self.stream.read_buf(&mut self.buffer).await? {
                // The remote closed the connection. For this to be a clean
                // shutdown, there should be no data in the read buffer. If
                // there is, this means that the peer closed the socket while
                // sending a frame.
                if self.buffer.is_empty() {
                    return Ok(None);
                } else {
                    return Err("connection reset by peer".into());
                }
            }
        }
    }

    /// Tries to parse a frame from the buffer. If the buffer contains enough
    /// data, the frame is returned and the data removed from the buffer. If not
    /// enough data has been buffered yet, `Ok(None)` is returned. If the
    /// buffered data does not represent a valid frame, `Err` is returned.
    fn parse_frame(&mut self) -> crate::Result<Option<Frame>> {
        use frame::Error::Incomplete;

        // Cursor is used to track the "current" location in the
        // buffer. Cursor also implements `Buf` from the `bytes` crate
        // which provides a number of helpful utilities for working
        // with bytes.
        let mut buf = Cursor::new(&self.buffer[..]);

        // The first step is to check if enough data has been buffered to parse
        // a single frame. This step is usually much faster than doing a full
        // parse of the frame, and allows us to skip allocating data structures
        // to hold the frame data unless we know the full frame has been
        // received.
        match Frame::check(&mut buf) {
            Ok(_) => {
                // The `check` function will have advanced the cursor until the
                // end of the frame. Since the cursor had position set to zero
                // before `Frame::check` was called, we obtain the length of the
                // frame by checking the cursor position.
                let len = buf.position() as usize;

                // Reset the position to zero before passing the cursor to
                // `Frame::parse`.
                buf.set_position(0);

                // Parse the frame from the buffer. This allocates the necessary
                // structures to represent the frame and returns the frame
                // value.
                //
                // If the encoded frame representation is invalid, an error is
                // returned. This should terminate the **current** connection
                // but should not impact any other connected client.
                let frame = Frame::parse(&mut buf)?;

                // Discard the parsed data from the read buffer.
                //
                // When `advance` is called on the read buffer, all of the data
                // up to `len` is discarded. The details of how this works is
                // left to `BytesMut`. This is often done by moving an internal
                // cursor, but it may be done by reallocating and copying data.
                self.buffer.advance(len);

                // Return the parsed frame to the caller.
                Ok(Some(frame))
            }
            // There is not enough data present in the read buffer to parse a
            // single frame. We must wait for more data to be received from the
            // socket. Reading from the socket will be done in the statement
            // after this `match`.
            //
            // We do not want to return `Err` from here as this "error" is an
            // expected runtime condition.
            Err(Incomplete) => Ok(None),
            // An error was encountered while parsing the frame. The connection
            // is now in an invalid state. Returning `Err` from here will result
            // in the connection being closed.
            Err(e) => Err(e.into()),
        }
    }

    /// Write a single `Frame` value to the underlying stream.
    ///
    /// The `Frame` value is written to the socket using the various `write_*`
    /// functions provided by `AsyncWrite`. Calling these functions directly on
    /// a `TcpStream` is **not** advised, as this will result in a large number of
    /// syscalls. However, it is fine to call these functions on a *buffered*
    /// write stream. The data will be written to the buffer. Once the buffer is
    /// full, it is flushed to the underlying socket.
    pub async fn write_frame(&mut self, frame: &Frame) -> io::Result<()> {
        // Arrays are encoded by encoding each entry. All other frame types are
        // considered literals. For now, mini-telegram is not able to encode
        // recursive frame structures. See below for more details.
        match frame {
            Frame::Array(val) => {
                // Encode the frame type prefix. For an array, it is `*`.
                self.stream.write_u8(b'*').await?;

                // Encode the length of the array.
                self.write_decimal(val.len() as u64).await?;

                // Iterate and encode each entry in the array.
                for entry in &**val {
                    self.write_value(entry).await?;
                }
            }
            // The frame type is a literal. Encode the value directly.
            _ => self.write_value(frame).await?,
        }

        // Ensure the encoded frame is written to the socket. The calls above
        // are to the buffered stream and writes. Calling `flush` writes the
        // remaining contents of the buffer to the socket.
        self.stream.flush().await
    }

    /// Write a frame literal to the stream
    async fn write_value(&mut self, frame: &Frame) -> io::Result<()> {
        match frame {
            Frame::Simple(val) => {
                self.stream.write_u8(b'+').await?;
                self.stream.write_all(val.as_bytes()).await?;
                self.stream.write_all(b"\r\n").await?;
            }
            Frame::Error(val) => {
                self.stream.write_u8(b'-').await?;
                self.stream.write_all(val.as_bytes()).await?;
                self.stream.write_all(b"\r\n").await?;
            }
            Frame::Integer(val) => {
                self.stream.write_u8(b':').await?;
                self.write_decimal(*val).await?;
            }
            Frame::Null => {
                self.stream.write_all(b"$-1\r\n").await?;
            }
            Frame::Bulk(val) => {
                let len = val.len();

                self.stream.write_u8(b'$').await?;
                self.write_decimal(len as u64).await?;
                self.stream.write_all(val).await?;
                self.stream.write_all(b"\r\n").await?;
            }
            // Encoding an `Array` from within a value cannot be done using a
            // recursive strategy. In general, async fns do not support
            // recursion. Mini-telegram has not needed to encode nested arrays yet,
            // so for now it is skipped.
            Frame::Array(_val) => unreachable!(),
        }

        Ok(())
    }

    /// Write a decimal frame to the stream
    async fn write_decimal(&mut self, val: u64) -> io::Result<()> {
        use std::io::Write;

        // Convert the value to a string
        let mut buf = [0u8; 20];
        let mut buf = Cursor::new(&mut buf[..]);
        write!(&mut buf, "{}", val)?;

        let pos = buf.position() as usize;
        self.stream.write_all(&buf.get_ref()[..pos]).await?;
        self.stream.write_all(b"\r\n").await?;

        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use bytes::{BufMut, BytesMut};
    use futures::future::join_all;
    use std::str;
    use tokio::io::{AsyncReadExt, AsyncWriteExt, BufWriter};
    use tokio::net::{TcpListener, TcpStream};
    use tokio::time::Instant;
    use tokio::try_join;

    #[tokio::test]
    async fn test_tcp_stream() {
        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
        let addr = listener.local_addr().unwrap();

        let server = tokio::spawn(async move {
            let mut stream = listener.accept().await.unwrap().0; // (stream, addr)
            let mut buf = [0];
            let _ = stream.read(&mut buf).await.unwrap();
            assert_eq!(buf[0], 144);
            // println!("server terminated!");
        });

        let client = tokio::spawn(async move {
            let mut stream = TcpStream::connect(addr).await.unwrap();
            let _ = stream.write_all(&[144]).await.unwrap();
        });

        try_join!(server, client).unwrap();
    }

    #[tokio::test]
    async fn test_tcp_stream_buf_writer() {
        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
        let addr = listener.local_addr().unwrap();
        const N: usize = 10240;

        let server = tokio::spawn(async move {
            let mut handles: Vec<tokio::task::JoinHandle<()>> = Vec::new();
            for _ in 0..2 {
                let mut stream = listener.accept().await.unwrap().0; // (stream, addr)
                handles.push(tokio::spawn(async move {
                    let mut buf = [0; 10];
                    for _ in 0..N {
                        let _ = stream.read(&mut buf).await.unwrap();
                        assert_eq!(str::from_utf8(&buf).unwrap(), "some bytes");
                    }
                    // println!("handler thread terminated: {}", t);
                }));
            }
            let _ = join_all(handles).await;
            println!("server terminated!");
        });

        let client_tcp_stream = tokio::spawn(async move {
            let mut stream = TcpStream::connect(addr).await.unwrap();
            let now = Instant::now();
            for _ in 0..N {
                let _ = stream.write_all(b"some bytes").await.unwrap();
            }
            let tcp_stream_time_consumption = now.elapsed();
            // println!("tcp_stream:{:?}", tcp_stream_time_consumption);
            tcp_stream_time_consumption
        });

        let client_buf_writer = tokio::spawn(async move {
            let stream = TcpStream::connect(addr).await.unwrap();
            // `BufWriter` can improve the speed of programs that make *small* and
            // *repeated* write calls to the same file or network socket. It does not
            // help when writing very large amounts at once, or writing just one or a few
            // times. It also provides no advantage when writing to a destination that is
            // in memory, like a `Vec<u8>`.
            let mut stream = BufWriter::new(stream);
            let now = Instant::now();
            for _ in 0..N {
                let _ = stream.write_all(b"some bytes").await.unwrap();
            }
            let buf_writer_time_consumption = now.elapsed();
            // println!("buf_writer:{:?}", buf_writer_time_consumption);
            buf_writer_time_consumption
        });

        let (_, tcp_stream_time_consumption, buf_writer_time_consumption) =
            try_join!(server, client_tcp_stream, client_buf_writer).unwrap();

        assert!(buf_writer_time_consumption < tcp_stream_time_consumption);
    }

    #[tokio::test]
    async fn test_bytes_mut_growth() {
        // BytesMut’s BufMut implementation will implicitly grow its buffer
        // as necessary. However, explicitly reserving the required space
        // up-front before a series of inserts will be more efficient.
        let mut buf = BytesMut::with_capacity(10);
        let addr_a = format!("{:p}", buf.as_ptr());
        buf.put(&b"yumcoder"[..]);
        let addr_b = format!("{:p}", buf.as_ptr());
        assert_eq!(addr_a, addr_b);
        buf.put(&b"more content to expand the current buffer!"[..]);
        let addr_c = format!("{:p}", buf.as_ptr());
        assert_ne!(addr_c, addr_b);
    }

    #[tokio::test]
    async fn test_read_frame() {
        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
        let addr = listener.local_addr().unwrap();

        let server = tokio::spawn(async move {
            let stream = listener.accept().await.unwrap().0; // (stream, addr)
            let mut connection = Connection::new(stream);
            let cmd = connection.read_frame().await.unwrap();
            if let Some(x) = cmd {
                assert_eq!(x.to_string(), "OK");
            }

            let cmd = connection.read_frame().await.unwrap_err();
            let err = frame::Error::from("protocol error; invalid frame type byte `33`");
            assert_eq!(cmd.to_string(), err.to_string());
            // println!("server terminated!");
        });

        let client = tokio::spawn(async move {
            let mut stream = TcpStream::connect(addr).await.unwrap();
            let _ = stream.write_all(&b"+OK\r\n"[..]).await.unwrap();
            let _ = stream.write_all(&b"!"[..]).await.unwrap();
        });

        try_join!(server, client).unwrap();
    }

    #[tokio::test]
    async fn test_write_frame() {
        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
        let addr = listener.local_addr().unwrap();

        let server = tokio::spawn(async move {
            let stream = listener.accept().await.unwrap().0; // (stream, addr)
            let mut connection = Connection::new(stream);
            let cmd = connection.read_frame().await.unwrap();
            if let Some(x) = cmd {
                assert_eq!(x.to_string(), "OK");
                connection.write_frame(&x).await.unwrap();
            }
            // println!("server terminated!");
        });

        let client = tokio::spawn(async move {
            let mut stream = TcpStream::connect(addr).await.unwrap();
            // for simplicity using Connection only on server side
            let _ = stream.write_all(&b"+OK\r\n"[..]).await.unwrap();
            let mut buf = [0; 5];
            let _ = stream.read(&mut buf).await.unwrap();
            assert_eq!(str::from_utf8(&buf).unwrap(), "+OK\r\n");
        });

        try_join!(server, client).unwrap();
    }
}