1use bytes::{Buf, Bytes};
5use std::convert::TryInto;
6use std::fmt;
7use std::io::Cursor;
8use std::num::TryFromIntError;
9use std::string::FromUtf8Error;
10
11#[derive(Clone, Debug)]
13pub enum Frame {
14 Simple(String),
15 Error(String),
16 Integer(u64),
17 Bulk(Bytes),
18 Null,
19 Array(Vec<Frame>),
20}
21
22#[derive(Debug)]
23pub enum Error {
24 Incomplete,
26
27 Other(crate::Error),
29}
30
31impl Frame {
32 pub(crate) fn array() -> Frame {
34 Frame::Array(vec![])
35 }
36
37 pub(crate) fn push_bulk(&mut self, bytes: Bytes) {
43 match self {
44 Frame::Array(vec) => {
45 vec.push(Frame::Bulk(bytes));
46 }
47 _ => panic!("not an array frame"),
48 }
49 }
50
51 pub(crate) fn push_int(&mut self, value: u64) {
57 match self {
58 Frame::Array(vec) => {
59 vec.push(Frame::Integer(value));
60 }
61 _ => panic!("not an array frame"),
62 }
63 }
64
65 pub fn check(src: &mut Cursor<&[u8]>) -> Result<(), Error> {
67 match get_u8(src)? {
68 b'+' => {
69 get_line(src)?;
70 Ok(())
71 }
72 b'-' => {
73 get_line(src)?;
74 Ok(())
75 }
76 b':' => {
77 let _ = get_decimal(src)?;
78 Ok(())
79 }
80 b'$' => {
81 if b'-' == peek_u8(src)? {
82 skip(src, 4)
84 } else {
85 let len: usize = get_decimal(src)?.try_into()?;
87
88 skip(src, len + 2)
90 }
91 }
92 b'*' => {
93 let len = get_decimal(src)?;
94
95 for _ in 0..len {
96 Frame::check(src)?;
97 }
98
99 Ok(())
100 }
101 actual => Err(format!("protocol error; invalid frame type byte `{}`", actual).into()),
102 }
103 }
104
105 pub fn parse(src: &mut Cursor<&[u8]>) -> Result<Frame, Error> {
107 match get_u8(src)? {
108 b'+' => {
109 let line = get_line(src)?.to_vec();
111
112 let string = String::from_utf8(line)?;
114
115 Ok(Frame::Simple(string))
116 }
117 b'-' => {
118 let line = get_line(src)?.to_vec();
120
121 let string = String::from_utf8(line)?;
123
124 Ok(Frame::Error(string))
125 }
126 b':' => {
127 let len = get_decimal(src)?;
128 Ok(Frame::Integer(len))
129 }
130 b'$' => {
131 if b'-' == peek_u8(src)? {
132 let line = get_line(src)?;
133
134 if line != b"-1" {
135 return Err("protocol error; invalid frame format".into());
136 }
137
138 Ok(Frame::Null)
139 } else {
140 let len = get_decimal(src)?.try_into()?;
142 let n = len + 2;
143
144 if src.remaining() < n {
145 return Err(Error::Incomplete);
146 }
147
148 let data = Bytes::copy_from_slice(&src.chunk()[..len]);
149
150 skip(src, n)?;
152
153 Ok(Frame::Bulk(data))
154 }
155 }
156 b'*' => {
157 let len = get_decimal(src)?.try_into()?;
158 let mut out = Vec::with_capacity(len);
159
160 for _ in 0..len {
161 out.push(Frame::parse(src)?);
162 }
163
164 Ok(Frame::Array(out))
165 }
166 _ => unimplemented!(),
167 }
168 }
169
170 pub(crate) fn to_error(&self) -> crate::Error {
172 format!("unexpected frame: {}", self).into()
173 }
174}
175
176impl PartialEq<&str> for Frame {
177 fn eq(&self, other: &&str) -> bool {
178 match self {
179 Frame::Simple(s) => s.eq(other),
180 Frame::Bulk(s) => s.eq(other),
181 _ => false,
182 }
183 }
184}
185
186impl fmt::Display for Frame {
187 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
188 use std::str;
189
190 match self {
191 Frame::Simple(response) => response.fmt(fmt),
192 Frame::Error(msg) => write!(fmt, "error: {}", msg),
193 Frame::Integer(num) => num.fmt(fmt),
194 Frame::Bulk(msg) => match str::from_utf8(msg) {
195 Ok(string) => string.fmt(fmt),
196 Err(_) => write!(fmt, "{:?}", msg),
197 },
198 Frame::Null => "(nil)".fmt(fmt),
199 Frame::Array(parts) => {
200 for (i, part) in parts.iter().enumerate() {
201 if i > 0 {
202 write!(fmt, " ")?;
203 part.fmt(fmt)?;
204 }
205 }
206
207 Ok(())
208 }
209 }
210 }
211}
212
213fn peek_u8(src: &mut Cursor<&[u8]>) -> Result<u8, Error> {
214 if !src.has_remaining() {
215 return Err(Error::Incomplete);
216 }
217
218 Ok(src.chunk()[0])
219}
220
221fn get_u8(src: &mut Cursor<&[u8]>) -> Result<u8, Error> {
222 if !src.has_remaining() {
223 return Err(Error::Incomplete);
224 }
225
226 Ok(src.get_u8())
227}
228
229fn skip(src: &mut Cursor<&[u8]>, n: usize) -> Result<(), Error> {
230 if src.remaining() < n {
231 return Err(Error::Incomplete);
232 }
233
234 src.advance(n);
235 Ok(())
236}
237
238fn get_decimal(src: &mut Cursor<&[u8]>) -> Result<u64, Error> {
240 use atoi::atoi;
241
242 let line = get_line(src)?;
243
244 atoi::<u64>(line).ok_or_else(|| "protocol error; invalid frame format".into())
245}
246
247fn get_line<'a>(src: &mut Cursor<&'a [u8]>) -> Result<&'a [u8], Error> {
249 let start = src.position() as usize;
251 let end = src.get_ref().len() - 1;
253
254 for i in start..end {
255 if src.get_ref()[i] == b'\r' && src.get_ref()[i + 1] == b'\n' {
256 src.set_position((i + 2) as u64);
258
259 return Ok(&src.get_ref()[start..i]);
261 }
262 }
263
264 Err(Error::Incomplete)
265}
266
267impl From<String> for Error {
268 fn from(src: String) -> Error {
269 Error::Other(src.into())
270 }
271}
272
273impl From<&str> for Error {
274 fn from(src: &str) -> Error {
275 src.to_string().into()
276 }
277}
278
279impl From<FromUtf8Error> for Error {
280 fn from(_src: FromUtf8Error) -> Error {
281 "protocol error; invalid frame format".into()
282 }
283}
284
285impl From<TryFromIntError> for Error {
286 fn from(_src: TryFromIntError) -> Error {
287 "protocol error; invalid frame format".into()
288 }
289}
290
291impl std::error::Error for Error {}
292
293impl fmt::Display for Error {
294 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
295 match self {
296 Error::Incomplete => "stream ended early".fmt(fmt),
297 Error::Other(err) => err.fmt(fmt),
298 }
299 }
300}