mini_redis/connection.rs
1use crate::frame::{self, Frame};
2
3use bytes::{Buf, BytesMut};
4use std::io::{self, Cursor};
5use tokio::io::{AsyncReadExt, AsyncWriteExt, BufWriter};
6use tokio::net::TcpStream;
7
8/// Send and receive `Frame` values from a remote peer.
9///
10/// When implementing networking protocols, a message on that protocol is
11/// often composed of several smaller messages known as frames. The purpose of
12/// `Connection` is to read and write frames on the underlying `TcpStream`.
13///
14/// To read frames, the `Connection` uses an internal buffer, which is filled
15/// up until there are enough bytes to create a full frame. Once this happens,
16/// the `Connection` creates the frame and returns it to the caller.
17///
18/// When sending frames, the frame is first encoded into the write buffer.
19/// The contents of the write buffer are then written to the socket.
20#[derive(Debug)]
21pub struct Connection {
22 // The `TcpStream`. It is decorated with a `BufWriter`, which provides write
23 // level buffering. The `BufWriter` implementation provided by Tokio is
24 // sufficient for our needs.
25 stream: BufWriter<TcpStream>,
26
27 // The buffer for reading frames.
28 buffer: BytesMut,
29}
30
31impl Connection {
32 /// Create a new `Connection`, backed by `socket`. Read and write buffers
33 /// are initialized.
34 pub fn new(socket: TcpStream) -> Connection {
35 Connection {
36 stream: BufWriter::new(socket),
37 // Default to a 4KB read buffer. For the use case of mini redis,
38 // this is fine. However, real applications will want to tune this
39 // value to their specific use case. There is a high likelihood that
40 // a larger read buffer will work better.
41 buffer: BytesMut::with_capacity(4 * 1024),
42 }
43 }
44
45 /// Read a single `Frame` value from the underlying stream.
46 ///
47 /// The function waits until it has retrieved enough data to parse a frame.
48 /// Any data remaining in the read buffer after the frame has been parsed is
49 /// kept there for the next call to `read_frame`.
50 ///
51 /// # Returns
52 ///
53 /// On success, the received frame is returned. If the `TcpStream`
54 /// is closed in a way that doesn't break a frame in half, it returns
55 /// `None`. Otherwise, an error is returned.
56 pub async fn read_frame(&mut self) -> crate::Result<Option<Frame>> {
57 loop {
58 // Attempt to parse a frame from the buffered data. If enough data
59 // has been buffered, the frame is returned.
60 if let Some(frame) = self.parse_frame()? {
61 return Ok(Some(frame));
62 }
63
64 // There is not enough buffered data to read a frame. Attempt to
65 // read more data from the socket.
66 //
67 // On success, the number of bytes is returned. `0` indicates "end
68 // of stream".
69 if 0 == self.stream.read_buf(&mut self.buffer).await? {
70 // The remote closed the connection. For this to be a clean
71 // shutdown, there should be no data in the read buffer. If
72 // there is, this means that the peer closed the socket while
73 // sending a frame.
74 if self.buffer.is_empty() {
75 return Ok(None);
76 } else {
77 return Err("connection reset by peer".into());
78 }
79 }
80 }
81 }
82
83 /// Tries to parse a frame from the buffer. If the buffer contains enough
84 /// data, the frame is returned and the data removed from the buffer. If not
85 /// enough data has been buffered yet, `Ok(None)` is returned. If the
86 /// buffered data does not represent a valid frame, `Err` is returned.
87 fn parse_frame(&mut self) -> crate::Result<Option<Frame>> {
88 use frame::Error::Incomplete;
89
90 // Cursor is used to track the "current" location in the
91 // buffer. Cursor also implements `Buf` from the `bytes` crate
92 // which provides a number of helpful utilities for working
93 // with bytes.
94 let mut buf = Cursor::new(&self.buffer[..]);
95
96 // The first step is to check if enough data has been buffered to parse
97 // a single frame. This step is usually much faster than doing a full
98 // parse of the frame, and allows us to skip allocating data structures
99 // to hold the frame data unless we know the full frame has been
100 // received.
101 match Frame::check(&mut buf) {
102 Ok(_) => {
103 // The `check` function will have advanced the cursor until the
104 // end of the frame. Since the cursor had position set to zero
105 // before `Frame::check` was called, we obtain the length of the
106 // frame by checking the cursor position.
107 let len = buf.position() as usize;
108
109 // Reset the position to zero before passing the cursor to
110 // `Frame::parse`.
111 buf.set_position(0);
112
113 // Parse the frame from the buffer. This allocates the necessary
114 // structures to represent the frame and returns the frame
115 // value.
116 //
117 // If the encoded frame representation is invalid, an error is
118 // returned. This should terminate the **current** connection
119 // but should not impact any other connected client.
120 let frame = Frame::parse(&mut buf)?;
121
122 // Discard the parsed data from the read buffer.
123 //
124 // When `advance` is called on the read buffer, all of the data
125 // up to `len` is discarded. The details of how this works is
126 // left to `BytesMut`. This is often done by moving an internal
127 // cursor, but it may be done by reallocating and copying data.
128 self.buffer.advance(len);
129
130 // Return the parsed frame to the caller.
131 Ok(Some(frame))
132 }
133 // There is not enough data present in the read buffer to parse a
134 // single frame. We must wait for more data to be received from the
135 // socket. Reading from the socket will be done in the statement
136 // after this `match`.
137 //
138 // We do not want to return `Err` from here as this "error" is an
139 // expected runtime condition.
140 Err(Incomplete) => Ok(None),
141 // An error was encountered while parsing the frame. The connection
142 // is now in an invalid state. Returning `Err` from here will result
143 // in the connection being closed.
144 Err(e) => Err(e.into()),
145 }
146 }
147
148 /// Write a single `Frame` value to the underlying stream.
149 ///
150 /// The `Frame` value is written to the socket using the various `write_*`
151 /// functions provided by `AsyncWrite`. Calling these functions directly on
152 /// a `TcpStream` is **not** advised, as this will result in a large number of
153 /// syscalls. However, it is fine to call these functions on a *buffered*
154 /// write stream. The data will be written to the buffer. Once the buffer is
155 /// full, it is flushed to the underlying socket.
156 pub async fn write_frame(&mut self, frame: &Frame) -> io::Result<()> {
157 // Arrays are encoded by encoding each entry. All other frame types are
158 // considered literals. For now, mini-redis is not able to encode
159 // recursive frame structures. See below for more details.
160 match frame {
161 Frame::Array(val) => {
162 // Encode the frame type prefix. For an array, it is `*`.
163 self.stream.write_u8(b'*').await?;
164
165 // Encode the length of the array.
166 self.write_decimal(val.len() as u64).await?;
167
168 // Iterate and encode each entry in the array.
169 for entry in &**val {
170 self.write_value(entry).await?;
171 }
172 }
173 // The frame type is a literal. Encode the value directly.
174 _ => self.write_value(frame).await?,
175 }
176
177 // Ensure the encoded frame is written to the socket. The calls above
178 // are to the buffered stream and writes. Calling `flush` writes the
179 // remaining contents of the buffer to the socket.
180 self.stream.flush().await
181 }
182
183 /// Write a frame literal to the stream
184 async fn write_value(&mut self, frame: &Frame) -> io::Result<()> {
185 match frame {
186 Frame::Simple(val) => {
187 self.stream.write_u8(b'+').await?;
188 self.stream.write_all(val.as_bytes()).await?;
189 self.stream.write_all(b"\r\n").await?;
190 }
191 Frame::Error(val) => {
192 self.stream.write_u8(b'-').await?;
193 self.stream.write_all(val.as_bytes()).await?;
194 self.stream.write_all(b"\r\n").await?;
195 }
196 Frame::Integer(val) => {
197 self.stream.write_u8(b':').await?;
198 self.write_decimal(*val).await?;
199 }
200 Frame::Null => {
201 self.stream.write_all(b"$-1\r\n").await?;
202 }
203 Frame::Bulk(val) => {
204 let len = val.len();
205
206 self.stream.write_u8(b'$').await?;
207 self.write_decimal(len as u64).await?;
208 self.stream.write_all(val).await?;
209 self.stream.write_all(b"\r\n").await?;
210 }
211 // Encoding an `Array` from within a value cannot be done using a
212 // recursive strategy. In general, async fns do not support
213 // recursion. Mini-redis has not needed to encode nested arrays yet,
214 // so for now it is skipped.
215 Frame::Array(_val) => unreachable!(),
216 }
217
218 Ok(())
219 }
220
221 /// Write a decimal frame to the stream
222 async fn write_decimal(&mut self, val: u64) -> io::Result<()> {
223 use std::io::Write;
224
225 // Convert the value to a string
226 let mut buf = [0u8; 20];
227 let mut buf = Cursor::new(&mut buf[..]);
228 write!(&mut buf, "{}", val)?;
229
230 let pos = buf.position() as usize;
231 self.stream.write_all(&buf.get_ref()[..pos]).await?;
232 self.stream.write_all(b"\r\n").await?;
233
234 Ok(())
235 }
236}