1use bytes::{Buf, Bytes};
2use std::convert::TryInto;
3use std::fmt;
4use std::io::Cursor;
5use std::num::TryFromIntError;
6use std::string::FromUtf8Error;
7
8#[derive(Clone, Debug)]
9pub enum Frame {
10 Simple(String),
11 Error(String),
12 Integer(u64),
13 Bulk(Bytes),
14 Null,
15 Array(Vec<Frame>),
16}
17
18#[derive(Debug)]
19pub enum Error {
20 Incomplete,
21 Other(crate::Error),
22}
23
24impl Frame {
25 pub fn array() -> Frame { Frame::Array(vec![]) }
26
27 pub fn push_bulk(&mut self, bytes: Bytes) {
28 match self {
29 Frame::Array(vec) => {
30 vec.push(Frame::Bulk(bytes));
31 }
32 _ => panic!("not an array frame"),
33 }
34 }
35
36 pub fn push_int(&mut self, value: u64) {
37 match self {
38 Frame::Array(vec) => {
39 vec.push(Frame::Integer(value));
40 }
41 _ => panic!("not an array frame"),
42 }
43 }
44
45 pub fn check(src: &mut Cursor<&[u8]>) -> Result<(), Error> {
46 match get_u8(src)? {
47 b'+' => {
48 get_line(src)?;
49 Ok(())
50 }
51 b'-' => {
52 get_line(src)?;
53 Ok(())
54 }
55 b':' => {
56 let _ = get_decimal(src)?;
57 Ok(())
58 }
59 b'$' => {
60 if b'-' == peek_u8(src)? {
61 skip(src, 4)
63 } else {
64 let len: usize = get_decimal(src)?.try_into()?;
66
67 skip(src, len + 2)
69 }
70 }
71 b'*' => {
72 let len = get_decimal(src)?;
73
74 for _ in 0..len {
75 Frame::check(src)?;
76 }
77
78 Ok(())
79 }
80 actual => Err(format!("protocol error; invalid frame type byte `{}`", actual).into()),
81 }
82 }
83
84 pub fn parse(src: &mut Cursor<&[u8]>) -> Result<Frame, Error> {
85 match get_u8(src)? {
86 b'+' => {
87 let line = get_line(src)?.to_vec();
88 let string = String::from_utf8(line)?;
89
90 Ok(Frame::Simple(string))
91 }
92 b'-' => {
93 let line = get_line(src)?.to_vec();
94 let string = String::from_utf8(line)?;
95
96 Ok(Frame::Error(string))
97 }
98 b':' => {
99 let len = get_decimal(src)?;
100 Ok(Frame::Integer(len))
101 }
102 b'$' => {
103 if b'-' == peek_u8(src)? {
104 let line = get_line(src)?;
105
106 if line != b"-1" {
107 return Err("protocol error; invalid frame format".into());
108 }
109
110 Ok(Frame::Null)
111 } else {
112 let len = get_decimal(src)?.try_into()?;
113 let n = len + 2;
114
115 if src.remaining() < n {
116 return Err(Error::Incomplete);
117 }
118
119 let data = Bytes::copy_from_slice(&src.chunk()[..len]);
120
121 skip(src, n)?;
122 Ok(Frame::Bulk(data))
123 }
124 }
125 b'*' => {
126 let len = get_decimal(src)?.try_into()?;
127 let mut out = Vec::with_capacity(len);
128
129 for _ in 0..len {
130 out.push(Frame::parse(src)?);
131 }
132
133 Ok(Frame::Array(out))
134 }
135 _ => unimplemented!(),
136 }
137 }
138
139 pub fn to_error(&self) -> crate::Error { format!("unexpected frame: {}", self).into() }
140}
141
142impl PartialEq<&str> for Frame {
143 fn eq(&self, other: &&str) -> bool {
144 match self {
145 Frame::Simple(s) => s.eq(other),
146 Frame::Bulk(s) => s.eq(other),
147 _ => false,
148 }
149 }
150}
151
152impl fmt::Display for Frame {
153 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
154 use std::str;
155
156 match self {
157 Frame::Simple(response) => response.fmt(fmt),
158 Frame::Error(msg) => write!(fmt, "error: {}", msg),
159 Frame::Integer(num) => num.fmt(fmt),
160 Frame::Bulk(msg) => match str::from_utf8(msg) {
161 Ok(string) => string.fmt(fmt),
162 Err(_) => write!(fmt, "{:?}", msg),
163 },
164 Frame::Null => "(nil)".fmt(fmt),
165 Frame::Array(parts) => {
166 for (i, part) in parts.iter().enumerate() {
167 if i > 0 {
168 write!(fmt, " ")?;
169 }
170
171 part.fmt(fmt)?;
172 }
173
174 Ok(())
175 }
176 }
177 }
178}
179
180fn peek_u8(src: &mut Cursor<&[u8]>) -> Result<u8, Error> {
181 if !src.has_remaining() {
182 return Err(Error::Incomplete);
183 }
184
185 Ok(src.chunk()[0])
186}
187
188fn get_u8(src: &mut Cursor<&[u8]>) -> Result<u8, Error> {
189 if !src.has_remaining() {
190 return Err(Error::Incomplete);
191 }
192
193 Ok(src.get_u8())
194}
195
196fn skip(src: &mut Cursor<&[u8]>, n: usize) -> Result<(), Error> {
197 if src.remaining() < n {
198 return Err(Error::Incomplete);
199 }
200
201 src.advance(n);
202 Ok(())
203}
204
205fn get_decimal(src: &mut Cursor<&[u8]>) -> Result<u64, Error> {
206 use atoi::atoi;
207 let line = get_line(src)?;
208
209 atoi::<u64>(line).ok_or_else(|| "protocol error; invalid frame format".into())
210}
211
212fn get_line<'a>(src: &mut Cursor<&'a [u8]>) -> Result<&'a [u8], Error> {
213 let start = src.position() as usize;
214 let end = src.get_ref().len() - 1;
215
216 for i in start..end {
217 if src.get_ref()[i] == b'\r' && src.get_ref()[i + 1] == b'\n' {
218 src.set_position((i + 2) as u64);
219 return Ok(&src.get_ref()[start..i]);
220 }
221 }
222
223 Err(Error::Incomplete)
224}
225
226impl From<String> for Error {
227 fn from(src: String) -> Error { Error::Other(src.into()) }
228}
229
230impl From<&str> for Error {
231 fn from(src: &str) -> Error { src.to_string().into() }
232}
233
234impl From<FromUtf8Error> for Error {
235 fn from(_src: FromUtf8Error) -> Error { "protocol error; invalid frame format".into() }
236}
237
238impl From<TryFromIntError> for Error {
239 fn from(_src: TryFromIntError) -> Error { "protocol error; invalid frame format".into() }
240}
241
242impl std::error::Error for Error {}
243
244impl fmt::Display for Error {
245 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
246 match self {
247 Error::Incomplete => "stream ended early".fmt(fmt),
248 Error::Other(err) => err.fmt(fmt),
249 }
250 }
251}