1use std::{
2 io::{self, Read},
3 str,
4};
5
6use crate::types::{make_extension_error, ErrorKind, RedisError, RedisResult, Value};
7
8use combine::{
9 any,
10 error::StreamError,
11 opaque,
12 parser::{
13 byte::{crlf, take_until_bytes},
14 combinator::{any_send_sync_partial_state, AnySendSyncPartialState},
15 range::{recognize, take},
16 },
17 stream::{PointerOffset, RangeStream, StreamErrorFor},
18 ParseError, Parser as _,
19};
20
21struct ResultExtend<T, E>(Result<T, E>);
22
23impl<T, E> Default for ResultExtend<T, E>
24where
25 T: Default,
26{
27 fn default() -> Self {
28 ResultExtend(Ok(T::default()))
29 }
30}
31
32impl<T, U, E> Extend<Result<U, E>> for ResultExtend<T, E>
33where
34 T: Extend<U>,
35{
36 fn extend<I>(&mut self, iter: I)
37 where
38 I: IntoIterator<Item = Result<U, E>>,
39 {
40 let mut returned_err = None;
41 if let Ok(ref mut elems) = self.0 {
42 elems.extend(iter.into_iter().scan((), |_, item| match item {
43 Ok(item) => Some(item),
44 Err(err) => {
45 returned_err = Some(err);
46 None
47 }
48 }));
49 }
50 if let Some(err) = returned_err {
51 self.0 = Err(err);
52 }
53 }
54}
55
56fn value<'a, I>(
57) -> impl combine::Parser<I, Output = RedisResult<Value>, PartialState = AnySendSyncPartialState>
58where
59 I: RangeStream<Token = u8, Range = &'a [u8]>,
60 I::Error: combine::ParseError<u8, &'a [u8], I::Position>,
61{
62 opaque!(any_send_sync_partial_state(any().then_partial(
63 move |&mut b| {
64 let line = || {
65 recognize(take_until_bytes(&b"\r\n"[..]).with(take(2).map(|_| ()))).and_then(
66 |line: &[u8]| {
67 str::from_utf8(&line[..line.len() - 2]).map_err(StreamErrorFor::<I>::other)
68 },
69 )
70 };
71
72 let status = || {
73 line().map(|line| {
74 if line == "OK" {
75 Value::Okay
76 } else {
77 Value::Status(line.into())
78 }
79 })
80 };
81
82 let int = || {
83 line().and_then(|line| match line.trim().parse::<i64>() {
84 Err(_) => Err(StreamErrorFor::<I>::message_static_message(
85 "Expected integer, got garbage",
86 )),
87 Ok(value) => Ok(value),
88 })
89 };
90
91 let data = || {
92 int().then_partial(move |size| {
93 if *size < 0 {
94 combine::value(Value::Nil).left()
95 } else {
96 take(*size as usize)
97 .map(|bs: &[u8]| Value::Data(bs.to_vec()))
98 .skip(crlf())
99 .right()
100 }
101 })
102 };
103
104 let bulk = || {
105 int().then_partial(|&mut length| {
106 if length < 0 {
107 combine::value(Value::Nil).map(Ok).left()
108 } else {
109 let length = length as usize;
110 combine::count_min_max(length, length, value())
111 .map(|result: ResultExtend<_, _>| result.0.map(Value::Bulk))
112 .right()
113 }
114 })
115 };
116
117 let error = || {
118 line().map(|line: &str| {
119 let desc = "An error was signalled by the server";
120 let mut pieces = line.splitn(2, ' ');
121 let kind = match pieces.next().unwrap() {
122 "ERR" => ErrorKind::ResponseError,
123 "EXECABORT" => ErrorKind::ExecAbortError,
124 "LOADING" => ErrorKind::BusyLoadingError,
125 "NOSCRIPT" => ErrorKind::NoScriptError,
126 "MOVED" => ErrorKind::Moved,
127 "ASK" => ErrorKind::Ask,
128 "TRYAGAIN" => ErrorKind::TryAgain,
129 "CLUSTERDOWN" => ErrorKind::ClusterDown,
130 "CROSSSLOT" => ErrorKind::CrossSlot,
131 "MASTERDOWN" => ErrorKind::MasterDown,
132 "READONLY" => ErrorKind::ReadOnly,
133 code => return make_extension_error(code, pieces.next()),
134 };
135 match pieces.next() {
136 Some(detail) => RedisError::from((kind, desc, detail.to_string())),
137 None => RedisError::from((kind, desc)),
138 }
139 })
140 };
141
142 combine::dispatch!(b;
143 b'+' => status().map(Ok),
144 b':' => int().map(|i| Ok(Value::Int(i))),
145 b'$' => data().map(Ok),
146 b'*' => bulk(),
147 b'-' => error().map(Err),
148 b => combine::unexpected_any(combine::error::Token(b))
149 )
150 }
151 )))
152}
153
154#[cfg(feature = "aio")]
155mod aio_support {
156 use super::*;
157
158 use bytes::{Buf, BytesMut};
159 use tokio::io::AsyncRead;
160 use tokio_util::codec::{Decoder, Encoder};
161
162 #[derive(Default)]
163 pub struct ValueCodec {
164 state: AnySendSyncPartialState,
165 }
166
167 impl ValueCodec {
168 fn decode_stream(
169 &mut self,
170 bytes: &mut BytesMut,
171 eof: bool,
172 ) -> RedisResult<Option<RedisResult<Value>>> {
173 let (opt, removed_len) = {
174 let buffer = &bytes[..];
175 let mut stream =
176 combine::easy::Stream(combine::stream::MaybePartialStream(buffer, !eof));
177 match combine::stream::decode_tokio(value(), &mut stream, &mut self.state) {
178 Ok(x) => x,
179 Err(err) => {
180 let err = err
181 .map_position(|pos| pos.translate_position(buffer))
182 .map_range(|range| format!("{:?}", range))
183 .to_string();
184 return Err(RedisError::from((
185 ErrorKind::ResponseError,
186 "parse error",
187 err,
188 )));
189 }
190 }
191 };
192
193 bytes.advance(removed_len);
194 match opt {
195 Some(result) => Ok(Some(result)),
196 None => Ok(None),
197 }
198 }
199 }
200
201 impl Encoder<Vec<u8>> for ValueCodec {
202 type Error = RedisError;
203 fn encode(&mut self, item: Vec<u8>, dst: &mut BytesMut) -> Result<(), Self::Error> {
204 dst.extend_from_slice(item.as_ref());
205 Ok(())
206 }
207 }
208
209 impl Decoder for ValueCodec {
210 type Item = RedisResult<Value>;
211 type Error = RedisError;
212
213 fn decode(&mut self, bytes: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
214 self.decode_stream(bytes, false)
215 }
216
217 fn decode_eof(&mut self, bytes: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
218 self.decode_stream(bytes, true)
219 }
220 }
221
222 pub async fn parse_redis_value_async<R>(
224 decoder: &mut combine::stream::Decoder<AnySendSyncPartialState, PointerOffset<[u8]>>,
225 read: &mut R,
226 ) -> RedisResult<Value>
227 where
228 R: AsyncRead + std::marker::Unpin,
229 {
230 let result = combine::decode_tokio!(*decoder, *read, value(), |input, _| {
231 combine::stream::easy::Stream::from(input)
232 });
233 match result {
234 Err(err) => Err(match err {
235 combine::stream::decoder::Error::Io { error, .. } => error.into(),
236 combine::stream::decoder::Error::Parse(err) => {
237 if err.is_unexpected_end_of_input() {
238 RedisError::from(io::Error::from(io::ErrorKind::UnexpectedEof))
239 } else {
240 let err = err
241 .map_range(|range| format!("{:?}", range))
242 .map_position(|pos| pos.translate_position(decoder.buffer()))
243 .to_string();
244 RedisError::from((ErrorKind::ResponseError, "parse error", err))
245 }
246 }
247 }),
248 Ok(result) => result,
249 }
250 }
251}
252
253#[cfg(feature = "aio")]
254#[cfg_attr(docsrs, doc(cfg(feature = "aio")))]
255pub use self::aio_support::*;
256
257pub struct Parser {
259 decoder: combine::stream::decoder::Decoder<AnySendSyncPartialState, PointerOffset<[u8]>>,
260}
261
262impl Default for Parser {
263 fn default() -> Self {
264 Parser::new()
265 }
266}
267
268impl Parser {
273 pub fn new() -> Parser {
278 Parser {
279 decoder: combine::stream::decoder::Decoder::new(),
280 }
281 }
282
283 pub fn parse_value<T: Read>(&mut self, mut reader: T) -> RedisResult<Value> {
287 let mut decoder = &mut self.decoder;
288 let result = combine::decode!(decoder, reader, value(), |input, _| {
289 combine::stream::easy::Stream::from(input)
290 });
291 match result {
292 Err(err) => Err(match err {
293 combine::stream::decoder::Error::Io { error, .. } => error.into(),
294 combine::stream::decoder::Error::Parse(err) => {
295 if err.is_unexpected_end_of_input() {
296 RedisError::from(io::Error::from(io::ErrorKind::UnexpectedEof))
297 } else {
298 let err = err
299 .map_range(|range| format!("{:?}", range))
300 .map_position(|pos| pos.translate_position(decoder.buffer()))
301 .to_string();
302 RedisError::from((ErrorKind::ResponseError, "parse error", err))
303 }
304 }
305 }),
306 Ok(result) => result,
307 }
308 }
309}
310
311pub fn parse_redis_value(bytes: &[u8]) -> RedisResult<Value> {
316 let mut parser = Parser::new();
317 parser.parse_value(bytes)
318}
319
320#[cfg(test)]
321mod tests {
322 #[cfg(feature = "aio")]
323 use super::*;
324
325 #[cfg(feature = "aio")]
326 #[test]
327 fn decode_eof_returns_none_at_eof() {
328 use tokio_util::codec::Decoder;
329 let mut codec = ValueCodec::default();
330
331 let mut bytes = bytes::BytesMut::from(&b"+GET 123\r\n"[..]);
332 assert_eq!(
333 codec.decode_eof(&mut bytes),
334 Ok(Some(Ok(parse_redis_value(b"+GET 123\r\n").unwrap())))
335 );
336 assert_eq!(codec.decode_eof(&mut bytes), Ok(None));
337 assert_eq!(codec.decode_eof(&mut bytes), Ok(None));
338 }
339}