kade_proto/pkg/
frame.rs

1use bytes::{Buf, Bytes};
2use std::convert::TryInto;
3use std::fmt;
4use std::io::Cursor;
5use std::num::TryFromIntError;
6use std::string::FromUtf8Error;
7
8#[derive(Clone, Debug)]
9pub enum Frame {
10    Simple(String),
11    Error(String),
12    Integer(u64),
13    Bulk(Bytes),
14    Null,
15    Array(Vec<Frame>),
16}
17
18#[derive(Debug)]
19pub enum Error {
20    Incomplete,
21    Other(crate::Error),
22}
23
24impl Frame {
25    pub fn array() -> Frame { Frame::Array(vec![]) }
26
27    pub fn push_bulk(&mut self, bytes: Bytes) {
28        match self {
29            Frame::Array(vec) => {
30                vec.push(Frame::Bulk(bytes));
31            }
32            _ => panic!("not an array frame"),
33        }
34    }
35
36    pub fn push_int(&mut self, value: u64) {
37        match self {
38            Frame::Array(vec) => {
39                vec.push(Frame::Integer(value));
40            }
41            _ => panic!("not an array frame"),
42        }
43    }
44
45    pub fn check(src: &mut Cursor<&[u8]>) -> Result<(), Error> {
46        match get_u8(src)? {
47            b'+' => {
48                get_line(src)?;
49                Ok(())
50            }
51            b'-' => {
52                get_line(src)?;
53                Ok(())
54            }
55            b':' => {
56                let _ = get_decimal(src)?;
57                Ok(())
58            }
59            b'$' => {
60                if b'-' == peek_u8(src)? {
61                    // Skip '-1\r\n'
62                    skip(src, 4)
63                } else {
64                    // Read the bulk string
65                    let len: usize = get_decimal(src)?.try_into()?;
66
67                    // skip that number of bytes + 2 (\r\n).
68                    skip(src, len + 2)
69                }
70            }
71            b'*' => {
72                let len = get_decimal(src)?;
73
74                for _ in 0..len {
75                    Frame::check(src)?;
76                }
77
78                Ok(())
79            }
80            actual => Err(format!("protocol error; invalid frame type byte `{}`", actual).into()),
81        }
82    }
83
84    pub fn parse(src: &mut Cursor<&[u8]>) -> Result<Frame, Error> {
85        match get_u8(src)? {
86            b'+' => {
87                let line = get_line(src)?.to_vec();
88                let string = String::from_utf8(line)?;
89
90                Ok(Frame::Simple(string))
91            }
92            b'-' => {
93                let line = get_line(src)?.to_vec();
94                let string = String::from_utf8(line)?;
95
96                Ok(Frame::Error(string))
97            }
98            b':' => {
99                let len = get_decimal(src)?;
100                Ok(Frame::Integer(len))
101            }
102            b'$' => {
103                if b'-' == peek_u8(src)? {
104                    let line = get_line(src)?;
105
106                    if line != b"-1" {
107                        return Err("protocol error; invalid frame format".into());
108                    }
109
110                    Ok(Frame::Null)
111                } else {
112                    let len = get_decimal(src)?.try_into()?;
113                    let n = len + 2;
114
115                    if src.remaining() < n {
116                        return Err(Error::Incomplete);
117                    }
118
119                    let data = Bytes::copy_from_slice(&src.chunk()[..len]);
120
121                    skip(src, n)?;
122                    Ok(Frame::Bulk(data))
123                }
124            }
125            b'*' => {
126                let len = get_decimal(src)?.try_into()?;
127                let mut out = Vec::with_capacity(len);
128
129                for _ in 0..len {
130                    out.push(Frame::parse(src)?);
131                }
132
133                Ok(Frame::Array(out))
134            }
135            _ => unimplemented!(),
136        }
137    }
138
139    pub fn to_error(&self) -> crate::Error { format!("unexpected frame: {}", self).into() }
140}
141
142impl PartialEq<&str> for Frame {
143    fn eq(&self, other: &&str) -> bool {
144        match self {
145            Frame::Simple(s) => s.eq(other),
146            Frame::Bulk(s) => s.eq(other),
147            _ => false,
148        }
149    }
150}
151
152impl fmt::Display for Frame {
153    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
154        use std::str;
155
156        match self {
157            Frame::Simple(response) => response.fmt(fmt),
158            Frame::Error(msg) => write!(fmt, "error: {}", msg),
159            Frame::Integer(num) => num.fmt(fmt),
160            Frame::Bulk(msg) => match str::from_utf8(msg) {
161                Ok(string) => string.fmt(fmt),
162                Err(_) => write!(fmt, "{:?}", msg),
163            },
164            Frame::Null => "(nil)".fmt(fmt),
165            Frame::Array(parts) => {
166                for (i, part) in parts.iter().enumerate() {
167                    if i > 0 {
168                        write!(fmt, " ")?;
169                    }
170
171                    part.fmt(fmt)?;
172                }
173
174                Ok(())
175            }
176        }
177    }
178}
179
180fn peek_u8(src: &mut Cursor<&[u8]>) -> Result<u8, Error> {
181    if !src.has_remaining() {
182        return Err(Error::Incomplete);
183    }
184
185    Ok(src.chunk()[0])
186}
187
188fn get_u8(src: &mut Cursor<&[u8]>) -> Result<u8, Error> {
189    if !src.has_remaining() {
190        return Err(Error::Incomplete);
191    }
192
193    Ok(src.get_u8())
194}
195
196fn skip(src: &mut Cursor<&[u8]>, n: usize) -> Result<(), Error> {
197    if src.remaining() < n {
198        return Err(Error::Incomplete);
199    }
200
201    src.advance(n);
202    Ok(())
203}
204
205fn get_decimal(src: &mut Cursor<&[u8]>) -> Result<u64, Error> {
206    use atoi::atoi;
207    let line = get_line(src)?;
208
209    atoi::<u64>(line).ok_or_else(|| "protocol error; invalid frame format".into())
210}
211
212fn get_line<'a>(src: &mut Cursor<&'a [u8]>) -> Result<&'a [u8], Error> {
213    let start = src.position() as usize;
214    let end = src.get_ref().len() - 1;
215
216    for i in start..end {
217        if src.get_ref()[i] == b'\r' && src.get_ref()[i + 1] == b'\n' {
218            src.set_position((i + 2) as u64);
219            return Ok(&src.get_ref()[start..i]);
220        }
221    }
222
223    Err(Error::Incomplete)
224}
225
226impl From<String> for Error {
227    fn from(src: String) -> Error { Error::Other(src.into()) }
228}
229
230impl From<&str> for Error {
231    fn from(src: &str) -> Error { src.to_string().into() }
232}
233
234impl From<FromUtf8Error> for Error {
235    fn from(_src: FromUtf8Error) -> Error { "protocol error; invalid frame format".into() }
236}
237
238impl From<TryFromIntError> for Error {
239    fn from(_src: TryFromIntError) -> Error { "protocol error; invalid frame format".into() }
240}
241
242impl std::error::Error for Error {}
243
244impl fmt::Display for Error {
245    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
246        match self {
247            Error::Incomplete => "stream ended early".fmt(fmt),
248            Error::Other(err) => err.fmt(fmt),
249        }
250    }
251}