contained_core/connect/
frame.rs1use crate::Error;
9use bytes::{Buf, Bytes};
10use std::convert::TryInto;
11use std::io::Cursor;
12use std::num::TryFromIntError;
13use std::string::FromUtf8Error;
14
15#[derive(Clone, Debug)]
17pub enum Frame {
18 Simple(String),
19 Error(String),
20 Integer(u64),
21 Bulk(Bytes),
22 Null,
23 Array(Vec<Frame>),
24}
25
26#[derive(Debug)]
27pub enum FrameError {
28 Incomplete,
30
31 Other(Error),
33}
34
35impl Frame {
36 pub fn array() -> Frame {
38 Frame::Array(vec![])
39 }
40
41 pub fn push_bulk(&mut self, bytes: Bytes) {
47 match self {
48 Frame::Array(vec) => {
49 vec.push(Frame::Bulk(bytes));
50 }
51 _ => panic!("not an array frame"),
52 }
53 }
54
55 pub fn push_int(&mut self, value: u64) {
61 match self {
62 Frame::Array(vec) => {
63 vec.push(Frame::Integer(value));
64 }
65 _ => panic!("not an array frame"),
66 }
67 }
68
69 pub fn check(src: &mut Cursor<&[u8]>) -> Result<(), FrameError> {
71 match get_u8(src)? {
72 b'+' => {
73 get_line(src)?;
74 Ok(())
75 }
76 b'-' => {
77 get_line(src)?;
78 Ok(())
79 }
80 b':' => {
81 let _ = get_decimal(src)?;
82 Ok(())
83 }
84 b'$' => {
85 if b'-' == peek_u8(src)? {
86 skip(src, 4)
88 } else {
89 let len: usize = get_decimal(src)?.try_into()?;
91
92 skip(src, len + 2)
94 }
95 }
96 b'*' => {
97 let len = get_decimal(src)?;
98
99 for _ in 0..len {
100 Frame::check(src)?;
101 }
102
103 Ok(())
104 }
105 actual => Err(format!("protocol error; invalid frame type byte `{}`", actual).into()),
106 }
107 }
108
109 pub fn parse(src: &mut Cursor<&[u8]>) -> Result<Frame, FrameError> {
111 match get_u8(src)? {
112 b'+' => {
113 let line = get_line(src)?.to_vec();
115
116 let string = String::from_utf8(line)?;
118
119 Ok(Frame::Simple(string))
120 }
121 b'-' => {
122 let line = get_line(src)?.to_vec();
124
125 let string = String::from_utf8(line)?;
127
128 Ok(Frame::Error(string))
129 }
130 b':' => {
131 let len = get_decimal(src)?;
132 Ok(Frame::Integer(len))
133 }
134 b'$' => {
135 if b'-' == peek_u8(src)? {
136 let line = get_line(src)?;
137
138 if line != b"-1" {
139 return Err("protocol error; invalid frame format".into());
140 }
141
142 Ok(Frame::Null)
143 } else {
144 let len = get_decimal(src)?.try_into()?;
146 let n = len + 2;
147
148 if src.remaining() < n {
149 return Err(FrameError::Incomplete);
150 }
151
152 let data = Bytes::copy_from_slice(&src.get_ref()[..len]);
153
154 skip(src, n)?;
156
157 Ok(Frame::Bulk(data))
158 }
159 }
160 b'*' => {
161 let len = get_decimal(src)?.try_into()?;
162 let mut out = Vec::with_capacity(len);
163
164 for _ in 0..len {
165 out.push(Frame::parse(src)?);
166 }
167
168 Ok(Frame::Array(out))
169 }
170 _ => unimplemented!(),
171 }
172 }
173
174 pub fn to_error(&self) -> crate::Error {
176 format!("unexpected frame: {}", self).into()
177 }
178}
179
180impl PartialEq<&str> for Frame {
181 fn eq(&self, other: &&str) -> bool {
182 match self {
183 Frame::Simple(s) => s.eq(other),
184 Frame::Bulk(s) => s.eq(other),
185 _ => false,
186 }
187 }
188}
189
190impl std::fmt::Display for Frame {
191 fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
192 use std::str;
193
194 match self {
195 Frame::Simple(response) => response.fmt(fmt),
196 Frame::Error(msg) => write!(fmt, "error: {}", msg),
197 Frame::Integer(num) => num.fmt(fmt),
198 Frame::Bulk(msg) => match str::from_utf8(msg) {
199 Ok(string) => string.fmt(fmt),
200 Err(_) => write!(fmt, "{:?}", msg),
201 },
202 Frame::Null => "(nil)".fmt(fmt),
203 Frame::Array(parts) => {
204 for (i, part) in parts.iter().enumerate() {
205 if i > 0 {
206 write!(fmt, " ")?;
207 part.fmt(fmt)?;
208 }
209 }
210
211 Ok(())
212 }
213 }
214 }
215}
216
217fn peek_u8(src: &mut Cursor<&[u8]>) -> Result<u8, FrameError> {
218 if !src.has_remaining() {
219 return Err(FrameError::Incomplete);
220 }
221
222 Ok(src.get_ref()[0])
223}
224
225fn get_u8(src: &mut Cursor<&[u8]>) -> Result<u8, FrameError> {
226 if !src.has_remaining() {
227 return Err(FrameError::Incomplete);
228 }
229
230 Ok(src.get_u8())
231}
232
233fn skip(src: &mut Cursor<&[u8]>, n: usize) -> Result<(), FrameError> {
234 if src.remaining() < n {
235 return Err(FrameError::Incomplete);
236 }
237
238 src.advance(n);
239 Ok(())
240}
241
242fn get_decimal(src: &mut Cursor<&[u8]>) -> Result<u64, FrameError> {
244 use atoi::atoi;
245
246 let line = get_line(src)?;
247
248 atoi::<u64>(line).ok_or_else(|| "protocol error; invalid frame format".into())
249}
250
251fn get_line<'a>(src: &mut Cursor<&'a [u8]>) -> Result<&'a [u8], FrameError> {
253 let start = src.position() as usize;
255 let end = src.get_ref().len() - 1;
257
258 for i in start..end {
259 if src.get_ref()[i] == b'\r' && src.get_ref()[i + 1] == b'\n' {
260 src.set_position((i + 2) as u64);
262
263 return Ok(&src.get_ref()[start..i]);
265 }
266 }
267
268 Err(FrameError::Incomplete)
269}
270
271impl From<String> for FrameError {
272 fn from(src: String) -> FrameError {
273 FrameError::Other(src.into())
274 }
275}
276
277impl From<&str> for FrameError {
278 fn from(src: &str) -> FrameError {
279 src.to_string().into()
280 }
281}
282
283impl From<FrameError> for crate::Error {
284 fn from(src: FrameError) -> crate::Error {
285 crate::Error::from(src.to_string())
286 }
287}
288
289impl From<Box<dyn std::error::Error>> for FrameError {
290 fn from(src: Box<dyn std::error::Error>) -> FrameError {
291 FrameError::Other(src.into())
292 }
293}
294
295impl From<Box<dyn std::error::Error + Send + Sync>> for FrameError {
296 fn from(src: Box<dyn std::error::Error + Send + Sync>) -> FrameError {
297 FrameError::Other(src.into())
298 }
299}
300
301impl From<FromUtf8Error> for FrameError {
302 fn from(_src: FromUtf8Error) -> FrameError {
303 "protocol error; invalid frame format".into()
304 }
305}
306
307impl From<TryFromIntError> for FrameError {
308 fn from(_src: TryFromIntError) -> FrameError {
309 "protocol error; invalid frame format".into()
310 }
311}
312
313impl std::error::Error for FrameError {}
314
315impl std::fmt::Display for FrameError {
316 fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
317 match self {
318 FrameError::Incomplete => "stream ended early".fmt(fmt),
319 FrameError::Other(err) => err.fmt(fmt),
320 }
321 }
322}