contained_core/connect/
frame.rs

1/*
2    Appellation: frame <module>
3    Contrib: FL03 <jo3mccain@icloud.com>
4    Description: A frame is used to describe units of data shared between two peers. Implementing a custom framing layer is useful for managing the various types of data that can be sent between peers.
5        This module provides a `Frame` enum that can be used to describe the various types of data that can be sent between peers. The `Frame` enum is used to implement a custom framing layer for
6        the `Connection` type.
7*/
8use crate::Error;
9use bytes::{Buf, Bytes};
10use std::convert::TryInto;
11use std::io::Cursor;
12use std::num::TryFromIntError;
13use std::string::FromUtf8Error;
14
15/// A frame in the Redis protocol.
16#[derive(Clone, Debug)]
17pub enum Frame {
18    Simple(String),
19    Error(String),
20    Integer(u64),
21    Bulk(Bytes),
22    Null,
23    Array(Vec<Frame>),
24}
25
26#[derive(Debug)]
27pub enum FrameError {
28    /// Not enough data is available to parse a message
29    Incomplete,
30
31    /// Invalid message encoding
32    Other(Error),
33}
34
35impl Frame {
36    /// Returns an empty array
37    pub fn array() -> Frame {
38        Frame::Array(vec![])
39    }
40
41    /// Push a "bulk" frame into the array. `self` must be an Array frame.
42    ///
43    /// # Panics
44    ///
45    /// panics if `self` is not an array
46    pub fn push_bulk(&mut self, bytes: Bytes) {
47        match self {
48            Frame::Array(vec) => {
49                vec.push(Frame::Bulk(bytes));
50            }
51            _ => panic!("not an array frame"),
52        }
53    }
54
55    /// Push an "integer" frame into the array. `self` must be an Array frame.
56    ///
57    /// # Panics
58    ///
59    /// panics if `self` is not an array
60    pub fn push_int(&mut self, value: u64) {
61        match self {
62            Frame::Array(vec) => {
63                vec.push(Frame::Integer(value));
64            }
65            _ => panic!("not an array frame"),
66        }
67    }
68
69    /// Checks if an entire message can be decoded from `src`
70    pub fn check(src: &mut Cursor<&[u8]>) -> Result<(), FrameError> {
71        match get_u8(src)? {
72            b'+' => {
73                get_line(src)?;
74                Ok(())
75            }
76            b'-' => {
77                get_line(src)?;
78                Ok(())
79            }
80            b':' => {
81                let _ = get_decimal(src)?;
82                Ok(())
83            }
84            b'$' => {
85                if b'-' == peek_u8(src)? {
86                    // Skip '-1\r\n'
87                    skip(src, 4)
88                } else {
89                    // Read the bulk string
90                    let len: usize = get_decimal(src)?.try_into()?;
91
92                    // skip that number of bytes + 2 (\r\n).
93                    skip(src, len + 2)
94                }
95            }
96            b'*' => {
97                let len = get_decimal(src)?;
98
99                for _ in 0..len {
100                    Frame::check(src)?;
101                }
102
103                Ok(())
104            }
105            actual => Err(format!("protocol error; invalid frame type byte `{}`", actual).into()),
106        }
107    }
108
109    /// The message has already been validated with `scan`.
110    pub fn parse(src: &mut Cursor<&[u8]>) -> Result<Frame, FrameError> {
111        match get_u8(src)? {
112            b'+' => {
113                // Read the line and convert it to `Vec<u8>`
114                let line = get_line(src)?.to_vec();
115
116                // Convert the line to a String
117                let string = String::from_utf8(line)?;
118
119                Ok(Frame::Simple(string))
120            }
121            b'-' => {
122                // Read the line and convert it to `Vec<u8>`
123                let line = get_line(src)?.to_vec();
124
125                // Convert the line to a String
126                let string = String::from_utf8(line)?;
127
128                Ok(Frame::Error(string))
129            }
130            b':' => {
131                let len = get_decimal(src)?;
132                Ok(Frame::Integer(len))
133            }
134            b'$' => {
135                if b'-' == peek_u8(src)? {
136                    let line = get_line(src)?;
137
138                    if line != b"-1" {
139                        return Err("protocol error; invalid frame format".into());
140                    }
141
142                    Ok(Frame::Null)
143                } else {
144                    // Read the bulk string
145                    let len = get_decimal(src)?.try_into()?;
146                    let n = len + 2;
147
148                    if src.remaining() < n {
149                        return Err(FrameError::Incomplete);
150                    }
151
152                    let data = Bytes::copy_from_slice(&src.get_ref()[..len]);
153
154                    // skip that number of bytes + 2 (\r\n).
155                    skip(src, n)?;
156
157                    Ok(Frame::Bulk(data))
158                }
159            }
160            b'*' => {
161                let len = get_decimal(src)?.try_into()?;
162                let mut out = Vec::with_capacity(len);
163
164                for _ in 0..len {
165                    out.push(Frame::parse(src)?);
166                }
167
168                Ok(Frame::Array(out))
169            }
170            _ => unimplemented!(),
171        }
172    }
173
174    /// Converts the frame to an "unexpected frame" error
175    pub fn to_error(&self) -> crate::Error {
176        format!("unexpected frame: {}", self).into()
177    }
178}
179
180impl PartialEq<&str> for Frame {
181    fn eq(&self, other: &&str) -> bool {
182        match self {
183            Frame::Simple(s) => s.eq(other),
184            Frame::Bulk(s) => s.eq(other),
185            _ => false,
186        }
187    }
188}
189
190impl std::fmt::Display for Frame {
191    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
192        use std::str;
193
194        match self {
195            Frame::Simple(response) => response.fmt(fmt),
196            Frame::Error(msg) => write!(fmt, "error: {}", msg),
197            Frame::Integer(num) => num.fmt(fmt),
198            Frame::Bulk(msg) => match str::from_utf8(msg) {
199                Ok(string) => string.fmt(fmt),
200                Err(_) => write!(fmt, "{:?}", msg),
201            },
202            Frame::Null => "(nil)".fmt(fmt),
203            Frame::Array(parts) => {
204                for (i, part) in parts.iter().enumerate() {
205                    if i > 0 {
206                        write!(fmt, " ")?;
207                        part.fmt(fmt)?;
208                    }
209                }
210
211                Ok(())
212            }
213        }
214    }
215}
216
217fn peek_u8(src: &mut Cursor<&[u8]>) -> Result<u8, FrameError> {
218    if !src.has_remaining() {
219        return Err(FrameError::Incomplete);
220    }
221
222    Ok(src.get_ref()[0])
223}
224
225fn get_u8(src: &mut Cursor<&[u8]>) -> Result<u8, FrameError> {
226    if !src.has_remaining() {
227        return Err(FrameError::Incomplete);
228    }
229
230    Ok(src.get_u8())
231}
232
233fn skip(src: &mut Cursor<&[u8]>, n: usize) -> Result<(), FrameError> {
234    if src.remaining() < n {
235        return Err(FrameError::Incomplete);
236    }
237
238    src.advance(n);
239    Ok(())
240}
241
242/// Read a new-line terminated decimal
243fn get_decimal(src: &mut Cursor<&[u8]>) -> Result<u64, FrameError> {
244    use atoi::atoi;
245
246    let line = get_line(src)?;
247
248    atoi::<u64>(line).ok_or_else(|| "protocol error; invalid frame format".into())
249}
250
251/// Find a line
252fn get_line<'a>(src: &mut Cursor<&'a [u8]>) -> Result<&'a [u8], FrameError> {
253    // Scan the bytes directly
254    let start = src.position() as usize;
255    // Scan to the second to last byte
256    let end = src.get_ref().len() - 1;
257
258    for i in start..end {
259        if src.get_ref()[i] == b'\r' && src.get_ref()[i + 1] == b'\n' {
260            // We found a line, update the position to be *after* the \n
261            src.set_position((i + 2) as u64);
262
263            // Return the line
264            return Ok(&src.get_ref()[start..i]);
265        }
266    }
267
268    Err(FrameError::Incomplete)
269}
270
271impl From<String> for FrameError {
272    fn from(src: String) -> FrameError {
273        FrameError::Other(src.into())
274    }
275}
276
277impl From<&str> for FrameError {
278    fn from(src: &str) -> FrameError {
279        src.to_string().into()
280    }
281}
282
283impl From<FrameError> for crate::Error {
284    fn from(src: FrameError) -> crate::Error {
285        crate::Error::from(src.to_string())
286    }
287}
288
289impl From<Box<dyn std::error::Error>> for FrameError {
290    fn from(src: Box<dyn std::error::Error>) -> FrameError {
291        FrameError::Other(src.into())
292    }
293}
294
295impl From<Box<dyn std::error::Error + Send + Sync>> for FrameError {
296    fn from(src: Box<dyn std::error::Error + Send + Sync>) -> FrameError {
297        FrameError::Other(src.into())
298    }
299}
300
301impl From<FromUtf8Error> for FrameError {
302    fn from(_src: FromUtf8Error) -> FrameError {
303        "protocol error; invalid frame format".into()
304    }
305}
306
307impl From<TryFromIntError> for FrameError {
308    fn from(_src: TryFromIntError) -> FrameError {
309        "protocol error; invalid frame format".into()
310    }
311}
312
313impl std::error::Error for FrameError {}
314
315impl std::fmt::Display for FrameError {
316    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
317        match self {
318            FrameError::Incomplete => "stream ended early".fmt(fmt),
319            FrameError::Other(err) => err.fmt(fmt),
320        }
321    }
322}