Skip to main content

destream_json/
de.rs

1//! Decode a JSON stream to a Rust data structure.
2
3use std::collections::HashSet;
4use std::fmt;
5use std::str::FromStr;
6
7use async_recursion::async_recursion;
8use bytes::{BufMut, Bytes};
9use destream::{de, FromStream, Visitor};
10use futures::{
11    stream::{Fuse, FusedStream, Stream, StreamExt, TryStreamExt},
12    FutureExt as _,
13};
14
15#[cfg(feature = "tokio-io")]
16use tokio::io::{AsyncRead, AsyncReadExt, BufReader};
17
18use crate::constants::*;
19
20const SNIPPET_LEN: usize = 50;
21
22/// Methods common to any decodable [`Stream`]
23#[trait_variant::make(Send)]
24pub trait Read: Send + Unpin {
25    /// Read the next chunk of [`Bytes`] in this [`Stream`].
26    async fn next(&mut self) -> Option<Result<Bytes, Error>>;
27
28    /// Return `true` if there are no more contents to be read from this [`Stream`].
29    fn is_terminated(&self) -> bool;
30}
31
32/// A decodable [`Stream`]
33pub struct SourceStream<S> {
34    source: Fuse<S>,
35}
36
37impl<S: Stream<Item = Result<Bytes, Error>> + Send + Unpin> Read for SourceStream<S> {
38    async fn next(&mut self) -> Option<Result<Bytes, Error>> {
39        self.source.next().await
40    }
41
42    fn is_terminated(&self) -> bool {
43        self.source.is_terminated()
44    }
45}
46
47impl<S: Stream> From<S> for SourceStream<S> {
48    fn from(source: S) -> Self {
49        Self {
50            source: source.fuse(),
51        }
52    }
53}
54
55#[cfg(feature = "tokio-io")]
56pub struct SourceReader<R: AsyncRead> {
57    reader: BufReader<R>,
58    terminated: bool,
59}
60
61#[cfg(feature = "tokio-io")]
62impl<R: AsyncRead + Send + Unpin> Read for SourceReader<R> {
63    async fn next(&mut self) -> Option<Result<Bytes, Error>> {
64        let mut chunk = Vec::new();
65        match self.reader.read_buf(&mut chunk).await {
66            Ok(0) => {
67                self.terminated = true;
68                Some(Ok(chunk.into()))
69            }
70            Ok(size) => {
71                debug_assert_eq!(chunk.len(), size);
72                Some(Ok(chunk.into()))
73            }
74            Err(cause) => Some(Err(de::Error::custom(format!("io error: {}", cause)))),
75        }
76    }
77
78    fn is_terminated(&self) -> bool {
79        self.terminated
80    }
81}
82
83#[cfg(feature = "tokio-io")]
84impl<R: AsyncRead> From<R> for SourceReader<R> {
85    fn from(reader: R) -> Self {
86        Self {
87            reader: BufReader::new(reader),
88            terminated: false,
89        }
90    }
91}
92
93/// An error encountered while decoding a JSON stream.
94#[derive(PartialEq)]
95pub struct Error {
96    message: String,
97}
98
99impl Error {
100    fn invalid_utf8<I: fmt::Display>(info: I) -> Self {
101        de::Error::custom(format!("invalid UTF-8: {}", info))
102    }
103
104    fn unexpected_end() -> Self {
105        de::Error::custom("unexpected end of stream")
106    }
107}
108
109impl std::error::Error for Error {}
110
111impl de::Error for Error {
112    fn custom<T: fmt::Display>(msg: T) -> Self {
113        Self {
114            message: msg.to_string(),
115        }
116    }
117}
118
119impl fmt::Debug for Error {
120    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
121        fmt::Display::fmt(self, f)
122    }
123}
124
125impl fmt::Display for Error {
126    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
127        fmt::Display::fmt(&self.message, f)
128    }
129}
130
131struct MapAccess<'a, S> {
132    decoder: &'a mut Decoder<S>,
133    size_hint: Option<usize>,
134    done: bool,
135}
136
137impl<'a, S: Read + 'a> MapAccess<'a, S> {
138    async fn new(
139        decoder: &'a mut Decoder<S>,
140        size_hint: Option<usize>,
141    ) -> Result<MapAccess<'a, S>, Error> {
142        decoder.expect_whitespace().await?;
143
144        decoder.expect_delimiter(MAP_BEGIN).await?;
145        decoder.expect_whitespace().await?;
146
147        let done = decoder.maybe_delimiter(MAP_END).await?;
148
149        Ok(MapAccess {
150            decoder,
151            size_hint,
152            done,
153        })
154    }
155}
156
157impl<'a, S: Read + 'a> de::MapAccess for MapAccess<'a, S> {
158    type Error = Error;
159
160    async fn next_key<K: FromStream>(&mut self, context: K::Context) -> Result<Option<K>, Error> {
161        if self.done {
162            return Ok(None);
163        }
164
165        self.decoder.expect_whitespace().await?;
166        let key = K::from_stream(context, self.decoder).await?;
167
168        self.decoder.expect_whitespace().await?;
169        self.decoder.expect_delimiter(COLON).await?;
170        self.decoder.expect_whitespace().await?;
171
172        Ok(Some(key))
173    }
174
175    async fn next_value<V: FromStream>(&mut self, context: V::Context) -> Result<V, Error> {
176        if self.done {
177            return Err(de::Error::custom(
178                "called MapAccess::next_value but the map has already ended",
179            ));
180        }
181
182        let value = V::from_stream(context, self.decoder).await?;
183
184        self.decoder.expect_whitespace().await?;
185
186        if self.decoder.maybe_delimiter(MAP_END).await? {
187            self.done = true;
188        } else {
189            self.decoder.expect_delimiter(COMMA).await?;
190        }
191
192        Ok(value)
193    }
194
195    fn size_hint(&self) -> Option<usize> {
196        self.size_hint
197    }
198}
199
200struct SeqAccess<'a, S> {
201    decoder: &'a mut Decoder<S>,
202    size_hint: Option<usize>,
203    done: bool,
204}
205
206impl<'a, S: Read + 'a> SeqAccess<'a, S> {
207    async fn new(
208        decoder: &'a mut Decoder<S>,
209        size_hint: Option<usize>,
210    ) -> Result<SeqAccess<'a, S>, Error> {
211        decoder.expect_whitespace().await?;
212        decoder.expect_delimiter(LIST_BEGIN).await?;
213        decoder.expect_whitespace().await?;
214
215        let done = decoder.maybe_delimiter(LIST_END).await?;
216
217        Ok(SeqAccess {
218            decoder,
219            size_hint,
220            done,
221        })
222    }
223}
224
225impl<'a, S: Read + 'a> de::SeqAccess for SeqAccess<'a, S> {
226    type Error = Error;
227
228    async fn next_element<T: FromStream>(
229        &mut self,
230        context: T::Context,
231    ) -> Result<Option<T>, Self::Error> {
232        if self.done {
233            return Ok(None);
234        }
235
236        self.decoder.expect_whitespace().await?;
237        let value = T::from_stream(context, self.decoder).await?;
238        self.decoder.expect_whitespace().await?;
239
240        if self.decoder.maybe_delimiter(LIST_END).await? {
241            self.done = true;
242        } else {
243            self.decoder.expect_delimiter(COMMA).await?;
244        }
245
246        Ok(Some(value))
247    }
248
249    fn size_hint(&self) -> Option<usize> {
250        self.size_hint
251    }
252}
253
254impl<'a, S: Read + 'a, T: FromStream<Context = ()> + 'a> de::ArrayAccess<T> for SeqAccess<'a, S> {
255    type Error = Error;
256
257    async fn buffer(&mut self, buffer: &mut [T]) -> Result<usize, Self::Error> {
258        let mut i = 0;
259        let len = buffer.len();
260        while i < len {
261            match de::SeqAccess::next_element(self, ()).await {
262                Ok(Some(b)) => {
263                    buffer[i] = b;
264                    i += 1;
265                }
266                Ok(None) => break,
267                Err(cause) => {
268                    let message = match self.decoder.contents(SNIPPET_LEN) {
269                        Ok(snippet) => format!("array decode error: {} at {}...", cause, snippet),
270                        Err(_) => format!("array decode error: {}", cause),
271                    };
272                    return Err(de::Error::custom(message));
273                }
274            }
275        }
276
277        Ok(i)
278    }
279}
280
281/// A structure that decodes Rust values from a JSON stream.
282pub struct Decoder<S> {
283    source: S,
284    buffer: Vec<u8>,
285    numeric: HashSet<u8>,
286}
287
288#[cfg(feature = "tokio-io")]
289impl<A: AsyncRead> Decoder<A>
290where
291    SourceReader<A>: Read,
292{
293    pub fn from_reader(reader: A) -> Decoder<SourceReader<A>> {
294        Decoder {
295            source: SourceReader::from(reader),
296            buffer: Vec::new(),
297            numeric: NUMERIC.iter().cloned().collect(),
298        }
299    }
300}
301
302impl<S> Decoder<S> {
303    fn contents(&self, max_len: usize) -> Result<String, Error> {
304        let len = Ord::min(self.buffer.len(), max_len);
305        String::from_utf8(self.buffer[..len].to_vec()).map_err(Error::invalid_utf8)
306    }
307}
308
309impl<S: Stream> Decoder<SourceStream<S>>
310where
311    SourceStream<S>: Read,
312{
313    /// Construct a new [`Decoder`] from the given source `stream`.
314    pub fn from_stream(stream: S) -> Decoder<SourceStream<S>> {
315        Decoder {
316            source: SourceStream::from(stream),
317            buffer: Vec::new(),
318            numeric: NUMERIC.iter().cloned().collect(),
319        }
320    }
321
322    /// Return `true` if this [`Decoder`] has no more data to be decoded.
323    pub fn is_terminated(&self) -> bool {
324        self.source.is_terminated()
325    }
326}
327
328impl<S: Read> Decoder<S> {
329    async fn buffer(&mut self) -> Result<(), Error> {
330        if let Some(data) = self.source.next().await {
331            self.buffer.extend(data?);
332        }
333
334        Ok(())
335    }
336
337    async fn buffer_string(&mut self) -> Result<Vec<u8>, Error> {
338        self.expect_delimiter(QUOTE).await?;
339
340        let mut i = 0;
341        let mut escaped = false;
342        loop {
343            while i >= self.buffer.len() && !self.source.is_terminated() {
344                self.buffer().await?;
345            }
346
347            if i < self.buffer.len() && &self.buffer[i..i + 1] == QUOTE && !escaped {
348                break;
349            } else if self.source.is_terminated() {
350                return Err(Error::unexpected_end());
351            }
352
353            if escaped {
354                escaped = false;
355            } else if self.buffer[i] == ESCAPE[0] {
356                escaped = true;
357            }
358
359            i += 1;
360        }
361
362        let mut s = Vec::with_capacity(i);
363        let mut escape = false;
364        for byte in self.buffer.drain(0..i) {
365            let as_slice = std::slice::from_ref(&byte);
366            if escape {
367                s.put_u8(byte);
368                escape = false;
369            } else if as_slice == ESCAPE {
370                escape = true;
371            } else {
372                s.put_u8(byte);
373            }
374        }
375
376        self.buffer.remove(0); // process the end quote
377        self.buffer.shrink_to_fit();
378        Ok(s)
379    }
380
381    async fn buffer_while<F: Fn(u8) -> bool>(&mut self, cond: F) -> Result<usize, Error> {
382        let mut i = 0;
383        loop {
384            while i >= self.buffer.len() && !self.source.is_terminated() {
385                self.buffer().await?;
386            }
387
388            if i < self.buffer.len() && cond(self.buffer[i]) {
389                i += 1;
390            } else if self.source.is_terminated() {
391                return Ok(i);
392            } else {
393                break;
394            }
395        }
396
397        Ok(i)
398    }
399
400    async fn peek(&mut self) -> Result<Option<u8>, Error> {
401        while self.buffer.is_empty() && !self.source.is_terminated() {
402            self.buffer().await?;
403        }
404
405        if self.buffer.is_empty() {
406            Ok(None)
407        } else {
408            Ok(Some(self.buffer[0]))
409        }
410    }
411
412    async fn next_char(&mut self) -> Result<Option<u8>, Error> {
413        while self.buffer.is_empty() && !self.source.is_terminated() {
414            self.buffer().await?;
415        }
416
417        match self.buffer.len() {
418            0 => Ok(None),
419            _ => Ok(Some(self.buffer.remove(0))),
420        }
421    }
422
423    async fn eat_char(&mut self) -> Result<(), Error> {
424        self.next_char().await?;
425        Ok(())
426    }
427
428    async fn next_or_eof(&mut self) -> Result<u8, Error> {
429        while self.buffer.is_empty() && !self.source.is_terminated() {
430            self.buffer().await?;
431        }
432        if self.buffer.is_empty() {
433            Err(Error::unexpected_end())
434        } else {
435            Ok(self.buffer.remove(0))
436        }
437    }
438
439    async fn decode_number<V: Visitor>(&mut self, visitor: V) -> Result<V::Value, Error> {
440        let mut i = 0;
441        loop {
442            if self.buffer[i] == DECIMAL[0] || self.buffer[i] == E[0] {
443                return de::Decoder::decode_f64(self, visitor).await;
444            } else if !self.numeric.contains(&self.buffer[i]) {
445                return de::Decoder::decode_i64(self, visitor).await;
446            }
447
448            i += 1;
449            while i >= self.buffer.len() && !self.source.is_terminated() {
450                self.buffer().await?;
451            }
452
453            if self.source.is_terminated() {
454                return de::Decoder::decode_i64(self, visitor).await;
455            }
456        }
457    }
458
459    async fn expect_delimiter(&mut self, delimiter: &'static [u8]) -> Result<(), Error> {
460        while self.buffer.is_empty() && !self.source.is_terminated() {
461            self.buffer().await?;
462        }
463
464        if self.buffer.is_empty() {
465            return Err(Error::unexpected_end());
466        }
467
468        if &self.buffer[0..1] == delimiter {
469            self.buffer.remove(0);
470            Ok(())
471        } else {
472            let contents = self.contents(SNIPPET_LEN)?;
473            Err(de::Error::custom(format!(
474                "unexpected delimiter {}, expected {} at `{}`...",
475                self.buffer[0] as char, delimiter[0] as char, contents
476            )))
477        }
478    }
479
480    async fn expect_whitespace(&mut self) -> Result<(), Error> {
481        let i = self.buffer_while(|b| (b as char).is_whitespace()).await?;
482        self.buffer.drain(..i);
483        Ok(())
484    }
485
486    async fn ignore_value(&mut self) -> Result<(), Error> {
487        self.expect_whitespace().await?;
488
489        while self.buffer.is_empty() && !self.source.is_terminated() {
490            self.buffer().await?;
491        }
492
493        if !self.buffer.is_empty() {
494            // Determine the type of JSON value based on the first character in the buffer
495            match self.buffer[0] {
496                b'"' => self.ignore_string().await?,
497                b'-' => {
498                    self.eat_char().await?;
499                    self.ignore_number().await?;
500                }
501                b'0'..=b'9' => self.ignore_number().await?,
502                b't' => self.ignore_exactly("true").await?,
503                b'f' => self.ignore_exactly("false").await?,
504                b'n' => self.ignore_exactly("null").await?,
505                b'[' => self.ignore_array().await?,
506                b'{' => self.ignore_object().await?,
507                // If the first character doesn't match any JSON value type, return an error
508                _ => {
509                    return Err(Error::invalid_utf8(format!(
510                        "unexpected token ignoring value: {}",
511                        self.buffer[0]
512                    )))
513                }
514            }
515        }
516
517        Ok(())
518    }
519
520    async fn ignore_string(&mut self) -> Result<(), Error> {
521        // eat the first char, which is a quote
522        self.eat_char().await?;
523        loop {
524            if self.buffer.is_empty() {
525                self.buffer().await?;
526            }
527
528            if self.buffer.is_empty() && self.source.is_terminated() {
529                return Err(Error::unexpected_end());
530            }
531
532            let ch = self.next_or_eof().await?;
533            if !ESCAPE_CHARS[ch as usize] {
534                continue;
535            }
536
537            match ch {
538                b'"' => {
539                    return Ok(());
540                }
541                b'\\' => {
542                    self.ignore_escaped_char().await?;
543                }
544                ch => {
545                    return Err(Error::invalid_utf8(format!(
546                        "invalid control character in string: {ch}"
547                    )));
548                }
549            }
550        }
551    }
552
553    /// Parses a JSON escape sequence and discards the value. Assumes the previous
554    /// byte read was a backslash.
555    async fn ignore_escaped_char(&mut self) -> Result<(), Error> {
556        let ch = self.next_or_eof().await?;
557
558        match ch {
559            b'"' | b'\\' | b'/' | b'b' | b'f' | b'n' | b'r' | b't' => {}
560            b'u' => {
561                // At this point we don't care if the codepoint is valid. We just
562                // want to consume it. We don't actually know what is valid or not
563                // at this point, because that depends on if this string will
564                // ultimately be parsed into a string or a byte buffer in the "real"
565                // parse.
566
567                self.decode_hex_escape().await?;
568            }
569            _ => {
570                return Err(Error::invalid_utf8("invalid escape character in string"));
571            }
572        }
573
574        Ok(())
575    }
576
577    async fn decode_hex_escape(&mut self) -> Result<u16, Error> {
578        let mut n = 0;
579        for _ in 0..4 {
580            let ch = decode_hex_val(self.next_or_eof().await?);
581            match ch {
582                None => return Err(Error::invalid_utf8("invalid escape decoding hex escape")),
583                Some(val) => {
584                    n = (n << 4) + val;
585                }
586            }
587        }
588        Ok(n)
589    }
590
591    async fn maybe_delimiter(&mut self, delimiter: &'static [u8]) -> Result<bool, Error> {
592        while self.buffer.is_empty() && !self.source.is_terminated() {
593            self.buffer().await?;
594        }
595
596        if self.buffer.is_empty() {
597            Ok(false)
598        } else if self.buffer.starts_with(delimiter) {
599            self.buffer.remove(0);
600            Ok(true)
601        } else {
602            Ok(false)
603        }
604    }
605
606    async fn parse_bool(&mut self) -> Result<bool, Error> {
607        self.expect_whitespace().await?;
608
609        while self.buffer.len() < TRUE.len() && !self.source.is_terminated() {
610            self.buffer().await?;
611        }
612
613        if self.buffer.is_empty() {
614            return Err(Error::unexpected_end());
615        } else if self.buffer.starts_with(TRUE) {
616            self.buffer.drain(0..TRUE.len());
617            return Ok(true);
618        }
619
620        while self.buffer.len() < FALSE.len() && !self.source.is_terminated() {
621            self.buffer().await?;
622        }
623
624        if self.buffer.is_empty() {
625            return Err(Error::unexpected_end());
626        } else if self.buffer.starts_with(FALSE) {
627            self.buffer.drain(0..FALSE.len());
628            return Ok(false);
629        }
630
631        let i = Ord::min(self.buffer.len(), SNIPPET_LEN);
632        let unknown = String::from_utf8(self.buffer[..i].to_vec()).map_err(Error::invalid_utf8)?;
633        Err(de::Error::invalid_value(unknown, "a boolean"))
634    }
635
636    async fn parse_number<N: FromStr>(&mut self) -> Result<N, Error>
637    where
638        <N as FromStr>::Err: fmt::Display,
639    {
640        self.expect_whitespace().await?;
641
642        let numeric = self.numeric.clone();
643        let i = self.buffer_while(|b| numeric.contains(&b)).await?;
644        let n = String::from_utf8(self.buffer[0..i].to_vec()).map_err(Error::invalid_utf8)?;
645
646        match n.parse() {
647            Ok(number) => {
648                self.buffer.drain(..i);
649                Ok(number)
650            }
651            Err(cause) => Err(de::Error::invalid_value(cause, std::any::type_name::<N>())),
652        }
653    }
654
655    async fn parse_string(&mut self) -> Result<String, Error> {
656        let s = self.buffer_string().await?;
657        String::from_utf8(s).map_err(Error::invalid_utf8)
658    }
659
660    async fn parse_unit(&mut self) -> Result<(), Error> {
661        self.expect_whitespace().await?;
662
663        while self.buffer.len() < NULL.len() && !self.source.is_terminated() {
664            self.buffer().await?;
665        }
666
667        if self.buffer.starts_with(NULL) {
668            self.buffer.drain(..NULL.len());
669            Ok(())
670        } else {
671            let i = Ord::min(self.buffer.len(), SNIPPET_LEN);
672            let as_str =
673                String::from_utf8(self.buffer[..i].to_vec()).map_err(Error::invalid_utf8)?;
674
675            Err(de::Error::invalid_type(as_str, "null"))
676        }
677    }
678
679    async fn ignore_exactly(&mut self, s: &str) -> Result<(), Error> {
680        for ch in s.as_bytes() {
681            match self.peek().await?.as_ref() {
682                None => return Err(Error::unexpected_end()),
683                Some(next) if next == ch => self.eat_char().await?,
684                Some(next) => {
685                    return Err(Error::invalid_utf8(format!(
686                        "invalid char {next}, expected {ch}"
687                    )));
688                }
689            }
690        }
691        Ok(())
692    }
693
694    async fn ignore_number(&mut self) -> Result<(), Error> {
695        let ch = self.next_char().await?;
696        match ch {
697            Some(b'0') => {
698                // There cannot be any leading zeroes.
699                // If there is a leading '0', it cannot be followed by more digits.
700                if let Some(b'0'..=b'9') = self.peek().await? {
701                    return Err(Error::invalid_utf8("invalid number, two leading zeroes"));
702                }
703            }
704            Some(b'1'..=b'9') => {
705                while let Some(b'0'..=b'9') = self.peek().await? {
706                    self.eat_char().await?;
707                }
708            }
709            Some(ch) => {
710                return Err(Error::invalid_utf8(format!("invalid number: {}", ch)));
711            }
712            None => return Err(Error::unexpected_end()),
713        }
714
715        match self.peek().await? {
716            Some(b'.') => self.ignore_decimal().await,
717            Some(b'e' | b'E') => self.ignore_exponent().await,
718            _ => Ok(()),
719        }
720    }
721
722    async fn ignore_decimal(&mut self) -> Result<(), Error> {
723        self.eat_char().await?;
724
725        let mut at_least_one_digit = false;
726        while let Some(b'0'..=b'9') = self.peek().await? {
727            self.eat_char().await?;
728            at_least_one_digit = true;
729        }
730
731        if !at_least_one_digit {
732            return Err(Error::invalid_utf8(
733                "invalid number, expected at least one digit after decimal",
734            ));
735        }
736
737        match self.peek().await? {
738            Some(b'e' | b'E') => self.ignore_exponent().await,
739            _ => Ok(()),
740        }
741    }
742
743    async fn ignore_exponent(&mut self) -> Result<(), Error> {
744        self.eat_char().await?;
745
746        if let Some(b'+' | b'-') = self.peek().await? {
747            self.eat_char().await?;
748        }
749
750        // Make sure a digit follows the exponent place.
751        match self.next_char().await? {
752            Some(b'0'..=b'9') => {}
753            Some(ch) => {
754                return Err(Error::invalid_utf8(format!(
755                    "expected a digit to follow the exponent, found {ch}"
756                )));
757            }
758            None => return Err(Error::unexpected_end()),
759        }
760
761        while let Some(b'0'..=b'9') = self.peek().await? {
762            self.eat_char().await?;
763        }
764
765        Ok(())
766    }
767
768    #[async_recursion]
769    async fn ignore_array(&mut self) -> Result<(), Error> {
770        self.eat_char().await?;
771        self.expect_whitespace().await?;
772        if self.peek().await? == Some(b']') {
773            self.eat_char().await?;
774            return Ok(());
775        }
776
777        loop {
778            self.ignore_value().await?;
779            self.expect_whitespace().await?;
780            match self.peek().await? {
781                Some(b',') => self.eat_char().await?,
782                Some(b']') => {
783                    self.eat_char().await?;
784                    return Ok(());
785                }
786                Some(ch) => {
787                    return Err(Error::invalid_utf8(format!(
788                        "invalid char {ch}, expected , or ]"
789                    )))
790                }
791                None => return Err(Error::unexpected_end()),
792            }
793        }
794    }
795
796    #[async_recursion]
797    async fn ignore_object(&mut self) -> Result<(), Error> {
798        self.eat_char().await?; // b'{'
799        self.expect_whitespace().await?;
800        if self.peek().await? == Some(b'}') {
801            self.eat_char().await?;
802            return Ok(());
803        }
804
805        loop {
806            self.expect_whitespace().await?;
807            self.ignore_string().await?; // key
808            self.expect_whitespace().await?;
809            self.ignore_exactly(":").await?;
810            self.ignore_value().await?;
811            self.expect_whitespace().await?;
812            match self.peek().await? {
813                Some(b'}') => {
814                    self.eat_char().await?;
815                    return Ok(());
816                }
817                Some(b',') => self.eat_char().await?,
818                Some(ch) => {
819                    return Err(Error::invalid_utf8(format!(
820                        "invalid char {ch}, expected , or }}"
821                    )))
822                }
823                None => return Err(Error::unexpected_end()),
824            }
825        }
826    }
827}
828
829impl<S: Read> de::Decoder for Decoder<S> {
830    type Error = Error;
831
832    async fn decode_any<V: Visitor>(&mut self, visitor: V) -> Result<V::Value, Self::Error> {
833        self.expect_whitespace().await?;
834
835        while self.buffer.is_empty() && !self.source.is_terminated() {
836            self.buffer().await?;
837        }
838
839        if self.buffer.is_empty() {
840            Err(Error::unexpected_end())
841        } else if self.buffer.starts_with(QUOTE) {
842            self.decode_string(visitor).await
843        } else if self.buffer.starts_with(LIST_BEGIN) {
844            self.decode_seq(visitor).await
845        } else if self.buffer.starts_with(MAP_BEGIN) {
846            self.decode_map(visitor).await
847        } else if self.numeric.contains(&self.buffer[0]) {
848            self.decode_number(visitor).await
849        } else if (self.buffer.len() >= FALSE.len() && self.buffer.starts_with(FALSE))
850            || (self.buffer.len() >= TRUE.len() && self.buffer.starts_with(TRUE))
851        {
852            self.decode_bool(visitor).await
853        } else if self.buffer.len() >= NULL.len() && self.buffer.starts_with(NULL) {
854            self.decode_option(visitor).await
855        } else {
856            while self.buffer.len() < TRUE.len() && !self.source.is_terminated() {
857                self.buffer().await?;
858            }
859
860            if self.buffer.is_empty() {
861                Err(Error::unexpected_end())
862            } else if self.buffer.starts_with(TRUE) {
863                self.decode_bool(visitor).await
864            } else if self.buffer.starts_with(NULL) {
865                self.decode_option(visitor).await
866            } else {
867                while self.buffer.len() < FALSE.len() && !self.source.is_terminated() {
868                    self.buffer().await?;
869                }
870
871                if self.buffer.is_empty() {
872                    Err(Error::unexpected_end())
873                } else if self.buffer.starts_with(FALSE) {
874                    self.decode_bool(visitor).await
875                } else {
876                    let i = Ord::min(self.buffer.len(), SNIPPET_LEN);
877                    let s = String::from_utf8(self.buffer[0..i].to_vec())
878                        .map_err(Error::invalid_utf8)?;
879
880                    Err(de::Error::invalid_value(
881                        s,
882                        std::any::type_name::<V::Value>(),
883                    ))
884                }
885            }
886        }
887    }
888
889    async fn decode_bool<V: Visitor>(&mut self, visitor: V) -> Result<V::Value, Self::Error> {
890        let b = self.parse_bool().await?;
891        visitor.visit_bool(b)
892    }
893
894    async fn decode_bytes<V: Visitor>(&mut self, visitor: V) -> Result<V::Value, Self::Error> {
895        let s = self.parse_string().await?;
896        visitor.visit_string(s)
897    }
898
899    async fn decode_i8<V: Visitor>(&mut self, visitor: V) -> Result<V::Value, Self::Error> {
900        let i = self.parse_number().await?;
901        visitor.visit_i8(i)
902    }
903
904    async fn decode_i16<V: Visitor>(&mut self, visitor: V) -> Result<V::Value, Self::Error> {
905        let i = self.parse_number().await?;
906        visitor.visit_i16(i)
907    }
908
909    async fn decode_i32<V: Visitor>(&mut self, visitor: V) -> Result<V::Value, Self::Error> {
910        let i = self.parse_number().await?;
911        visitor.visit_i32(i)
912    }
913
914    async fn decode_i64<V: Visitor>(&mut self, visitor: V) -> Result<V::Value, Self::Error> {
915        let i = self.parse_number().await?;
916        visitor.visit_i64(i)
917    }
918
919    async fn decode_u8<V: Visitor>(&mut self, visitor: V) -> Result<V::Value, Self::Error> {
920        let u = self.parse_number().await?;
921        visitor.visit_u8(u)
922    }
923
924    async fn decode_u16<V: Visitor>(&mut self, visitor: V) -> Result<V::Value, Self::Error> {
925        let u = self.parse_number().await?;
926        visitor.visit_u16(u)
927    }
928
929    async fn decode_u32<V: Visitor>(&mut self, visitor: V) -> Result<V::Value, Self::Error> {
930        let u = self.parse_number().await?;
931        visitor.visit_u32(u)
932    }
933
934    async fn decode_u64<V: Visitor>(&mut self, visitor: V) -> Result<V::Value, Self::Error> {
935        let u = self.parse_number().await?;
936        visitor.visit_u64(u)
937    }
938
939    async fn decode_f32<V: Visitor>(&mut self, visitor: V) -> Result<V::Value, Self::Error> {
940        let f = self.parse_number().await?;
941        visitor.visit_f32(f)
942    }
943
944    async fn decode_f64<V: Visitor>(&mut self, visitor: V) -> Result<V::Value, Self::Error> {
945        let f = self.parse_number().await?;
946        visitor.visit_f64(f)
947    }
948
949    async fn decode_array_bool<V: Visitor>(&mut self, visitor: V) -> Result<V::Value, Self::Error> {
950        let access = SeqAccess::new(self, None).await?;
951        // was getting an error about S not living long enough, so boxed based on
952        // https://github.com/rust-lang/rust/issues/100013#issuecomment-2052045872
953        // once this issue is closed, we can remove the `.boxed()`
954        visitor.visit_array_bool(access).boxed().await
955    }
956
957    async fn decode_array_i8<V: Visitor>(
958        &mut self,
959        visitor: V,
960    ) -> Result<<V as Visitor>::Value, Self::Error> {
961        let access = SeqAccess::new(self, None).await?;
962        visitor.visit_array_bool(access).boxed().await
963    }
964
965    async fn decode_array_i16<V: Visitor>(
966        &mut self,
967        visitor: V,
968    ) -> Result<<V as Visitor>::Value, Self::Error> {
969        let access = SeqAccess::new(self, None).await?;
970        visitor.visit_array_i16(access).boxed().await
971    }
972
973    async fn decode_array_i32<V: Visitor>(
974        &mut self,
975        visitor: V,
976    ) -> Result<<V as Visitor>::Value, Self::Error> {
977        let access = SeqAccess::new(self, None).await?;
978        visitor.visit_array_i32(access).boxed().await
979    }
980
981    async fn decode_array_i64<V: Visitor>(
982        &mut self,
983        visitor: V,
984    ) -> Result<<V as Visitor>::Value, Self::Error> {
985        let access = SeqAccess::new(self, None).await?;
986        visitor.visit_array_i64(access).boxed().await
987    }
988
989    async fn decode_array_u8<V: Visitor>(
990        &mut self,
991        visitor: V,
992    ) -> Result<<V as Visitor>::Value, Self::Error> {
993        let access = SeqAccess::new(self, None).await?;
994        visitor.visit_array_u8(access).boxed().await
995    }
996
997    async fn decode_array_u16<V: Visitor>(
998        &mut self,
999        visitor: V,
1000    ) -> Result<<V as Visitor>::Value, Self::Error> {
1001        let access = SeqAccess::new(self, None).await?;
1002        visitor.visit_array_u16(access).boxed().await
1003    }
1004
1005    async fn decode_array_u32<V: Visitor>(
1006        &mut self,
1007        visitor: V,
1008    ) -> Result<<V as Visitor>::Value, Self::Error> {
1009        let access = SeqAccess::new(self, None).await?;
1010        visitor.visit_array_u32(access).boxed().await
1011    }
1012
1013    async fn decode_array_u64<V: Visitor>(
1014        &mut self,
1015        visitor: V,
1016    ) -> Result<<V as Visitor>::Value, Self::Error> {
1017        let access = SeqAccess::new(self, None).await?;
1018        visitor.visit_array_u64(access).boxed().await
1019    }
1020
1021    async fn decode_array_f32<V: Visitor>(
1022        &mut self,
1023        visitor: V,
1024    ) -> Result<<V as Visitor>::Value, Self::Error> {
1025        let access = SeqAccess::new(self, None).await?;
1026        visitor.visit_array_f32(access).boxed().await
1027    }
1028
1029    async fn decode_array_f64<V: Visitor>(
1030        &mut self,
1031        visitor: V,
1032    ) -> Result<<V as Visitor>::Value, Self::Error> {
1033        let access = SeqAccess::new(self, None).await?;
1034        visitor.visit_array_f64(access).boxed().await
1035    }
1036
1037    async fn decode_string<V: Visitor>(&mut self, visitor: V) -> Result<V::Value, Self::Error> {
1038        self.expect_whitespace().await?;
1039
1040        let s = self.parse_string().await?;
1041        visitor.visit_string(s)
1042    }
1043
1044    async fn decode_option<V: Visitor>(&mut self, visitor: V) -> Result<V::Value, Self::Error> {
1045        self.expect_whitespace().await?;
1046
1047        while self.buffer.len() < NULL.len() && !self.source.is_terminated() {
1048            self.buffer().await?;
1049        }
1050
1051        if self.buffer.starts_with(NULL) {
1052            self.buffer.drain(0..NULL.len());
1053            visitor.visit_none()
1054        } else {
1055            visitor.visit_some(self).await
1056        }
1057    }
1058
1059    async fn decode_seq<V: Visitor>(&mut self, visitor: V) -> Result<V::Value, Self::Error> {
1060        let access = SeqAccess::new(self, None).await?;
1061        visitor.visit_seq(access).boxed().await
1062    }
1063
1064    async fn decode_unit<V: Visitor>(
1065        &mut self,
1066        visitor: V,
1067    ) -> Result<<V as Visitor>::Value, Self::Error> {
1068        self.parse_unit().await?;
1069        visitor.visit_unit()
1070    }
1071
1072    async fn decode_uuid<V: Visitor>(
1073        &mut self,
1074        visitor: V,
1075    ) -> Result<<V as Visitor>::Value, Self::Error> {
1076        let s = self.parse_string().await?;
1077        visitor.visit_string(s)
1078    }
1079
1080    async fn decode_tuple<V: Visitor>(
1081        &mut self,
1082        len: usize,
1083        visitor: V,
1084    ) -> Result<V::Value, Self::Error> {
1085        let access = SeqAccess::new(self, Some(len)).await?;
1086        visitor.visit_seq(access).boxed().await
1087    }
1088
1089    async fn decode_map<V: Visitor>(&mut self, visitor: V) -> Result<V::Value, Self::Error> {
1090        let access = MapAccess::new(self, None).await?;
1091        visitor.visit_map(access).boxed().await
1092    }
1093
1094    async fn decode_ignored_any<V: Visitor>(
1095        &mut self,
1096        visitor: V,
1097    ) -> Result<V::Value, Self::Error> {
1098        self.ignore_value().await?;
1099        visitor.visit_unit()
1100    }
1101}
1102
1103impl<S: Read> From<S> for Decoder<S> {
1104    fn from(source: S) -> Self {
1105        Self {
1106            source,
1107            buffer: vec![],
1108            numeric: NUMERIC.iter().cloned().collect(),
1109        }
1110    }
1111}
1112
1113/// Decode the given JSON-encoded stream of bytes into an instance of `T` using the given context.
1114pub async fn decode<S: Stream<Item = Bytes> + Send + Unpin, T: FromStream>(
1115    context: T::Context,
1116    source: S,
1117) -> Result<T, Error> {
1118    let source = source.map(Result::<Bytes, Error>::Ok);
1119    let mut decoder = Decoder::from(SourceStream::from(source));
1120
1121    let decoded = T::from_stream(context, &mut decoder).await?;
1122    decoder.expect_whitespace().await?;
1123
1124    if decoder.is_terminated() {
1125        Ok(decoded)
1126    } else {
1127        let buffer = decoder.contents(SNIPPET_LEN)?;
1128        Err(de::Error::custom(format!(
1129            "expected end of stream, found `{}...`",
1130            buffer
1131        )))
1132    }
1133}
1134
1135/// Decode the given JSON-encoded stream of bytes into an instance of `T` using the given context.
1136pub async fn try_decode<
1137    E: fmt::Display,
1138    S: Stream<Item = Result<Bytes, E>> + Send + Unpin,
1139    T: FromStream,
1140>(
1141    context: T::Context,
1142    source: S,
1143) -> Result<T, Error> {
1144    let mut decoder = Decoder::from_stream(source.map_err(|e| de::Error::custom(e)));
1145    let decoded = T::from_stream(context, &mut decoder).await?;
1146    decoder.expect_whitespace().await?;
1147
1148    if decoder.is_terminated() {
1149        Ok(decoded)
1150    } else {
1151        let snippet = decoder.contents(SNIPPET_LEN)?;
1152        Err(de::Error::custom(format!(
1153            "expected end of stream, found `{}...`",
1154            snippet
1155        )))
1156    }
1157}
1158
1159/// Decode the given JSON-encoded stream of bytes into an instance of `T` using the given context.
1160#[cfg(feature = "tokio-io")]
1161/// Decode the given JSON-encoded stream of bytes into an instance of `T` using the given context.
1162pub async fn read_from<S: AsyncReadExt + Send + Unpin, T: FromStream>(
1163    context: T::Context,
1164    source: S,
1165) -> Result<T, Error> {
1166    T::from_stream(context, &mut Decoder::from(SourceReader::from(source))).await
1167}
1168
1169fn decode_hex_val(val: u8) -> Option<u16> {
1170    let n = HEX[val as usize] as u16;
1171    if n == 255 {
1172        None
1173    } else {
1174        Some(n)
1175    }
1176}
1177
1178#[cfg(test)]
1179mod tests {
1180    use std::cmp::max;
1181
1182    use futures::stream;
1183    use test_case::test_case;
1184
1185    use super::*;
1186
1187    /// next_or_eof should return the next char in the buffer/stream, or
1188    /// if we've hit the EOF, throw an error.
1189    #[tokio::test]
1190    async fn test_next_or_eof() {
1191        let s = b"bar";
1192        for num_chunks in (1..s.len()).rev() {
1193            let source = stream::iter(s.iter().copied())
1194                .chunks(num_chunks)
1195                .map(Bytes::from)
1196                .map(Result::<Bytes, Error>::Ok);
1197
1198            let mut decoder = Decoder::from_stream(source);
1199            for expected in s {
1200                let actual = decoder.next_or_eof().await.unwrap();
1201                assert_eq!(&actual, expected);
1202            }
1203            let res = decoder.next_or_eof().await;
1204            assert!(res.is_err());
1205        }
1206    }
1207
1208    fn test_decoder(
1209        source: &str,
1210    ) -> Decoder<SourceStream<impl Stream<Item = Result<Bytes, Error>> + '_>> {
1211        let chunk_size = max(source.len(), 1);
1212        let source = stream::iter(source.as_bytes().iter().copied())
1213            .chunks(chunk_size)
1214            .map(Bytes::from)
1215            .map(Result::<Bytes, Error>::Ok);
1216
1217        Decoder::from_stream(source)
1218    }
1219
1220    /// ignore_exactly takes a vector of bytes, and consumes exactly those characters.
1221    #[test_case("foo", "foo", true, 0; "ignore foo")]
1222    #[test_case("foobar", "foo", true, 3; "ignore foo not bar")]
1223    #[test_case("foobar", "bar", false, 6; "wrong expected str")]
1224    #[test_case("", "", true, 0; "empty good")]
1225    #[test_case("", "a", false, 0; "empty bad")]
1226    #[tokio::test]
1227    async fn test_ignore_exactly(source: &str, to_ignore: &str, success: bool, chars_left: usize) {
1228        let mut decoder = test_decoder(source);
1229        let res = decoder.ignore_exactly(to_ignore).await;
1230
1231        assert_eq!(res.is_ok(), success);
1232        assert_eq!(decoder.buffer.len(), chars_left);
1233    }
1234
1235    #[test_case(r#""""#, Ok(0); "empty string")]
1236    #[test_case(r#""","#, Ok(1); "empty string then leave char")]
1237    #[test_case("\"foo\"bar", Ok(3); "ends correctly")]
1238    #[test_case("\"test\"", Ok(0); "string value")]
1239    #[test_case("\"\"", Ok(0); "empty")]
1240    #[test_case("\"\\r\"", Ok(0); "carriage return")]
1241    #[test_case("\"hello\"world\"", Ok(6); "multiple quotes")]
1242    #[test_case("\"   hello\"", Ok(0); "whitespace before")]
1243    #[test_case("\"hello   \"   ", Ok(3); "whitespace after")]
1244    #[test_case("\"\\t\\n\\r\"", Ok(0); "whitespace chars")]
1245    #[test_case("\"\"test\\\"", Ok(6); "chars after empty string")]
1246    #[test_case("\"\\\\\\\\\"", Ok(0); "backslashes")]
1247    #[test_case("", Err(Error::unexpected_end()); "eof")]
1248    #[test_case(r#""a"#, Err(Error::unexpected_end()); "unterminatedstring")]
1249    #[test_case("\"\x01\"", Err(Error::invalid_utf8("invalid control character in string: 1")); "invalid control char")]
1250    #[test_case(r#""\u00""#, Err(Error::invalid_utf8("invalid escape decoding hex escape")); "unfinished hex char")]
1251    #[test_case(
1252        r#""\x01""#,
1253        Err(Error::invalid_utf8("invalid escape character in string"))
1254    )]
1255    #[tokio::test]
1256    async fn test_ignore_string(source: &str, expected: Result<usize, Error>) {
1257        let mut decoder = test_decoder(source);
1258
1259        let res = decoder.ignore_string().await;
1260
1261        match expected {
1262            Ok(end_length) => assert_eq!(decoder.buffer.len(), end_length),
1263            Err(e) => assert_eq!(Err(e), res),
1264        }
1265    }
1266
1267    #[test_case("-123", Ok(0); "negative number")]
1268    #[test_case("-123.45", Ok(0); "negative float")]
1269    #[test_case("abc", Err(Error::invalid_utf8("unexpected token ignoring value: 97")); "non number")]
1270    #[test_case("", Ok(0); "empty source")]
1271    #[tokio::test]
1272    async fn test_ignore_value(source: &str, expected: Result<usize, Error>) {
1273        let mut decoder = test_decoder(source);
1274
1275        // `ignore_number` only works on positive numbers.  `ignore_value` will eat that b'-'
1276        let res = decoder.ignore_value().await;
1277
1278        if let Ok(end_length) = expected {
1279            res.unwrap();
1280            assert_eq!(decoder.buffer.len(), end_length);
1281        } else {
1282            assert_eq!(res.unwrap_err(), expected.unwrap_err())
1283        }
1284    }
1285
1286    #[test_case("0", Ok(0); "zero")]
1287    #[test_case("00", Err(Error::invalid_utf8("invalid number, two leading zeroes")); "double zero")]
1288    #[test_case("123", Ok(0); "positive number")]
1289    #[test_case("123.45", Ok(0); "positive float")]
1290    #[test_case("0.0", Ok(0); "zero float")]
1291    #[test_case("123, 45", Ok(4); "parses only one number")]
1292    #[test_case("1e30, 45", Ok(4); "parses exponent")]
1293    #[test_case("1.2e3, 45", Ok(4); "parses decimal exponent")]
1294    #[test_case("abc", Err(Error::invalid_utf8("invalid number: 97")); "unexpected token")]
1295    #[test_case("", Err(Error::unexpected_end()); "unexpected end")]
1296    #[test_case("1.", Err(Error::invalid_utf8("invalid number, expected at least one digit after decimal")); "expected a number after the decimal")]
1297    #[test_case("1.1e-1", Ok(0); "negative exponent")]
1298    #[test_case("1.1e-a", Err(Error::invalid_utf8("expected a digit to follow the exponent, found 97")); "invalid exponent")]
1299    #[test_case("1.1e", Err(Error::unexpected_end()); "unterminated number")]
1300    #[tokio::test]
1301    async fn test_ignore_number(source: &str, expected: Result<usize, Error>) {
1302        let mut decoder = test_decoder(source);
1303        let res = decoder.ignore_number().await;
1304
1305        if let Ok(end_length) = expected {
1306            res.unwrap();
1307            assert_eq!(decoder.buffer.len(), end_length);
1308        } else {
1309            assert_eq!(res.unwrap_err(), expected.unwrap_err())
1310        }
1311    }
1312
1313    #[test_case("[]", Ok(0); "empty array")]
1314    #[test_case("[1]", Ok(0); "single array")]
1315    #[test_case("[ ] ", Ok(1); "whitespace empty array")]
1316    #[test_case("[ 1 ] ", Ok(1); "whitespace single array")]
1317    #[test_case("[1,2]", Ok(0); "multi array")]
1318    #[test_case("[],[]", Ok(3); "ends correctly")]
1319    #[test_case("[\"foo\",\"bar\"]", Ok(0); "string array")]
1320    #[test_case(r#""#, Err(Error::unexpected_end()); "unexpected end")]
1321    #[test_case(r#"["test""test"]"#, Err(Error::invalid_utf8("invalid char 34, expected , or ]")); "no comma")]
1322    #[tokio::test]
1323    async fn test_ignore_array(source: &str, expected: Result<usize, Error>) {
1324        let mut decoder = test_decoder(source);
1325        let res = decoder.ignore_array().await;
1326
1327        match expected {
1328            Ok(end_length) => assert_eq!(decoder.buffer.len(), end_length),
1329            Err(e) => assert_eq!(res.unwrap_err(), e),
1330        }
1331    }
1332
1333    #[test_case("{}", Ok(0); "empty object")]
1334    #[test_case("{},{}", Ok(3); "ends correctly")]
1335    #[test_case(r#"{"k":2, "k":3}"#, Ok(0); "multi object")]
1336    #[test_case(r#"{"k":1}"#, Ok(0); "single object")]
1337    #[test_case(r#"{"foo":"bar"}"#, Ok(0); "string value")]
1338    #[test_case(r#"{ } "#, Ok(1); "whitespace empty object")]
1339    #[test_case(r#"{"k" : 2 , " k " : 3 }"#, Ok(0); "whitespace multi object")]
1340    #[test_case(r#"{ " k " : 1 } "#, Ok(1); "whitespace single object")]
1341    #[test_case(r#"{"k""v"}"#, Err(Error::invalid_utf8("invalid char 34, expected 58")); "missing colon")]
1342    #[test_case(r#"{"k","v"}"#, Err(Error::invalid_utf8("invalid char 44, expected 58")); "comma when expecting colon")]
1343    #[test_case(r#"{,"k":"v"}"#, Err(Error::invalid_utf8("invalid char 107, expected 58")); "comma when expecting value")]
1344    #[test_case(r#"{"k":"v"asdf}"#, Err(Error::invalid_utf8("invalid char 97, expected , or }")); "value when expecting comma")]
1345    #[tokio::test]
1346    async fn test_ignore_object(source: &str, expected: Result<usize, Error>) {
1347        let mut decoder = test_decoder(source);
1348        let res = decoder.ignore_object().await;
1349
1350        match expected {
1351            Err(e) => assert_eq!(Err(e), res),
1352            Ok(end_length) => assert_eq!(decoder.buffer.len(), end_length),
1353        }
1354    }
1355}