csv/
deserializer.rs

1use std::{error::Error as StdError, fmt, iter, num, str};
2
3use serde_core::{
4    de::value::BorrowedBytesDeserializer,
5    de::{
6        Deserialize, DeserializeSeed, Deserializer, EnumAccess,
7        Error as SerdeError, IntoDeserializer, MapAccess, SeqAccess,
8        Unexpected, VariantAccess, Visitor,
9    },
10};
11
12use crate::{
13    byte_record::{ByteRecord, ByteRecordIter},
14    error::{Error, ErrorKind},
15    string_record::{StringRecord, StringRecordIter},
16};
17
18use self::DeserializeErrorKind as DEK;
19
20pub fn deserialize_string_record<'de, D: Deserialize<'de>>(
21    record: &'de StringRecord,
22    headers: Option<&'de StringRecord>,
23) -> Result<D, Error> {
24    let mut deser = DeRecordWrap(DeStringRecord {
25        it: record.iter().peekable(),
26        headers: headers.map(|r| r.iter()),
27        field: 0,
28    });
29    D::deserialize(&mut deser).map_err(|err| {
30        Error::new(ErrorKind::Deserialize {
31            pos: record.position().cloned(),
32            err,
33        })
34    })
35}
36
37pub fn deserialize_byte_record<'de, D: Deserialize<'de>>(
38    record: &'de ByteRecord,
39    headers: Option<&'de ByteRecord>,
40) -> Result<D, Error> {
41    let mut deser = DeRecordWrap(DeByteRecord {
42        it: record.iter().peekable(),
43        headers: headers.map(|r| r.iter()),
44        field: 0,
45    });
46    D::deserialize(&mut deser).map_err(|err| {
47        Error::new(ErrorKind::Deserialize {
48            pos: record.position().cloned(),
49            err,
50        })
51    })
52}
53
54/// An over-engineered internal trait that permits writing a single Serde
55/// deserializer that works on both ByteRecord and StringRecord.
56///
57/// We *could* implement a single deserializer on `ByteRecord` and simply
58/// convert `StringRecord`s to `ByteRecord`s, but then the implementation
59/// would be required to redo UTF-8 validation checks in certain places.
60///
61/// How does this work? We create a new `DeRecordWrap` type that wraps
62/// either a `StringRecord` or a `ByteRecord`. We then implement
63/// `DeRecord` for `DeRecordWrap<ByteRecord>` and `DeRecordWrap<StringRecord>`.
64/// Finally, we impl `serde::Deserialize` for `DeRecordWrap<T>` where
65/// `T: DeRecord`. That is, the `DeRecord` type corresponds to the differences
66/// between deserializing into a `ByteRecord` and deserializing into a
67/// `StringRecord`.
68///
69/// The lifetime `'r` refers to the lifetime of the underlying record.
70trait DeRecord<'r> {
71    /// Returns true if and only if this deserialize has access to headers.
72    fn has_headers(&self) -> bool;
73
74    /// Extracts the next string header value from the underlying record.
75    fn next_header(&mut self) -> Result<Option<&'r str>, DeserializeError>;
76
77    /// Extracts the next raw byte header value from the underlying record.
78    fn next_header_bytes(
79        &mut self,
80    ) -> Result<Option<&'r [u8]>, DeserializeError>;
81
82    /// Extracts the next string field from the underlying record.
83    fn next_field(&mut self) -> Result<&'r str, DeserializeError>;
84
85    /// Extracts the next raw byte field from the underlying record.
86    fn next_field_bytes(&mut self) -> Result<&'r [u8], DeserializeError>;
87
88    /// Peeks at the next field from the underlying record.
89    fn peek_field(&mut self) -> Option<&'r [u8]>;
90
91    /// Returns an error corresponding to the most recently extracted field.
92    fn error(&self, kind: DeserializeErrorKind) -> DeserializeError;
93
94    /// Infer the type of the next field and deserialize it.
95    fn infer_deserialize<'de, V: Visitor<'de>>(
96        &mut self,
97        visitor: V,
98    ) -> Result<V::Value, DeserializeError>;
99}
100
101struct DeRecordWrap<T>(T);
102
103impl<'r, T: DeRecord<'r>> DeRecord<'r> for DeRecordWrap<T> {
104    #[inline]
105    fn has_headers(&self) -> bool {
106        self.0.has_headers()
107    }
108
109    #[inline]
110    fn next_header(&mut self) -> Result<Option<&'r str>, DeserializeError> {
111        self.0.next_header()
112    }
113
114    #[inline]
115    fn next_header_bytes(
116        &mut self,
117    ) -> Result<Option<&'r [u8]>, DeserializeError> {
118        self.0.next_header_bytes()
119    }
120
121    #[inline]
122    fn next_field(&mut self) -> Result<&'r str, DeserializeError> {
123        self.0.next_field()
124    }
125
126    #[inline]
127    fn next_field_bytes(&mut self) -> Result<&'r [u8], DeserializeError> {
128        self.0.next_field_bytes()
129    }
130
131    #[inline]
132    fn peek_field(&mut self) -> Option<&'r [u8]> {
133        self.0.peek_field()
134    }
135
136    #[inline]
137    fn error(&self, kind: DeserializeErrorKind) -> DeserializeError {
138        self.0.error(kind)
139    }
140
141    #[inline]
142    fn infer_deserialize<'de, V: Visitor<'de>>(
143        &mut self,
144        visitor: V,
145    ) -> Result<V::Value, DeserializeError> {
146        self.0.infer_deserialize(visitor)
147    }
148}
149
150struct DeStringRecord<'r> {
151    it: iter::Peekable<StringRecordIter<'r>>,
152    headers: Option<StringRecordIter<'r>>,
153    field: u64,
154}
155
156impl<'r> DeRecord<'r> for DeStringRecord<'r> {
157    #[inline]
158    fn has_headers(&self) -> bool {
159        self.headers.is_some()
160    }
161
162    #[inline]
163    fn next_header(&mut self) -> Result<Option<&'r str>, DeserializeError> {
164        Ok(self.headers.as_mut().and_then(|it| it.next()))
165    }
166
167    #[inline]
168    fn next_header_bytes(
169        &mut self,
170    ) -> Result<Option<&'r [u8]>, DeserializeError> {
171        Ok(self.next_header()?.map(|s| s.as_bytes()))
172    }
173
174    #[inline]
175    fn next_field(&mut self) -> Result<&'r str, DeserializeError> {
176        match self.it.next() {
177            Some(field) => {
178                self.field += 1;
179                Ok(field)
180            }
181            None => Err(DeserializeError {
182                field: None,
183                kind: DEK::UnexpectedEndOfRow,
184            }),
185        }
186    }
187
188    #[inline]
189    fn next_field_bytes(&mut self) -> Result<&'r [u8], DeserializeError> {
190        self.next_field().map(|s| s.as_bytes())
191    }
192
193    #[inline]
194    fn peek_field(&mut self) -> Option<&'r [u8]> {
195        self.it.peek().map(|s| s.as_bytes())
196    }
197
198    fn error(&self, kind: DeserializeErrorKind) -> DeserializeError {
199        DeserializeError { field: Some(self.field.saturating_sub(1)), kind }
200    }
201
202    fn infer_deserialize<'de, V: Visitor<'de>>(
203        &mut self,
204        visitor: V,
205    ) -> Result<V::Value, DeserializeError> {
206        let x = self.next_field()?;
207        if x == "true" {
208            visitor.visit_bool(true)
209        } else if x == "false" {
210            visitor.visit_bool(false)
211        } else if let Some(n) = try_positive_integer64(x) {
212            visitor.visit_u64(n)
213        } else if let Some(n) = try_negative_integer64(x) {
214            visitor.visit_i64(n)
215        } else if let Some(n) = try_positive_integer128(x) {
216            visitor.visit_u128(n)
217        } else if let Some(n) = try_negative_integer128(x) {
218            visitor.visit_i128(n)
219        } else if let Some(n) = try_float(x) {
220            visitor.visit_f64(n)
221        } else {
222            visitor.visit_str(x)
223        }
224    }
225}
226
227struct DeByteRecord<'r> {
228    it: iter::Peekable<ByteRecordIter<'r>>,
229    headers: Option<ByteRecordIter<'r>>,
230    field: u64,
231}
232
233impl<'r> DeRecord<'r> for DeByteRecord<'r> {
234    #[inline]
235    fn has_headers(&self) -> bool {
236        self.headers.is_some()
237    }
238
239    #[inline]
240    fn next_header(&mut self) -> Result<Option<&'r str>, DeserializeError> {
241        match self.next_header_bytes() {
242            Ok(Some(field)) => Ok(Some(
243                str::from_utf8(field)
244                    .map_err(|err| self.error(DEK::InvalidUtf8(err)))?,
245            )),
246            Ok(None) => Ok(None),
247            Err(err) => Err(err),
248        }
249    }
250
251    #[inline]
252    fn next_header_bytes(
253        &mut self,
254    ) -> Result<Option<&'r [u8]>, DeserializeError> {
255        Ok(self.headers.as_mut().and_then(|it| it.next()))
256    }
257
258    #[inline]
259    fn next_field(&mut self) -> Result<&'r str, DeserializeError> {
260        self.next_field_bytes().and_then(|field| {
261            str::from_utf8(field)
262                .map_err(|err| self.error(DEK::InvalidUtf8(err)))
263        })
264    }
265
266    #[inline]
267    fn next_field_bytes(&mut self) -> Result<&'r [u8], DeserializeError> {
268        match self.it.next() {
269            Some(field) => {
270                self.field += 1;
271                Ok(field)
272            }
273            None => Err(DeserializeError {
274                field: None,
275                kind: DEK::UnexpectedEndOfRow,
276            }),
277        }
278    }
279
280    #[inline]
281    fn peek_field(&mut self) -> Option<&'r [u8]> {
282        self.it.peek().copied()
283    }
284
285    fn error(&self, kind: DeserializeErrorKind) -> DeserializeError {
286        DeserializeError { field: Some(self.field.saturating_sub(1)), kind }
287    }
288
289    fn infer_deserialize<'de, V: Visitor<'de>>(
290        &mut self,
291        visitor: V,
292    ) -> Result<V::Value, DeserializeError> {
293        let x = self.next_field_bytes()?;
294        if x == b"true" {
295            visitor.visit_bool(true)
296        } else if x == b"false" {
297            visitor.visit_bool(false)
298        } else if let Some(n) = try_positive_integer64_bytes(x) {
299            visitor.visit_u64(n)
300        } else if let Some(n) = try_negative_integer64_bytes(x) {
301            visitor.visit_i64(n)
302        } else if let Some(n) = try_positive_integer128_bytes(x) {
303            visitor.visit_u128(n)
304        } else if let Some(n) = try_negative_integer128_bytes(x) {
305            visitor.visit_i128(n)
306        } else if let Some(n) = try_float_bytes(x) {
307            visitor.visit_f64(n)
308        } else if let Ok(s) = str::from_utf8(x) {
309            visitor.visit_str(s)
310        } else {
311            visitor.visit_bytes(x)
312        }
313    }
314}
315
316macro_rules! deserialize_int {
317    ($method:ident, $visit:ident, $inttype:ty) => {
318        fn $method<V: Visitor<'de>>(
319            self,
320            visitor: V,
321        ) -> Result<V::Value, Self::Error> {
322            let field = self.next_field()?;
323            let num = if let Some(digits) = field.strip_prefix("0x") {
324                <$inttype>::from_str_radix(digits, 16)
325            } else {
326                field.parse()
327            };
328            visitor.$visit(num.map_err(|err| self.error(DEK::ParseInt(err)))?)
329        }
330    };
331}
332
333impl<'a, 'de: 'a, T: DeRecord<'de>> Deserializer<'de>
334    for &'a mut DeRecordWrap<T>
335{
336    type Error = DeserializeError;
337
338    fn deserialize_any<V: Visitor<'de>>(
339        self,
340        visitor: V,
341    ) -> Result<V::Value, Self::Error> {
342        self.infer_deserialize(visitor)
343    }
344
345    fn deserialize_bool<V: Visitor<'de>>(
346        self,
347        visitor: V,
348    ) -> Result<V::Value, Self::Error> {
349        visitor.visit_bool(
350            self.next_field()?
351                .parse()
352                .map_err(|err| self.error(DEK::ParseBool(err)))?,
353        )
354    }
355
356    deserialize_int!(deserialize_u8, visit_u8, u8);
357    deserialize_int!(deserialize_u16, visit_u16, u16);
358    deserialize_int!(deserialize_u32, visit_u32, u32);
359    deserialize_int!(deserialize_u64, visit_u64, u64);
360    deserialize_int!(deserialize_u128, visit_u128, u128);
361    deserialize_int!(deserialize_i8, visit_i8, i8);
362    deserialize_int!(deserialize_i16, visit_i16, i16);
363    deserialize_int!(deserialize_i32, visit_i32, i32);
364    deserialize_int!(deserialize_i64, visit_i64, i64);
365    deserialize_int!(deserialize_i128, visit_i128, i128);
366
367    fn deserialize_f32<V: Visitor<'de>>(
368        self,
369        visitor: V,
370    ) -> Result<V::Value, Self::Error> {
371        visitor.visit_f32(
372            self.next_field()?
373                .parse()
374                .map_err(|err| self.error(DEK::ParseFloat(err)))?,
375        )
376    }
377
378    fn deserialize_f64<V: Visitor<'de>>(
379        self,
380        visitor: V,
381    ) -> Result<V::Value, Self::Error> {
382        visitor.visit_f64(
383            self.next_field()?
384                .parse()
385                .map_err(|err| self.error(DEK::ParseFloat(err)))?,
386        )
387    }
388
389    fn deserialize_char<V: Visitor<'de>>(
390        self,
391        visitor: V,
392    ) -> Result<V::Value, Self::Error> {
393        let field = self.next_field()?;
394        let len = field.chars().count();
395        if len != 1 {
396            return Err(self.error(DEK::Message(format!(
397                "expected single character but got {} characters in '{}'",
398                len, field
399            ))));
400        }
401        visitor.visit_char(field.chars().next().unwrap())
402    }
403
404    fn deserialize_str<V: Visitor<'de>>(
405        self,
406        visitor: V,
407    ) -> Result<V::Value, Self::Error> {
408        self.next_field().and_then(|f| visitor.visit_borrowed_str(f))
409    }
410
411    fn deserialize_string<V: Visitor<'de>>(
412        self,
413        visitor: V,
414    ) -> Result<V::Value, Self::Error> {
415        self.next_field().and_then(|f| visitor.visit_str(f))
416    }
417
418    fn deserialize_bytes<V: Visitor<'de>>(
419        self,
420        visitor: V,
421    ) -> Result<V::Value, Self::Error> {
422        self.next_field_bytes().and_then(|f| visitor.visit_borrowed_bytes(f))
423    }
424
425    fn deserialize_byte_buf<V: Visitor<'de>>(
426        self,
427        visitor: V,
428    ) -> Result<V::Value, Self::Error> {
429        self.next_field_bytes()
430            .and_then(|f| visitor.visit_byte_buf(f.to_vec()))
431    }
432
433    fn deserialize_option<V: Visitor<'de>>(
434        self,
435        visitor: V,
436    ) -> Result<V::Value, Self::Error> {
437        match self.peek_field() {
438            None => visitor.visit_none(),
439            Some([]) => {
440                self.next_field().expect("empty field");
441                visitor.visit_none()
442            }
443            Some(_) => visitor.visit_some(self),
444        }
445    }
446
447    fn deserialize_unit<V: Visitor<'de>>(
448        self,
449        visitor: V,
450    ) -> Result<V::Value, Self::Error> {
451        visitor.visit_unit()
452    }
453
454    fn deserialize_unit_struct<V: Visitor<'de>>(
455        self,
456        _name: &'static str,
457        visitor: V,
458    ) -> Result<V::Value, Self::Error> {
459        visitor.visit_unit()
460    }
461
462    fn deserialize_newtype_struct<V: Visitor<'de>>(
463        self,
464        _name: &'static str,
465        visitor: V,
466    ) -> Result<V::Value, Self::Error> {
467        visitor.visit_newtype_struct(self)
468    }
469
470    fn deserialize_seq<V: Visitor<'de>>(
471        self,
472        visitor: V,
473    ) -> Result<V::Value, Self::Error> {
474        visitor.visit_seq(self)
475    }
476
477    fn deserialize_tuple<V: Visitor<'de>>(
478        self,
479        _len: usize,
480        visitor: V,
481    ) -> Result<V::Value, Self::Error> {
482        visitor.visit_seq(self)
483    }
484
485    fn deserialize_tuple_struct<V: Visitor<'de>>(
486        self,
487        _name: &'static str,
488        _len: usize,
489        visitor: V,
490    ) -> Result<V::Value, Self::Error> {
491        visitor.visit_seq(self)
492    }
493
494    fn deserialize_map<V: Visitor<'de>>(
495        self,
496        visitor: V,
497    ) -> Result<V::Value, Self::Error> {
498        if !self.has_headers() {
499            visitor.visit_seq(self)
500        } else {
501            visitor.visit_map(self)
502        }
503    }
504
505    fn deserialize_struct<V: Visitor<'de>>(
506        self,
507        _name: &'static str,
508        _fields: &'static [&'static str],
509        visitor: V,
510    ) -> Result<V::Value, Self::Error> {
511        if !self.has_headers() {
512            visitor.visit_seq(self)
513        } else {
514            visitor.visit_map(self)
515        }
516    }
517
518    fn deserialize_identifier<V: Visitor<'de>>(
519        self,
520        _visitor: V,
521    ) -> Result<V::Value, Self::Error> {
522        Err(self.error(DEK::Unsupported("deserialize_identifier".into())))
523    }
524
525    fn deserialize_enum<V: Visitor<'de>>(
526        self,
527        _name: &'static str,
528        _variants: &'static [&'static str],
529        visitor: V,
530    ) -> Result<V::Value, Self::Error> {
531        visitor.visit_enum(self)
532    }
533
534    fn deserialize_ignored_any<V: Visitor<'de>>(
535        self,
536        visitor: V,
537    ) -> Result<V::Value, Self::Error> {
538        // Read and drop the next field.
539        // This code is reached, e.g., when trying to deserialize a header
540        // that doesn't exist in the destination struct.
541        let _ = self.next_field_bytes()?;
542        visitor.visit_unit()
543    }
544}
545
546impl<'a, 'de: 'a, T: DeRecord<'de>> EnumAccess<'de>
547    for &'a mut DeRecordWrap<T>
548{
549    type Error = DeserializeError;
550    type Variant = Self;
551
552    fn variant_seed<V: DeserializeSeed<'de>>(
553        self,
554        seed: V,
555    ) -> Result<(V::Value, Self::Variant), Self::Error> {
556        let variant_name = self.next_field()?;
557        seed.deserialize(variant_name.into_deserializer()).map(|v| (v, self))
558    }
559}
560
561impl<'a, 'de: 'a, T: DeRecord<'de>> VariantAccess<'de>
562    for &'a mut DeRecordWrap<T>
563{
564    type Error = DeserializeError;
565
566    fn unit_variant(self) -> Result<(), Self::Error> {
567        Ok(())
568    }
569
570    fn newtype_variant_seed<U: DeserializeSeed<'de>>(
571        self,
572        _seed: U,
573    ) -> Result<U::Value, Self::Error> {
574        let unexp = Unexpected::UnitVariant;
575        Err(DeserializeError::invalid_type(unexp, &"newtype variant"))
576    }
577
578    fn tuple_variant<V: Visitor<'de>>(
579        self,
580        _len: usize,
581        _visitor: V,
582    ) -> Result<V::Value, Self::Error> {
583        let unexp = Unexpected::UnitVariant;
584        Err(DeserializeError::invalid_type(unexp, &"tuple variant"))
585    }
586
587    fn struct_variant<V: Visitor<'de>>(
588        self,
589        _fields: &'static [&'static str],
590        _visitor: V,
591    ) -> Result<V::Value, Self::Error> {
592        let unexp = Unexpected::UnitVariant;
593        Err(DeserializeError::invalid_type(unexp, &"struct variant"))
594    }
595}
596
597impl<'a, 'de: 'a, T: DeRecord<'de>> SeqAccess<'de>
598    for &'a mut DeRecordWrap<T>
599{
600    type Error = DeserializeError;
601
602    fn next_element_seed<U: DeserializeSeed<'de>>(
603        &mut self,
604        seed: U,
605    ) -> Result<Option<U::Value>, Self::Error> {
606        if self.peek_field().is_none() {
607            Ok(None)
608        } else {
609            seed.deserialize(&mut **self).map(Some)
610        }
611    }
612}
613
614impl<'a, 'de: 'a, T: DeRecord<'de>> MapAccess<'de>
615    for &'a mut DeRecordWrap<T>
616{
617    type Error = DeserializeError;
618
619    fn next_key_seed<K: DeserializeSeed<'de>>(
620        &mut self,
621        seed: K,
622    ) -> Result<Option<K::Value>, Self::Error> {
623        assert!(self.has_headers());
624        let field = match self.next_header_bytes()? {
625            None => return Ok(None),
626            Some(field) => field,
627        };
628        seed.deserialize(BorrowedBytesDeserializer::new(field)).map(Some)
629    }
630
631    fn next_value_seed<K: DeserializeSeed<'de>>(
632        &mut self,
633        seed: K,
634    ) -> Result<K::Value, Self::Error> {
635        seed.deserialize(&mut **self)
636    }
637}
638
639/// An Serde deserialization error.
640#[derive(Clone, Debug, Eq, PartialEq)]
641pub struct DeserializeError {
642    field: Option<u64>,
643    kind: DeserializeErrorKind,
644}
645
646/// The type of a Serde deserialization error.
647#[derive(Clone, Debug, Eq, PartialEq)]
648pub enum DeserializeErrorKind {
649    /// A generic Serde deserialization error.
650    Message(String),
651    /// A generic Serde unsupported error.
652    Unsupported(String),
653    /// This error occurs when a Rust type expects to decode another field
654    /// from a row, but no more fields exist.
655    UnexpectedEndOfRow,
656    /// This error occurs when UTF-8 validation on a field fails. UTF-8
657    /// validation is only performed when the Rust type requires it (e.g.,
658    /// a `String` or `&str` type).
659    InvalidUtf8(str::Utf8Error),
660    /// This error occurs when a boolean value fails to parse.
661    ParseBool(str::ParseBoolError),
662    /// This error occurs when an integer value fails to parse.
663    ParseInt(num::ParseIntError),
664    /// This error occurs when a float value fails to parse.
665    ParseFloat(num::ParseFloatError),
666}
667
668impl SerdeError for DeserializeError {
669    fn custom<T: fmt::Display>(msg: T) -> DeserializeError {
670        DeserializeError { field: None, kind: DEK::Message(msg.to_string()) }
671    }
672}
673
674impl StdError for DeserializeError {
675    fn description(&self) -> &str {
676        self.kind.description()
677    }
678}
679
680impl fmt::Display for DeserializeError {
681    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
682        if let Some(field) = self.field {
683            write!(f, "field {}: {}", field, self.kind)
684        } else {
685            write!(f, "{}", self.kind)
686        }
687    }
688}
689
690impl fmt::Display for DeserializeErrorKind {
691    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
692        use self::DeserializeErrorKind::*;
693
694        match *self {
695            Message(ref msg) => write!(f, "{}", msg),
696            Unsupported(ref which) => {
697                write!(f, "unsupported deserializer method: {}", which)
698            }
699            UnexpectedEndOfRow => write!(f, "{}", self.description()),
700            InvalidUtf8(ref err) => err.fmt(f),
701            ParseBool(ref err) => err.fmt(f),
702            ParseInt(ref err) => err.fmt(f),
703            ParseFloat(ref err) => err.fmt(f),
704        }
705    }
706}
707
708impl DeserializeError {
709    /// Return the field index (starting at 0) of this error, if available.
710    pub fn field(&self) -> Option<u64> {
711        self.field
712    }
713
714    /// Return the underlying error kind.
715    pub fn kind(&self) -> &DeserializeErrorKind {
716        &self.kind
717    }
718}
719
720impl DeserializeErrorKind {
721    #[allow(deprecated)]
722    fn description(&self) -> &str {
723        use self::DeserializeErrorKind::*;
724
725        match *self {
726            Message(_) => "deserialization error",
727            Unsupported(_) => "unsupported deserializer method",
728            UnexpectedEndOfRow => "expected field, but got end of row",
729            InvalidUtf8(ref err) => err.description(),
730            ParseBool(ref err) => err.description(),
731            ParseInt(ref err) => err.description(),
732            ParseFloat(ref err) => err.description(),
733        }
734    }
735}
736
737fn try_positive_integer128(s: &str) -> Option<u128> {
738    s.parse().ok()
739}
740
741fn try_negative_integer128(s: &str) -> Option<i128> {
742    s.parse().ok()
743}
744
745fn try_positive_integer64(s: &str) -> Option<u64> {
746    s.parse().ok()
747}
748
749fn try_negative_integer64(s: &str) -> Option<i64> {
750    s.parse().ok()
751}
752
753fn try_float(s: &str) -> Option<f64> {
754    s.parse().ok()
755}
756
757fn try_positive_integer64_bytes(s: &[u8]) -> Option<u64> {
758    str::from_utf8(s).ok().and_then(|s| s.parse().ok())
759}
760
761fn try_negative_integer64_bytes(s: &[u8]) -> Option<i64> {
762    str::from_utf8(s).ok().and_then(|s| s.parse().ok())
763}
764
765fn try_positive_integer128_bytes(s: &[u8]) -> Option<u128> {
766    str::from_utf8(s).ok().and_then(|s| s.parse().ok())
767}
768
769fn try_negative_integer128_bytes(s: &[u8]) -> Option<i128> {
770    str::from_utf8(s).ok().and_then(|s| s.parse().ok())
771}
772
773fn try_float_bytes(s: &[u8]) -> Option<f64> {
774    str::from_utf8(s).ok().and_then(|s| s.parse().ok())
775}
776
777#[cfg(test)]
778mod tests {
779    use std::collections::HashMap;
780
781    use {
782        bstr::BString,
783        serde::{de::DeserializeOwned, Deserialize},
784    };
785
786    use crate::{
787        byte_record::ByteRecord, error::Error, string_record::StringRecord,
788    };
789
790    use super::{deserialize_byte_record, deserialize_string_record};
791
792    fn de<D: DeserializeOwned>(fields: &[&str]) -> Result<D, Error> {
793        let record = StringRecord::from(fields);
794        deserialize_string_record(&record, None)
795    }
796
797    fn de_headers<D: DeserializeOwned>(
798        headers: &[&str],
799        fields: &[&str],
800    ) -> Result<D, Error> {
801        let headers = StringRecord::from(headers);
802        let record = StringRecord::from(fields);
803        deserialize_string_record(&record, Some(&headers))
804    }
805
806    fn b<T: AsRef<[u8]> + ?Sized>(bytes: &T) -> &[u8] {
807        bytes.as_ref()
808    }
809
810    #[test]
811    fn with_header() {
812        #[derive(Deserialize, Debug, PartialEq)]
813        struct Foo {
814            z: f64,
815            y: i32,
816            x: String,
817        }
818
819        let got: Foo =
820            de_headers(&["x", "y", "z"], &["hi", "42", "1.3"]).unwrap();
821        assert_eq!(got, Foo { x: "hi".into(), y: 42, z: 1.3 });
822    }
823
824    #[test]
825    fn with_header_unknown() {
826        #[derive(Deserialize, Debug, PartialEq)]
827        #[serde(deny_unknown_fields)]
828        struct Foo {
829            z: f64,
830            y: i32,
831            x: String,
832        }
833        assert!(de_headers::<Foo>(
834            &["a", "x", "y", "z"],
835            &["foo", "hi", "42", "1.3"],
836        )
837        .is_err());
838    }
839
840    #[test]
841    fn with_header_missing() {
842        #[derive(Deserialize, Debug, PartialEq)]
843        struct Foo {
844            z: f64,
845            y: i32,
846            x: String,
847        }
848        assert!(de_headers::<Foo>(&["y", "z"], &["42", "1.3"],).is_err());
849    }
850
851    #[test]
852    fn with_header_missing_ok() {
853        #[derive(Deserialize, Debug, PartialEq)]
854        struct Foo {
855            z: f64,
856            y: i32,
857            x: Option<String>,
858        }
859
860        let got: Foo = de_headers(&["y", "z"], &["42", "1.3"]).unwrap();
861        assert_eq!(got, Foo { x: None, y: 42, z: 1.3 });
862    }
863
864    #[test]
865    fn with_header_no_fields() {
866        #[derive(Deserialize, Debug, PartialEq)]
867        struct Foo {
868            z: f64,
869            y: i32,
870            x: Option<String>,
871        }
872
873        let got = de_headers::<Foo>(&["y", "z"], &[]);
874        assert!(got.is_err());
875    }
876
877    #[test]
878    fn with_header_empty() {
879        #[derive(Deserialize, Debug, PartialEq)]
880        struct Foo {
881            z: f64,
882            y: i32,
883            x: Option<String>,
884        }
885
886        let got = de_headers::<Foo>(&[], &[]);
887        assert!(got.is_err());
888    }
889
890    #[test]
891    fn with_header_empty_ok() {
892        #[derive(Deserialize, Debug, PartialEq)]
893        struct Foo;
894
895        #[derive(Deserialize, Debug, PartialEq)]
896        struct Bar {}
897
898        let got = de_headers::<Foo>(&[], &[]);
899        assert_eq!(got.unwrap(), Foo);
900
901        let got = de_headers::<Bar>(&[], &[]);
902        assert_eq!(got.unwrap(), Bar {});
903
904        let got = de_headers::<()>(&[], &[]);
905        assert_eq!(got.unwrap(), ());
906    }
907
908    #[test]
909    fn without_header() {
910        #[derive(Deserialize, Debug, PartialEq)]
911        struct Foo {
912            z: f64,
913            y: i32,
914            x: String,
915        }
916
917        let got: Foo = de(&["1.3", "42", "hi"]).unwrap();
918        assert_eq!(got, Foo { x: "hi".into(), y: 42, z: 1.3 });
919    }
920
921    #[test]
922    fn no_fields() {
923        assert!(de::<String>(&[]).is_err());
924    }
925
926    #[test]
927    fn one_field() {
928        let got: i32 = de(&["42"]).unwrap();
929        assert_eq!(got, 42);
930    }
931
932    #[test]
933    fn one_field_128() {
934        let got: i128 = de(&["2010223372036854775808"]).unwrap();
935        assert_eq!(got, 2010223372036854775808);
936    }
937
938    #[test]
939    fn two_fields() {
940        let got: (i32, bool) = de(&["42", "true"]).unwrap();
941        assert_eq!(got, (42, true));
942
943        #[derive(Deserialize, Debug, PartialEq)]
944        struct Foo(i32, bool);
945
946        let got: Foo = de(&["42", "true"]).unwrap();
947        assert_eq!(got, Foo(42, true));
948    }
949
950    #[test]
951    fn two_fields_too_many() {
952        let got: (i32, bool) = de(&["42", "true", "z", "z"]).unwrap();
953        assert_eq!(got, (42, true));
954    }
955
956    #[test]
957    fn two_fields_too_few() {
958        assert!(de::<(i32, bool)>(&["42"]).is_err());
959    }
960
961    #[test]
962    fn one_char() {
963        let got: char = de(&["a"]).unwrap();
964        assert_eq!(got, 'a');
965    }
966
967    #[test]
968    fn no_chars() {
969        assert!(de::<char>(&[""]).is_err());
970    }
971
972    #[test]
973    fn too_many_chars() {
974        assert!(de::<char>(&["ab"]).is_err());
975    }
976
977    #[test]
978    fn simple_seq() {
979        let got: Vec<i32> = de(&["1", "5", "10"]).unwrap();
980        assert_eq!(got, vec![1, 5, 10]);
981    }
982
983    #[test]
984    fn simple_hex_seq() {
985        let got: Vec<i32> = de(&["0x7F", "0xA9", "0x10"]).unwrap();
986        assert_eq!(got, vec![0x7F, 0xA9, 0x10]);
987    }
988
989    #[test]
990    fn mixed_hex_seq() {
991        let got: Vec<i32> = de(&["0x7F", "0xA9", "10"]).unwrap();
992        assert_eq!(got, vec![0x7F, 0xA9, 10]);
993    }
994
995    #[test]
996    fn bad_hex_seq() {
997        assert!(de::<Vec<u8>>(&["7F", "0xA9", "10"]).is_err());
998    }
999
1000    #[test]
1001    fn seq_in_struct() {
1002        #[derive(Deserialize, Debug, PartialEq)]
1003        struct Foo {
1004            xs: Vec<i32>,
1005        }
1006        let got: Foo = de(&["1", "5", "10"]).unwrap();
1007        assert_eq!(got, Foo { xs: vec![1, 5, 10] });
1008    }
1009
1010    #[test]
1011    fn seq_in_struct_tail() {
1012        #[derive(Deserialize, Debug, PartialEq)]
1013        struct Foo {
1014            label: String,
1015            xs: Vec<i32>,
1016        }
1017        let got: Foo = de(&["foo", "1", "5", "10"]).unwrap();
1018        assert_eq!(got, Foo { label: "foo".into(), xs: vec![1, 5, 10] });
1019    }
1020
1021    #[test]
1022    fn map_headers() {
1023        let got: HashMap<String, i32> =
1024            de_headers(&["a", "b", "c"], &["1", "5", "10"]).unwrap();
1025        assert_eq!(got.len(), 3);
1026        assert_eq!(got["a"], 1);
1027        assert_eq!(got["b"], 5);
1028        assert_eq!(got["c"], 10);
1029    }
1030
1031    #[test]
1032    fn map_no_headers() {
1033        let got = de::<HashMap<String, i32>>(&["1", "5", "10"]);
1034        assert!(got.is_err());
1035    }
1036
1037    #[test]
1038    fn bytes() {
1039        let got: Vec<u8> = de::<BString>(&["foobar"]).unwrap().into();
1040        assert_eq!(got, b"foobar".to_vec());
1041    }
1042
1043    #[test]
1044    fn adjacent_fixed_arrays() {
1045        let got: ([u32; 2], [u32; 2]) = de(&["1", "5", "10", "15"]).unwrap();
1046        assert_eq!(got, ([1, 5], [10, 15]));
1047    }
1048
1049    #[test]
1050    fn enum_label_simple_tagged() {
1051        #[derive(Deserialize, Debug, PartialEq)]
1052        struct Row {
1053            label: Label,
1054            x: f64,
1055        }
1056
1057        #[derive(Deserialize, Debug, PartialEq)]
1058        #[serde(rename_all = "snake_case")]
1059        enum Label {
1060            Foo,
1061            Bar,
1062            Baz,
1063        }
1064
1065        let got: Row = de_headers(&["label", "x"], &["bar", "5"]).unwrap();
1066        assert_eq!(got, Row { label: Label::Bar, x: 5.0 });
1067    }
1068
1069    #[test]
1070    fn enum_untagged() {
1071        #[derive(Deserialize, Debug, PartialEq)]
1072        struct Row {
1073            x: Boolish,
1074            y: Boolish,
1075            z: Boolish,
1076        }
1077
1078        #[derive(Deserialize, Debug, PartialEq)]
1079        #[serde(rename_all = "snake_case")]
1080        #[serde(untagged)]
1081        enum Boolish {
1082            Bool(bool),
1083            Number(i64),
1084            String(String),
1085        }
1086
1087        let got: Row =
1088            de_headers(&["x", "y", "z"], &["true", "null", "1"]).unwrap();
1089        assert_eq!(
1090            got,
1091            Row {
1092                x: Boolish::Bool(true),
1093                y: Boolish::String("null".into()),
1094                z: Boolish::Number(1),
1095            }
1096        );
1097    }
1098
1099    #[test]
1100    fn option_empty_field() {
1101        #[derive(Deserialize, Debug, PartialEq)]
1102        struct Foo {
1103            a: Option<i32>,
1104            b: String,
1105            c: Option<i32>,
1106        }
1107
1108        let got: Foo =
1109            de_headers(&["a", "b", "c"], &["", "foo", "5"]).unwrap();
1110        assert_eq!(got, Foo { a: None, b: "foo".into(), c: Some(5) });
1111    }
1112
1113    #[test]
1114    fn option_invalid_field() {
1115        #[derive(Deserialize, Debug, PartialEq)]
1116        struct Foo {
1117            #[serde(deserialize_with = "crate::invalid_option")]
1118            a: Option<i32>,
1119            #[serde(deserialize_with = "crate::invalid_option")]
1120            b: Option<i32>,
1121            #[serde(deserialize_with = "crate::invalid_option")]
1122            c: Option<i32>,
1123        }
1124
1125        let got: Foo =
1126            de_headers(&["a", "b", "c"], &["xyz", "", "5"]).unwrap();
1127        assert_eq!(got, Foo { a: None, b: None, c: Some(5) });
1128    }
1129
1130    #[test]
1131    fn borrowed() {
1132        #[derive(Deserialize, Debug, PartialEq)]
1133        struct Foo<'a, 'c> {
1134            a: &'a str,
1135            b: i32,
1136            c: &'c str,
1137        }
1138
1139        let headers = StringRecord::from(vec!["a", "b", "c"]);
1140        let record = StringRecord::from(vec!["foo", "5", "bar"]);
1141        let got: Foo =
1142            deserialize_string_record(&record, Some(&headers)).unwrap();
1143        assert_eq!(got, Foo { a: "foo", b: 5, c: "bar" });
1144    }
1145
1146    #[test]
1147    fn borrowed_map() {
1148        use std::collections::HashMap;
1149
1150        let headers = StringRecord::from(vec!["a", "b", "c"]);
1151        let record = StringRecord::from(vec!["aardvark", "bee", "cat"]);
1152        let got: HashMap<&str, &str> =
1153            deserialize_string_record(&record, Some(&headers)).unwrap();
1154
1155        let expected: HashMap<&str, &str> =
1156            headers.iter().zip(&record).collect();
1157        assert_eq!(got, expected);
1158    }
1159
1160    #[test]
1161    fn borrowed_map_bytes() {
1162        use std::collections::HashMap;
1163
1164        let headers = ByteRecord::from(vec![b"a", b"\xFF", b"c"]);
1165        let record = ByteRecord::from(vec!["aardvark", "bee", "cat"]);
1166        let got: HashMap<&[u8], &[u8]> =
1167            deserialize_byte_record(&record, Some(&headers)).unwrap();
1168
1169        let expected: HashMap<&[u8], &[u8]> =
1170            headers.iter().zip(&record).collect();
1171        assert_eq!(got, expected);
1172    }
1173
1174    #[test]
1175    fn flatten() {
1176        #[derive(Deserialize, Debug, PartialEq)]
1177        struct Input {
1178            x: f64,
1179            y: f64,
1180        }
1181
1182        #[derive(Deserialize, Debug, PartialEq)]
1183        struct Properties {
1184            prop1: f64,
1185            prop2: f64,
1186        }
1187
1188        #[derive(Deserialize, Debug, PartialEq)]
1189        struct Row {
1190            #[serde(flatten)]
1191            input: Input,
1192            #[serde(flatten)]
1193            properties: Properties,
1194        }
1195
1196        let header = StringRecord::from(vec!["x", "y", "prop1", "prop2"]);
1197        let record = StringRecord::from(vec!["1", "2", "3", "4"]);
1198        let got: Row = record.deserialize(Some(&header)).unwrap();
1199        assert_eq!(
1200            got,
1201            Row {
1202                input: Input { x: 1.0, y: 2.0 },
1203                properties: Properties { prop1: 3.0, prop2: 4.0 },
1204            }
1205        );
1206    }
1207
1208    #[test]
1209    fn partially_invalid_utf8() {
1210        #[derive(Debug, Deserialize, PartialEq)]
1211        struct Row {
1212            h1: String,
1213            h2: BString,
1214            h3: String,
1215        }
1216
1217        let headers = ByteRecord::from(vec![b"h1", b"h2", b"h3"]);
1218        let record =
1219            ByteRecord::from(vec![b(b"baz"), b(b"foo\xFFbar"), b(b"quux")]);
1220        let got: Row =
1221            deserialize_byte_record(&record, Some(&headers)).unwrap();
1222        assert_eq!(
1223            got,
1224            Row {
1225                h1: "baz".to_string(),
1226                h2: BString::from(b"foo\xFFbar".to_vec()),
1227                h3: "quux".to_string(),
1228            }
1229        );
1230    }
1231}