mini_redis/
connection.rs

1use crate::frame::{self, Frame};
2
3use bytes::{Buf, BytesMut};
4use std::io::{self, Cursor};
5use tokio::io::{AsyncReadExt, AsyncWriteExt, BufWriter};
6use tokio::net::TcpStream;
7
8/// Send and receive `Frame` values from a remote peer.
9///
10/// When implementing networking protocols, a message on that protocol is
11/// often composed of several smaller messages known as frames. The purpose of
12/// `Connection` is to read and write frames on the underlying `TcpStream`.
13///
14/// To read frames, the `Connection` uses an internal buffer, which is filled
15/// up until there are enough bytes to create a full frame. Once this happens,
16/// the `Connection` creates the frame and returns it to the caller.
17///
18/// When sending frames, the frame is first encoded into the write buffer.
19/// The contents of the write buffer are then written to the socket.
20#[derive(Debug)]
21pub struct Connection {
22    // The `TcpStream`. It is decorated with a `BufWriter`, which provides write
23    // level buffering. The `BufWriter` implementation provided by Tokio is
24    // sufficient for our needs.
25    stream: BufWriter<TcpStream>,
26
27    // The buffer for reading frames.
28    buffer: BytesMut,
29}
30
31impl Connection {
32    /// Create a new `Connection`, backed by `socket`. Read and write buffers
33    /// are initialized.
34    pub fn new(socket: TcpStream) -> Connection {
35        Connection {
36            stream: BufWriter::new(socket),
37            // Default to a 4KB read buffer. For the use case of mini redis,
38            // this is fine. However, real applications will want to tune this
39            // value to their specific use case. There is a high likelihood that
40            // a larger read buffer will work better.
41            buffer: BytesMut::with_capacity(4 * 1024),
42        }
43    }
44
45    /// Read a single `Frame` value from the underlying stream.
46    ///
47    /// The function waits until it has retrieved enough data to parse a frame.
48    /// Any data remaining in the read buffer after the frame has been parsed is
49    /// kept there for the next call to `read_frame`.
50    ///
51    /// # Returns
52    ///
53    /// On success, the received frame is returned. If the `TcpStream`
54    /// is closed in a way that doesn't break a frame in half, it returns
55    /// `None`. Otherwise, an error is returned.
56    pub async fn read_frame(&mut self) -> crate::Result<Option<Frame>> {
57        loop {
58            // Attempt to parse a frame from the buffered data. If enough data
59            // has been buffered, the frame is returned.
60            if let Some(frame) = self.parse_frame()? {
61                return Ok(Some(frame));
62            }
63
64            // There is not enough buffered data to read a frame. Attempt to
65            // read more data from the socket.
66            //
67            // On success, the number of bytes is returned. `0` indicates "end
68            // of stream".
69            if 0 == self.stream.read_buf(&mut self.buffer).await? {
70                // The remote closed the connection. For this to be a clean
71                // shutdown, there should be no data in the read buffer. If
72                // there is, this means that the peer closed the socket while
73                // sending a frame.
74                if self.buffer.is_empty() {
75                    return Ok(None);
76                } else {
77                    return Err("connection reset by peer".into());
78                }
79            }
80        }
81    }
82
83    /// Tries to parse a frame from the buffer. If the buffer contains enough
84    /// data, the frame is returned and the data removed from the buffer. If not
85    /// enough data has been buffered yet, `Ok(None)` is returned. If the
86    /// buffered data does not represent a valid frame, `Err` is returned.
87    fn parse_frame(&mut self) -> crate::Result<Option<Frame>> {
88        use frame::Error::Incomplete;
89
90        // Cursor is used to track the "current" location in the
91        // buffer. Cursor also implements `Buf` from the `bytes` crate
92        // which provides a number of helpful utilities for working
93        // with bytes.
94        let mut buf = Cursor::new(&self.buffer[..]);
95
96        // The first step is to check if enough data has been buffered to parse
97        // a single frame. This step is usually much faster than doing a full
98        // parse of the frame, and allows us to skip allocating data structures
99        // to hold the frame data unless we know the full frame has been
100        // received.
101        match Frame::check(&mut buf) {
102            Ok(_) => {
103                // The `check` function will have advanced the cursor until the
104                // end of the frame. Since the cursor had position set to zero
105                // before `Frame::check` was called, we obtain the length of the
106                // frame by checking the cursor position.
107                let len = buf.position() as usize;
108
109                // Reset the position to zero before passing the cursor to
110                // `Frame::parse`.
111                buf.set_position(0);
112
113                // Parse the frame from the buffer. This allocates the necessary
114                // structures to represent the frame and returns the frame
115                // value.
116                //
117                // If the encoded frame representation is invalid, an error is
118                // returned. This should terminate the **current** connection
119                // but should not impact any other connected client.
120                let frame = Frame::parse(&mut buf)?;
121
122                // Discard the parsed data from the read buffer.
123                //
124                // When `advance` is called on the read buffer, all of the data
125                // up to `len` is discarded. The details of how this works is
126                // left to `BytesMut`. This is often done by moving an internal
127                // cursor, but it may be done by reallocating and copying data.
128                self.buffer.advance(len);
129
130                // Return the parsed frame to the caller.
131                Ok(Some(frame))
132            }
133            // There is not enough data present in the read buffer to parse a
134            // single frame. We must wait for more data to be received from the
135            // socket. Reading from the socket will be done in the statement
136            // after this `match`.
137            //
138            // We do not want to return `Err` from here as this "error" is an
139            // expected runtime condition.
140            Err(Incomplete) => Ok(None),
141            // An error was encountered while parsing the frame. The connection
142            // is now in an invalid state. Returning `Err` from here will result
143            // in the connection being closed.
144            Err(e) => Err(e.into()),
145        }
146    }
147
148    /// Write a single `Frame` value to the underlying stream.
149    ///
150    /// The `Frame` value is written to the socket using the various `write_*`
151    /// functions provided by `AsyncWrite`. Calling these functions directly on
152    /// a `TcpStream` is **not** advised, as this will result in a large number of
153    /// syscalls. However, it is fine to call these functions on a *buffered*
154    /// write stream. The data will be written to the buffer. Once the buffer is
155    /// full, it is flushed to the underlying socket.
156    pub async fn write_frame(&mut self, frame: &Frame) -> io::Result<()> {
157        // Arrays are encoded by encoding each entry. All other frame types are
158        // considered literals. For now, mini-redis is not able to encode
159        // recursive frame structures. See below for more details.
160        match frame {
161            Frame::Array(val) => {
162                // Encode the frame type prefix. For an array, it is `*`.
163                self.stream.write_u8(b'*').await?;
164
165                // Encode the length of the array.
166                self.write_decimal(val.len() as u64).await?;
167
168                // Iterate and encode each entry in the array.
169                for entry in &**val {
170                    self.write_value(entry).await?;
171                }
172            }
173            // The frame type is a literal. Encode the value directly.
174            _ => self.write_value(frame).await?,
175        }
176
177        // Ensure the encoded frame is written to the socket. The calls above
178        // are to the buffered stream and writes. Calling `flush` writes the
179        // remaining contents of the buffer to the socket.
180        self.stream.flush().await
181    }
182
183    /// Write a frame literal to the stream
184    async fn write_value(&mut self, frame: &Frame) -> io::Result<()> {
185        match frame {
186            Frame::Simple(val) => {
187                self.stream.write_u8(b'+').await?;
188                self.stream.write_all(val.as_bytes()).await?;
189                self.stream.write_all(b"\r\n").await?;
190            }
191            Frame::Error(val) => {
192                self.stream.write_u8(b'-').await?;
193                self.stream.write_all(val.as_bytes()).await?;
194                self.stream.write_all(b"\r\n").await?;
195            }
196            Frame::Integer(val) => {
197                self.stream.write_u8(b':').await?;
198                self.write_decimal(*val).await?;
199            }
200            Frame::Null => {
201                self.stream.write_all(b"$-1\r\n").await?;
202            }
203            Frame::Bulk(val) => {
204                let len = val.len();
205
206                self.stream.write_u8(b'$').await?;
207                self.write_decimal(len as u64).await?;
208                self.stream.write_all(val).await?;
209                self.stream.write_all(b"\r\n").await?;
210            }
211            // Encoding an `Array` from within a value cannot be done using a
212            // recursive strategy. In general, async fns do not support
213            // recursion. Mini-redis has not needed to encode nested arrays yet,
214            // so for now it is skipped.
215            Frame::Array(_val) => unreachable!(),
216        }
217
218        Ok(())
219    }
220
221    /// Write a decimal frame to the stream
222    async fn write_decimal(&mut self, val: u64) -> io::Result<()> {
223        use std::io::Write;
224
225        // Convert the value to a string
226        let mut buf = [0u8; 20];
227        let mut buf = Cursor::new(&mut buf[..]);
228        write!(&mut buf, "{}", val)?;
229
230        let pos = buf.position() as usize;
231        self.stream.write_all(&buf.get_ref()[..pos]).await?;
232        self.stream.write_all(b"\r\n").await?;
233
234        Ok(())
235    }
236}