mco_redis_rs/
parser.rs

1use std::{
2    io::{self, Read},
3    str,
4};
5
6use crate::types::{make_extension_error, ErrorKind, RedisError, RedisResult, Value};
7
8use combine::{
9    any,
10    error::StreamError,
11    opaque,
12    parser::{
13        byte::{crlf, take_until_bytes},
14        combinator::{any_send_sync_partial_state, AnySendSyncPartialState},
15        range::{recognize, take},
16    },
17    stream::{PointerOffset, RangeStream, StreamErrorFor},
18    ParseError, Parser as _,
19};
20
21struct ResultExtend<T, E>(Result<T, E>);
22
23impl<T, E> Default for ResultExtend<T, E>
24where
25    T: Default,
26{
27    fn default() -> Self {
28        ResultExtend(Ok(T::default()))
29    }
30}
31
32impl<T, U, E> Extend<Result<U, E>> for ResultExtend<T, E>
33where
34    T: Extend<U>,
35{
36    fn extend<I>(&mut self, iter: I)
37    where
38        I: IntoIterator<Item = Result<U, E>>,
39    {
40        let mut returned_err = None;
41        if let Ok(ref mut elems) = self.0 {
42            elems.extend(iter.into_iter().scan((), |_, item| match item {
43                Ok(item) => Some(item),
44                Err(err) => {
45                    returned_err = Some(err);
46                    None
47                }
48            }));
49        }
50        if let Some(err) = returned_err {
51            self.0 = Err(err);
52        }
53    }
54}
55
56fn value<'a, I>(
57) -> impl combine::Parser<I, Output = RedisResult<Value>, PartialState = AnySendSyncPartialState>
58where
59    I: RangeStream<Token = u8, Range = &'a [u8]>,
60    I::Error: combine::ParseError<u8, &'a [u8], I::Position>,
61{
62    opaque!(any_send_sync_partial_state(any().then_partial(
63        move |&mut b| {
64            let line = || {
65                recognize(take_until_bytes(&b"\r\n"[..]).with(take(2).map(|_| ()))).and_then(
66                    |line: &[u8]| {
67                        str::from_utf8(&line[..line.len() - 2]).map_err(StreamErrorFor::<I>::other)
68                    },
69                )
70            };
71
72            let status = || {
73                line().map(|line| {
74                    if line == "OK" {
75                        Value::Okay
76                    } else {
77                        Value::Status(line.into())
78                    }
79                })
80            };
81
82            let int = || {
83                line().and_then(|line| match line.trim().parse::<i64>() {
84                    Err(_) => Err(StreamErrorFor::<I>::message_static_message(
85                        "Expected integer, got garbage",
86                    )),
87                    Ok(value) => Ok(value),
88                })
89            };
90
91            let data = || {
92                int().then_partial(move |size| {
93                    if *size < 0 {
94                        combine::value(Value::Nil).left()
95                    } else {
96                        take(*size as usize)
97                            .map(|bs: &[u8]| Value::Data(bs.to_vec()))
98                            .skip(crlf())
99                            .right()
100                    }
101                })
102            };
103
104            let bulk = || {
105                int().then_partial(|&mut length| {
106                    if length < 0 {
107                        combine::value(Value::Nil).map(Ok).left()
108                    } else {
109                        let length = length as usize;
110                        combine::count_min_max(length, length, value())
111                            .map(|result: ResultExtend<_, _>| result.0.map(Value::Bulk))
112                            .right()
113                    }
114                })
115            };
116
117            let error = || {
118                line().map(|line: &str| {
119                    let desc = "An error was signalled by the server";
120                    let mut pieces = line.splitn(2, ' ');
121                    let kind = match pieces.next().unwrap() {
122                        "ERR" => ErrorKind::ResponseError,
123                        "EXECABORT" => ErrorKind::ExecAbortError,
124                        "LOADING" => ErrorKind::BusyLoadingError,
125                        "NOSCRIPT" => ErrorKind::NoScriptError,
126                        "MOVED" => ErrorKind::Moved,
127                        "ASK" => ErrorKind::Ask,
128                        "TRYAGAIN" => ErrorKind::TryAgain,
129                        "CLUSTERDOWN" => ErrorKind::ClusterDown,
130                        "CROSSSLOT" => ErrorKind::CrossSlot,
131                        "MASTERDOWN" => ErrorKind::MasterDown,
132                        "READONLY" => ErrorKind::ReadOnly,
133                        code => return make_extension_error(code, pieces.next()),
134                    };
135                    match pieces.next() {
136                        Some(detail) => RedisError::from((kind, desc, detail.to_string())),
137                        None => RedisError::from((kind, desc)),
138                    }
139                })
140            };
141
142            combine::dispatch!(b;
143                b'+' => status().map(Ok),
144                b':' => int().map(|i| Ok(Value::Int(i))),
145                b'$' => data().map(Ok),
146                b'*' => bulk(),
147                b'-' => error().map(Err),
148                b => combine::unexpected_any(combine::error::Token(b))
149            )
150        }
151    )))
152}
153
154#[cfg(feature = "aio")]
155mod aio_support {
156    use super::*;
157
158    use bytes::{Buf, BytesMut};
159    use tokio::io::AsyncRead;
160    use tokio_util::codec::{Decoder, Encoder};
161
162    #[derive(Default)]
163    pub struct ValueCodec {
164        state: AnySendSyncPartialState,
165    }
166
167    impl ValueCodec {
168        fn decode_stream(
169            &mut self,
170            bytes: &mut BytesMut,
171            eof: bool,
172        ) -> RedisResult<Option<RedisResult<Value>>> {
173            let (opt, removed_len) = {
174                let buffer = &bytes[..];
175                let mut stream =
176                    combine::easy::Stream(combine::stream::MaybePartialStream(buffer, !eof));
177                match combine::stream::decode_tokio(value(), &mut stream, &mut self.state) {
178                    Ok(x) => x,
179                    Err(err) => {
180                        let err = err
181                            .map_position(|pos| pos.translate_position(buffer))
182                            .map_range(|range| format!("{:?}", range))
183                            .to_string();
184                        return Err(RedisError::from((
185                            ErrorKind::ResponseError,
186                            "parse error",
187                            err,
188                        )));
189                    }
190                }
191            };
192
193            bytes.advance(removed_len);
194            match opt {
195                Some(result) => Ok(Some(result)),
196                None => Ok(None),
197            }
198        }
199    }
200
201    impl Encoder<Vec<u8>> for ValueCodec {
202        type Error = RedisError;
203        fn encode(&mut self, item: Vec<u8>, dst: &mut BytesMut) -> Result<(), Self::Error> {
204            dst.extend_from_slice(item.as_ref());
205            Ok(())
206        }
207    }
208
209    impl Decoder for ValueCodec {
210        type Item = RedisResult<Value>;
211        type Error = RedisError;
212
213        fn decode(&mut self, bytes: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
214            self.decode_stream(bytes, false)
215        }
216
217        fn decode_eof(&mut self, bytes: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
218            self.decode_stream(bytes, true)
219        }
220    }
221
222    /// Parses a redis value asynchronously.
223    pub async fn parse_redis_value_async<R>(
224        decoder: &mut combine::stream::Decoder<AnySendSyncPartialState, PointerOffset<[u8]>>,
225        read: &mut R,
226    ) -> RedisResult<Value>
227    where
228        R: AsyncRead + std::marker::Unpin,
229    {
230        let result = combine::decode_tokio!(*decoder, *read, value(), |input, _| {
231            combine::stream::easy::Stream::from(input)
232        });
233        match result {
234            Err(err) => Err(match err {
235                combine::stream::decoder::Error::Io { error, .. } => error.into(),
236                combine::stream::decoder::Error::Parse(err) => {
237                    if err.is_unexpected_end_of_input() {
238                        RedisError::from(io::Error::from(io::ErrorKind::UnexpectedEof))
239                    } else {
240                        let err = err
241                            .map_range(|range| format!("{:?}", range))
242                            .map_position(|pos| pos.translate_position(decoder.buffer()))
243                            .to_string();
244                        RedisError::from((ErrorKind::ResponseError, "parse error", err))
245                    }
246                }
247            }),
248            Ok(result) => result,
249        }
250    }
251}
252
253#[cfg(feature = "aio")]
254#[cfg_attr(docsrs, doc(cfg(feature = "aio")))]
255pub use self::aio_support::*;
256
257/// The internal redis response parser.
258pub struct Parser {
259    decoder: combine::stream::decoder::Decoder<AnySendSyncPartialState, PointerOffset<[u8]>>,
260}
261
262impl Default for Parser {
263    fn default() -> Self {
264        Parser::new()
265    }
266}
267
268/// The parser can be used to parse redis responses into values.  Generally
269/// you normally do not use this directly as it's already done for you by
270/// the client but in some more complex situations it might be useful to be
271/// able to parse the redis responses.
272impl Parser {
273    /// Creates a new parser that parses the data behind the reader.  More
274    /// than one value can be behind the reader in which case the parser can
275    /// be invoked multiple times.  In other words: the stream does not have
276    /// to be terminated.
277    pub fn new() -> Parser {
278        Parser {
279            decoder: combine::stream::decoder::Decoder::new(),
280        }
281    }
282
283    // public api
284
285    /// Parses synchronously into a single value from the reader.
286    pub fn parse_value<T: Read>(&mut self, mut reader: T) -> RedisResult<Value> {
287        let mut decoder = &mut self.decoder;
288        let result = combine::decode!(decoder, reader, value(), |input, _| {
289            combine::stream::easy::Stream::from(input)
290        });
291        match result {
292            Err(err) => Err(match err {
293                combine::stream::decoder::Error::Io { error, .. } => error.into(),
294                combine::stream::decoder::Error::Parse(err) => {
295                    if err.is_unexpected_end_of_input() {
296                        RedisError::from(io::Error::from(io::ErrorKind::UnexpectedEof))
297                    } else {
298                        let err = err
299                            .map_range(|range| format!("{:?}", range))
300                            .map_position(|pos| pos.translate_position(decoder.buffer()))
301                            .to_string();
302                        RedisError::from((ErrorKind::ResponseError, "parse error", err))
303                    }
304                }
305            }),
306            Ok(result) => result,
307        }
308    }
309}
310
311/// Parses bytes into a redis value.
312///
313/// This is the most straightforward way to parse something into a low
314/// level redis value instead of having to use a whole parser.
315pub fn parse_redis_value(bytes: &[u8]) -> RedisResult<Value> {
316    let mut parser = Parser::new();
317    parser.parse_value(bytes)
318}
319
320#[cfg(test)]
321mod tests {
322    #[cfg(feature = "aio")]
323    use super::*;
324
325    #[cfg(feature = "aio")]
326    #[test]
327    fn decode_eof_returns_none_at_eof() {
328        use tokio_util::codec::Decoder;
329        let mut codec = ValueCodec::default();
330
331        let mut bytes = bytes::BytesMut::from(&b"+GET 123\r\n"[..]);
332        assert_eq!(
333            codec.decode_eof(&mut bytes),
334            Ok(Some(Ok(parse_redis_value(b"+GET 123\r\n").unwrap())))
335        );
336        assert_eq!(codec.decode_eof(&mut bytes), Ok(None));
337        assert_eq!(codec.decode_eof(&mut bytes), Ok(None));
338    }
339}