mini_redis/
frame.rs

1//! Provides a type representing a Redis protocol frame as well as utilities for
2//! parsing frames from a byte array.
3
4use bytes::{Buf, Bytes};
5use std::convert::TryInto;
6use std::fmt;
7use std::io::Cursor;
8use std::num::TryFromIntError;
9use std::string::FromUtf8Error;
10
11/// A frame in the Redis protocol.
12#[derive(Clone, Debug)]
13pub enum Frame {
14    Simple(String),
15    Error(String),
16    Integer(u64),
17    Bulk(Bytes),
18    Null,
19    Array(Vec<Frame>),
20}
21
22#[derive(Debug)]
23pub enum Error {
24    /// Not enough data is available to parse a message
25    Incomplete,
26
27    /// Invalid message encoding
28    Other(crate::Error),
29}
30
31impl Frame {
32    /// Returns an empty array
33    pub(crate) fn array() -> Frame {
34        Frame::Array(vec![])
35    }
36
37    /// Push a "bulk" frame into the array. `self` must be an Array frame.
38    ///
39    /// # Panics
40    ///
41    /// panics if `self` is not an array
42    pub(crate) fn push_bulk(&mut self, bytes: Bytes) {
43        match self {
44            Frame::Array(vec) => {
45                vec.push(Frame::Bulk(bytes));
46            }
47            _ => panic!("not an array frame"),
48        }
49    }
50
51    /// Push an "integer" frame into the array. `self` must be an Array frame.
52    ///
53    /// # Panics
54    ///
55    /// panics if `self` is not an array
56    pub(crate) fn push_int(&mut self, value: u64) {
57        match self {
58            Frame::Array(vec) => {
59                vec.push(Frame::Integer(value));
60            }
61            _ => panic!("not an array frame"),
62        }
63    }
64
65    /// Checks if an entire message can be decoded from `src`
66    pub fn check(src: &mut Cursor<&[u8]>) -> Result<(), Error> {
67        match get_u8(src)? {
68            b'+' => {
69                get_line(src)?;
70                Ok(())
71            }
72            b'-' => {
73                get_line(src)?;
74                Ok(())
75            }
76            b':' => {
77                let _ = get_decimal(src)?;
78                Ok(())
79            }
80            b'$' => {
81                if b'-' == peek_u8(src)? {
82                    // Skip '-1\r\n'
83                    skip(src, 4)
84                } else {
85                    // Read the bulk string
86                    let len: usize = get_decimal(src)?.try_into()?;
87
88                    // skip that number of bytes + 2 (\r\n).
89                    skip(src, len + 2)
90                }
91            }
92            b'*' => {
93                let len = get_decimal(src)?;
94
95                for _ in 0..len {
96                    Frame::check(src)?;
97                }
98
99                Ok(())
100            }
101            actual => Err(format!("protocol error; invalid frame type byte `{}`", actual).into()),
102        }
103    }
104
105    /// The message has already been validated with `check`.
106    pub fn parse(src: &mut Cursor<&[u8]>) -> Result<Frame, Error> {
107        match get_u8(src)? {
108            b'+' => {
109                // Read the line and convert it to `Vec<u8>`
110                let line = get_line(src)?.to_vec();
111
112                // Convert the line to a String
113                let string = String::from_utf8(line)?;
114
115                Ok(Frame::Simple(string))
116            }
117            b'-' => {
118                // Read the line and convert it to `Vec<u8>`
119                let line = get_line(src)?.to_vec();
120
121                // Convert the line to a String
122                let string = String::from_utf8(line)?;
123
124                Ok(Frame::Error(string))
125            }
126            b':' => {
127                let len = get_decimal(src)?;
128                Ok(Frame::Integer(len))
129            }
130            b'$' => {
131                if b'-' == peek_u8(src)? {
132                    let line = get_line(src)?;
133
134                    if line != b"-1" {
135                        return Err("protocol error; invalid frame format".into());
136                    }
137
138                    Ok(Frame::Null)
139                } else {
140                    // Read the bulk string
141                    let len = get_decimal(src)?.try_into()?;
142                    let n = len + 2;
143
144                    if src.remaining() < n {
145                        return Err(Error::Incomplete);
146                    }
147
148                    let data = Bytes::copy_from_slice(&src.chunk()[..len]);
149
150                    // skip that number of bytes + 2 (\r\n).
151                    skip(src, n)?;
152
153                    Ok(Frame::Bulk(data))
154                }
155            }
156            b'*' => {
157                let len = get_decimal(src)?.try_into()?;
158                let mut out = Vec::with_capacity(len);
159
160                for _ in 0..len {
161                    out.push(Frame::parse(src)?);
162                }
163
164                Ok(Frame::Array(out))
165            }
166            _ => unimplemented!(),
167        }
168    }
169
170    /// Converts the frame to an "unexpected frame" error
171    pub(crate) fn to_error(&self) -> crate::Error {
172        format!("unexpected frame: {}", self).into()
173    }
174}
175
176impl PartialEq<&str> for Frame {
177    fn eq(&self, other: &&str) -> bool {
178        match self {
179            Frame::Simple(s) => s.eq(other),
180            Frame::Bulk(s) => s.eq(other),
181            _ => false,
182        }
183    }
184}
185
186impl fmt::Display for Frame {
187    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
188        use std::str;
189
190        match self {
191            Frame::Simple(response) => response.fmt(fmt),
192            Frame::Error(msg) => write!(fmt, "error: {}", msg),
193            Frame::Integer(num) => num.fmt(fmt),
194            Frame::Bulk(msg) => match str::from_utf8(msg) {
195                Ok(string) => string.fmt(fmt),
196                Err(_) => write!(fmt, "{:?}", msg),
197            },
198            Frame::Null => "(nil)".fmt(fmt),
199            Frame::Array(parts) => {
200                for (i, part) in parts.iter().enumerate() {
201                    if i > 0 {
202                        write!(fmt, " ")?;
203                        part.fmt(fmt)?;
204                    }
205                }
206
207                Ok(())
208            }
209        }
210    }
211}
212
213fn peek_u8(src: &mut Cursor<&[u8]>) -> Result<u8, Error> {
214    if !src.has_remaining() {
215        return Err(Error::Incomplete);
216    }
217
218    Ok(src.chunk()[0])
219}
220
221fn get_u8(src: &mut Cursor<&[u8]>) -> Result<u8, Error> {
222    if !src.has_remaining() {
223        return Err(Error::Incomplete);
224    }
225
226    Ok(src.get_u8())
227}
228
229fn skip(src: &mut Cursor<&[u8]>, n: usize) -> Result<(), Error> {
230    if src.remaining() < n {
231        return Err(Error::Incomplete);
232    }
233
234    src.advance(n);
235    Ok(())
236}
237
238/// Read a new-line terminated decimal
239fn get_decimal(src: &mut Cursor<&[u8]>) -> Result<u64, Error> {
240    use atoi::atoi;
241
242    let line = get_line(src)?;
243
244    atoi::<u64>(line).ok_or_else(|| "protocol error; invalid frame format".into())
245}
246
247/// Find a line
248fn get_line<'a>(src: &mut Cursor<&'a [u8]>) -> Result<&'a [u8], Error> {
249    // Scan the bytes directly
250    let start = src.position() as usize;
251    // Scan to the second to last byte
252    let end = src.get_ref().len() - 1;
253
254    for i in start..end {
255        if src.get_ref()[i] == b'\r' && src.get_ref()[i + 1] == b'\n' {
256            // We found a line, update the position to be *after* the \n
257            src.set_position((i + 2) as u64);
258
259            // Return the line
260            return Ok(&src.get_ref()[start..i]);
261        }
262    }
263
264    Err(Error::Incomplete)
265}
266
267impl From<String> for Error {
268    fn from(src: String) -> Error {
269        Error::Other(src.into())
270    }
271}
272
273impl From<&str> for Error {
274    fn from(src: &str) -> Error {
275        src.to_string().into()
276    }
277}
278
279impl From<FromUtf8Error> for Error {
280    fn from(_src: FromUtf8Error) -> Error {
281        "protocol error; invalid frame format".into()
282    }
283}
284
285impl From<TryFromIntError> for Error {
286    fn from(_src: TryFromIntError) -> Error {
287        "protocol error; invalid frame format".into()
288    }
289}
290
291impl std::error::Error for Error {}
292
293impl fmt::Display for Error {
294    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
295        match self {
296            Error::Incomplete => "stream ended early".fmt(fmt),
297            Error::Other(err) => err.fmt(fmt),
298        }
299    }
300}