py_marshal/
lib.rs

1// Ported from <https://github.com/python/cpython/blob/master/Python/marshal.c>
2use bitflags::bitflags;
3use num_bigint::BigInt;
4use num_complex::Complex;
5use num_derive::{FromPrimitive, ToPrimitive};
6use ordered_float::OrderedFloat;
7use std::{
8    collections::{HashMap, HashSet},
9    convert::TryFrom,
10    fmt,
11    hash::{Hash, Hasher},
12    iter::FromIterator,
13    sync::{Arc, RwLock},
14};
15
16/// `Arc` = immutable
17/// `ArcRwLock` = mutable
18pub type ArcRwLock<T> = Arc<RwLock<T>>;
19
20#[derive(FromPrimitive, ToPrimitive, Debug, Copy, Clone)]
21#[repr(u8)]
22enum Type {
23    Null = b'0',
24    None = b'N',
25    False = b'F',
26    True = b'T',
27    StopIter = b'S',
28    Ellipsis = b'.',
29    Int = b'i',
30    Int64 = b'I',
31    Float = b'f',
32    BinaryFloat = b'g',
33    Complex = b'x',
34    BinaryComplex = b'y',
35    Long = b'l',
36    String = b's',
37    Interned = b't',
38    Ref = b'r',
39    Tuple = b'(',
40    List = b'[',
41    Dict = b'{',
42    Code = b'c',
43    Unicode = b'u',
44    Unknown = b'?',
45    Set = b'<',
46    FrozenSet = b'>',
47    Ascii = b'a',
48    AsciiInterned = b'A',
49    SmallTuple = b')',
50    ShortAscii = b'z',
51    ShortAsciiInterned = b'Z',
52}
53impl Type {
54    const FLAG_REF: u8 = b'\x80';
55}
56
57struct Depth(Arc<()>);
58impl Depth {
59    const MAX: usize = 900;
60
61    #[must_use]
62    pub fn new() -> Self {
63        Self(Arc::new(()))
64    }
65
66    pub fn try_clone(&self) -> Option<Self> {
67        if Arc::strong_count(&self.0) > Self::MAX {
68            None
69        } else {
70            Some(Self(self.0.clone()))
71        }
72    }
73}
74impl fmt::Debug for Depth {
75    fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
76        f.debug_tuple("Depth")
77            .field(&Arc::strong_count(&self.0))
78            .finish()
79    }
80}
81
82bitflags! {
83    pub struct CodeFlags: u32 {
84        const OPTIMIZED                   = 0x1;
85        const NEWLOCALS                   = 0x2;
86        const VARARGS                     = 0x4;
87        const VARKEYWORDS                 = 0x8;
88        const NESTED                     = 0x10;
89        const GENERATOR                  = 0x20;
90        const NOFREE                     = 0x40;
91        const COROUTINE                  = 0x80;
92        const ITERABLE_COROUTINE        = 0x100;
93        const ASYNC_GENERATOR           = 0x200;
94        // TODO: old versions
95        const GENERATOR_ALLOWED        = 0x1000;
96        const FUTURE_DIVISION          = 0x2000;
97        const FUTURE_ABSOLUTE_IMPORT   = 0x4000;
98        const FUTURE_WITH_STATEMENT    = 0x8000;
99        const FUTURE_PRINT_FUNCTION   = 0x10000;
100        const FUTURE_UNICODE_LITERALS = 0x20000;
101        const FUTURE_BARRY_AS_BDFL    = 0x40000;
102        const FUTURE_GENERATOR_STOP   = 0x80000;
103        #[allow(clippy::unreadable_literal)]
104        const FUTURE_ANNOTATIONS     = 0x100000;
105    }
106}
107
108#[rustfmt::skip]
109#[derive(Clone, Debug)]
110pub struct Code {
111    pub argcount:        u32,
112    pub posonlyargcount: u32,
113    pub kwonlyargcount:  u32,
114    pub nlocals:         u32,
115    pub stacksize:       u32,
116    pub flags:           CodeFlags,
117    pub code:            Arc<Vec<u8>>,
118    pub consts:          Arc<Vec<Obj>>,
119    pub names:           Vec<Arc<String>>,
120    pub varnames:        Vec<Arc<String>>,
121    pub freevars:        Vec<Arc<String>>,
122    pub cellvars:        Vec<Arc<String>>,
123    pub filename:        Arc<String>,
124    pub name:            Arc<String>,
125    pub firstlineno:     u32,
126    pub lnotab:          Arc<Vec<u8>>,
127}
128
129#[rustfmt::skip]
130#[derive(Clone)]
131pub enum Obj {
132    None,
133    StopIteration,
134    Ellipsis,
135    Bool     (bool),
136    Long     (Arc<BigInt>),
137    Float    (f64),
138    Complex  (Complex<f64>),
139    Bytes    (Arc<Vec<u8>>),
140    String   (Arc<String>),
141    Tuple    (Arc<Vec<Obj>>),
142    List     (ArcRwLock<Vec<Obj>>),
143    Dict     (ArcRwLock<HashMap<ObjHashable, Obj>>),
144    Set      (ArcRwLock<HashSet<ObjHashable>>),
145    FrozenSet(Arc<HashSet<ObjHashable>>),
146    Code     (Arc<Code>),
147    // etc.
148}
149macro_rules! define_extract {
150    ($extract_fn:ident($variant:ident) -> ()) => {
151        define_extract! { $extract_fn -> () { $variant => () } }
152    };
153    ($extract_fn:ident($variant:ident) -> Arc<$ret:ty>) => {
154        define_extract! { $extract_fn -> Arc<$ret> { $variant(x) => x } }
155    };
156    ($extract_fn:ident($variant:ident) -> ArcRwLock<$ret:ty>) => {
157        define_extract! { $extract_fn -> ArcRwLock<$ret> { $variant(x) => x } }
158    };
159    ($extract_fn:ident($variant:ident) -> $ret:ty) => {
160        define_extract! { $extract_fn -> $ret { $variant(x) => x } }
161    };
162    ($extract_fn:ident -> $ret:ty { $variant:ident$(($($pat:pat),+))? => $expr:expr }) => {
163        /// # Errors
164        /// Returns a reference to self if extraction fails
165        pub fn $extract_fn(self) -> Result<$ret, Self> {
166            if let Self::$variant$(($($pat),+))? = self {
167                Ok($expr)
168            } else {
169                Err(self)
170            }
171        }
172    }
173}
174macro_rules! define_is {
175    ($is_fn:ident($variant:ident$(($($pat:pat),+))?)) => {
176        /// # Errors
177        /// Returns a reference to self if extraction fails
178        #[must_use]
179        pub fn $is_fn(&self) -> bool {
180            if let Self::$variant$(($($pat),+))? = self {
181                true
182            } else {
183                false
184            }
185        }
186    }
187}
188impl Obj {
189    define_extract! { extract_none          (None)          -> ()                                    }
190    define_extract! { extract_stop_iteration(StopIteration) -> ()                                    }
191    define_extract! { extract_bool          (Bool)          -> bool                                  }
192    define_extract! { extract_long          (Long)          -> Arc<BigInt>                           }
193    define_extract! { extract_float         (Float)         -> f64                                   }
194    define_extract! { extract_bytes         (Bytes)         -> Arc<Vec<u8>>                          }
195    define_extract! { extract_string        (String)        -> Arc<String>                           }
196    define_extract! { extract_tuple         (Tuple)         -> Arc<Vec<Self>>                        }
197    define_extract! { extract_list          (List)          -> ArcRwLock<Vec<Self>>                  }
198    define_extract! { extract_dict          (Dict)          -> ArcRwLock<HashMap<ObjHashable, Self>> }
199    define_extract! { extract_set           (Set)           -> ArcRwLock<HashSet<ObjHashable>>       }
200    define_extract! { extract_frozenset     (FrozenSet)     -> Arc<HashSet<ObjHashable>>             }
201    define_extract! { extract_code          (Code)          -> Arc<Code>                             }
202
203    define_is! { is_none          (None)          }
204    define_is! { is_stop_iteration(StopIteration) }
205    define_is! { is_bool          (Bool(_))       }
206    define_is! { is_long          (Long(_))       }
207    define_is! { is_float         (Float(_))      }
208    define_is! { is_bytes         (Bytes(_))      }
209    define_is! { is_string        (String(_))     }
210    define_is! { is_tuple         (Tuple(_))      }
211    define_is! { is_list          (List(_))       }
212    define_is! { is_dict          (Dict(_))       }
213    define_is! { is_set           (Set(_))        }
214    define_is! { is_frozenset     (FrozenSet(_))  }
215    define_is! { is_code          (Code(_))       }
216}
217impl From<&ObjHashable> for Obj {
218    fn from(orig: &ObjHashable) -> Self {
219        match orig {
220            ObjHashable::None => Self::None,
221            ObjHashable::StopIteration => Self::StopIteration,
222            ObjHashable::Ellipsis => Self::Ellipsis,
223            ObjHashable::Bool(x) => Self::Bool(*x),
224            ObjHashable::Long(x) => Self::Long(Arc::clone(x)),
225            ObjHashable::Float(x) => Self::Float(x.into_inner()),
226            ObjHashable::Complex(Complex { re, im }) => Self::Complex(Complex {
227                re: re.into_inner(),
228                im: im.into_inner(),
229            }),
230            ObjHashable::Bytes(x) => Self::Bytes(Arc::clone(x)),
231            ObjHashable::String(x) => Self::String(Arc::clone(x)),
232            ObjHashable::Tuple(x) => Self::Tuple(Arc::new(x.iter().map(Self::from).collect())),
233            ObjHashable::FrozenSet(x) => Self::FrozenSet(Arc::new(x.inner().clone())),
234        }
235    }
236}
237/// Should mostly match Python's repr
238///
239/// # Float, Complex
240/// - Uses `float('...')` instead of `...` for nan, inf, and -inf.
241/// - Uses Rust's float-to-decimal conversion.
242///
243/// # Bytes, String
244/// - Always uses double-quotes
245/// - Escapes both kinds of quotes
246///
247/// # Code
248/// - Uses named arguments for readability
249/// - lnotab is formatted as bytes(...) with a list of integers, instead of a bytes literal
250impl fmt::Debug for Obj {
251    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
252        match self {
253            Self::None => write!(f, "None"),
254            Self::StopIteration => write!(f, "StopIteration"),
255            Self::Ellipsis => write!(f, "Ellipsis"),
256            Self::Bool(true) => write!(f, "True"),
257            Self::Bool(false) => write!(f, "False"),
258            Self::Long(x) => write!(f, "{}", x),
259            &Self::Float(x) => python_float_repr_full(f, x),
260            &Self::Complex(x) => python_complex_repr(f, x),
261            Self::Bytes(x) => python_bytes_repr(f, x),
262            Self::String(x) => python_string_repr(f, x),
263            Self::Tuple(x) => python_tuple_repr(f, x),
264            Self::List(x) => f.debug_list().entries(x.read().unwrap().iter()).finish(),
265            Self::Dict(x) => f.debug_map().entries(x.read().unwrap().iter()).finish(),
266            Self::Set(x) => f.debug_set().entries(x.read().unwrap().iter()).finish(),
267            Self::FrozenSet(x) => python_frozenset_repr(f, x),
268            Self::Code(x) => python_code_repr(f, x),
269        }
270    }
271}
272fn python_float_repr_full(f: &mut fmt::Formatter, x: f64) -> fmt::Result {
273    python_float_repr_core(f, x)?;
274    if x.fract() == 0. {
275        write!(f, ".0")?;
276    };
277    Ok(())
278}
279fn python_float_repr_core(f: &mut fmt::Formatter, x: f64) -> fmt::Result {
280    if x.is_nan() {
281        write!(f, "float('nan')")
282    } else if x.is_infinite() {
283        if x.is_sign_positive() {
284            write!(f, "float('inf')")
285        } else {
286            write!(f, "-float('inf')")
287        }
288    } else {
289        // properly handle -0.0
290        if x.is_sign_negative() {
291            write!(f, "-")?;
292        }
293        write!(f, "{}", x.abs())
294    }
295}
296fn python_complex_repr(f: &mut fmt::Formatter, x: Complex<f64>) -> fmt::Result {
297    if x.re == 0. && x.re.is_sign_positive() {
298        python_float_repr_core(f, x.im)?;
299        write!(f, "j")?;
300    } else {
301        write!(f, "(")?;
302        python_float_repr_core(f, x.re)?;
303        if x.im >= 0. || x.im.is_nan() {
304            write!(f, "+")?;
305        }
306        python_float_repr_core(f, x.im)?;
307        write!(f, "j)")?;
308    };
309    Ok(())
310}
311fn python_bytes_repr(f: &mut fmt::Formatter, x: &[u8]) -> fmt::Result {
312    write!(f, "b\"")?;
313    for &byte in x.iter() {
314        match byte {
315            b'\t' => write!(f, "\\t")?,
316            b'\n' => write!(f, "\\n")?,
317            b'\r' => write!(f, "\\r")?,
318            b'\'' | b'"' | b'\\' => write!(f, "\\{}", char::from(byte))?,
319            b' '..=b'~' => write!(f, "{}", char::from(byte))?,
320            _ => write!(f, "\\x{:02x}", byte)?,
321        }
322    }
323    write!(f, "\"")?;
324    Ok(())
325}
326fn python_string_repr(f: &mut fmt::Formatter, x: &str) -> fmt::Result {
327    let original = format!("{:?}", x);
328    let mut last_end = 0;
329    // Note: the behavior is arbitrary if there are improper escapes.
330    for (start, _) in original.match_indices("\\u{") {
331        f.write_str(&original[last_end..start])?;
332        let len = original[start..].find('}').ok_or(fmt::Error)? + 1;
333        let end = start + len;
334        match len - 4 {
335            0..=2 => write!(f, "\\x{:0>2}", &original[start + 3..end - 1])?,
336            3..=4 => write!(f, "\\u{:0>4}", &original[start + 3..end - 1])?,
337            5..=8 => write!(f, "\\U{:0>8}", &original[start + 3..end - 1])?,
338            _ => panic!("Internal error: length of unicode escape = {} > 8", len),
339        }
340        last_end = end;
341    }
342    f.write_str(&original[last_end..])?;
343    Ok(())
344}
345fn python_tuple_repr(f: &mut fmt::Formatter, x: &[Obj]) -> fmt::Result {
346    if x.is_empty() {
347        f.write_str("()") // Otherwise this would get formatted into an empty string
348    } else {
349        let mut debug_tuple = f.debug_tuple("");
350        for o in x.iter() {
351            debug_tuple.field(&o);
352        }
353        debug_tuple.finish()
354    }
355}
356fn python_frozenset_repr(f: &mut fmt::Formatter, x: &HashSet<ObjHashable>) -> fmt::Result {
357    f.write_str("frozenset(")?;
358    if !x.is_empty() {
359        f.debug_set().entries(x.iter()).finish()?;
360    }
361    f.write_str(")")?;
362    Ok(())
363}
364fn python_code_repr(f: &mut fmt::Formatter, x: &Code) -> fmt::Result {
365    write!(f, "code(argcount={:?}, posonlyargcount={:?}, kwonlyargcount={:?}, nlocals={:?}, stacksize={:?}, flags={:?}, code={:?}, consts={:?}, names={:?}, varnames={:?}, freevars={:?}, cellvars={:?}, filename={:?}, name={:?}, firstlineno={:?}, lnotab=bytes({:?}))", x.argcount, x.posonlyargcount, x.kwonlyargcount, x.nlocals, x.stacksize, x.flags, Obj::Bytes(Arc::clone(&x.code)), x.consts, x.names, x.varnames, x.freevars, x.cellvars, x.filename, x.name, x.firstlineno, &x.lnotab)
366}
367
368#[derive(Debug)]
369pub struct HashableHashSet<T>(HashSet<T>);
370impl<T> HashableHashSet<T> {
371    fn inner(&self) -> &HashSet<T> {
372        &self.0
373    }
374}
375impl<T> Hash for HashableHashSet<T>
376where
377    T: Hash,
378{
379    fn hash<H: Hasher>(&self, state: &mut H) {
380        let mut xor: u64 = 0;
381        let hasher = std::collections::hash_map::DefaultHasher::new();
382        for value in &self.0 {
383            let mut hasher_clone = hasher.clone();
384            value.hash(&mut hasher_clone);
385            xor ^= hasher_clone.finish();
386        }
387        state.write_u64(xor);
388    }
389}
390impl<T> PartialEq for HashableHashSet<T>
391where
392    T: Eq + Hash,
393{
394    fn eq(&self, other: &Self) -> bool {
395        self.0 == other.0
396    }
397}
398impl<T> Eq for HashableHashSet<T> where T: Eq + Hash {}
399impl<T> FromIterator<T> for HashableHashSet<T>
400where
401    T: Eq + Hash,
402{
403    fn from_iter<I>(iter: I) -> Self
404    where
405        I: IntoIterator<Item = T>,
406    {
407        Self(iter.into_iter().collect())
408    }
409}
410
411#[derive(PartialEq, Eq, Hash, Clone)]
412pub enum ObjHashable {
413    None,
414    StopIteration,
415    Ellipsis,
416    Bool(bool),
417    Long(Arc<BigInt>),
418    Float(OrderedFloat<f64>),
419    Complex(Complex<OrderedFloat<f64>>),
420    Bytes(Arc<Vec<u8>>),
421    String(Arc<String>),
422    Tuple(Arc<Vec<ObjHashable>>),
423    FrozenSet(Arc<HashableHashSet<ObjHashable>>),
424    // etc.
425}
426impl TryFrom<&Obj> for ObjHashable {
427    type Error = Obj;
428
429    fn try_from(orig: &Obj) -> Result<Self, Obj> {
430        match orig {
431            Obj::None => Ok(Self::None),
432            Obj::StopIteration => Ok(Self::StopIteration),
433            Obj::Ellipsis => Ok(Self::Ellipsis),
434            Obj::Bool(x) => Ok(Self::Bool(*x)),
435            Obj::Long(x) => Ok(Self::Long(Arc::clone(x))),
436            Obj::Float(x) => Ok(Self::Float(OrderedFloat(*x))),
437            Obj::Complex(Complex { re, im }) => Ok(Self::Complex(Complex {
438                re: OrderedFloat(*re),
439                im: OrderedFloat(*im),
440            })),
441            Obj::Bytes(x) => Ok(Self::Bytes(Arc::clone(x))),
442            Obj::String(x) => Ok(Self::String(Arc::clone(x))),
443            Obj::Tuple(x) => Ok(Self::Tuple(Arc::new(
444                x.iter()
445                    .map(Self::try_from)
446                    .collect::<Result<Vec<Self>, Obj>>()?,
447            ))),
448            Obj::FrozenSet(x) => Ok(Self::FrozenSet(Arc::new(
449                x.iter().cloned().collect::<HashableHashSet<Self>>(),
450            ))),
451            x => Err(x.clone()),
452        }
453    }
454}
455impl fmt::Debug for ObjHashable {
456    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
457        match self {
458            Self::None => write!(f, "None"),
459            Self::StopIteration => write!(f, "StopIteration"),
460            Self::Ellipsis => write!(f, "Ellipsis"),
461            Self::Bool(true) => write!(f, "True"),
462            Self::Bool(false) => write!(f, "False"),
463            Self::Long(x) => write!(f, "{}", x),
464            Self::Float(x) => python_float_repr_full(f, x.0),
465            Self::Complex(x) => python_complex_repr(
466                f,
467                Complex {
468                    re: x.re.0,
469                    im: x.im.0,
470                },
471            ),
472            Self::Bytes(x) => python_bytes_repr(f, x),
473            Self::String(x) => python_string_repr(f, x),
474            Self::Tuple(x) => python_tuple_hashable_repr(f, x),
475            Self::FrozenSet(x) => python_frozenset_repr(f, &x.0),
476        }
477    }
478}
479fn python_tuple_hashable_repr(f: &mut fmt::Formatter, x: &[ObjHashable]) -> fmt::Result {
480    if x.is_empty() {
481        f.write_str("()") // Otherwise this would get formatted into an empty string
482    } else {
483        let mut debug_tuple = f.debug_tuple("");
484        for o in x.iter() {
485            debug_tuple.field(&o);
486        }
487        debug_tuple.finish()
488    }
489}
490
491#[cfg(test)]
492mod test {
493    use super::{Code, CodeFlags, Obj, ObjHashable};
494    use num_bigint::BigInt;
495    use num_complex::Complex;
496    use std::{
497        collections::{HashMap, HashSet},
498        sync::{Arc, RwLock},
499    };
500
501    #[test]
502    fn test_debug_repr() {
503        assert_eq!(format!("{:?}", Obj::None), "None");
504        assert_eq!(format!("{:?}", Obj::StopIteration), "StopIteration");
505        assert_eq!(format!("{:?}", Obj::Ellipsis), "Ellipsis");
506        assert_eq!(format!("{:?}", Obj::Bool(true)), "True");
507        assert_eq!(format!("{:?}", Obj::Bool(false)), "False");
508        assert_eq!(
509            format!("{:?}", Obj::Long(Arc::new(BigInt::from(-123)))),
510            "-123"
511        );
512        assert_eq!(format!("{:?}", Obj::Tuple(Arc::new(vec![]))), "()");
513        assert_eq!(
514            format!("{:?}", Obj::Tuple(Arc::new(vec![Obj::Bool(true)]))),
515            "(True,)"
516        );
517        assert_eq!(
518            format!(
519                "{:?}",
520                Obj::Tuple(Arc::new(vec![Obj::Bool(true), Obj::None]))
521            ),
522            "(True, None)"
523        );
524        assert_eq!(
525            format!(
526                "{:?}",
527                Obj::List(Arc::new(RwLock::new(vec![Obj::Bool(true)])))
528            ),
529            "[True]"
530        );
531        assert_eq!(
532            format!(
533                "{:?}",
534                Obj::Dict(Arc::new(RwLock::new(
535                    vec![(
536                        ObjHashable::Bool(true),
537                        Obj::Bytes(Arc::new(Vec::from(b"a" as &[u8])))
538                    )]
539                    .into_iter()
540                    .collect::<HashMap<_, _>>()
541                )))
542            ),
543            "{True: b\"a\"}"
544        );
545        assert_eq!(
546            format!(
547                "{:?}",
548                Obj::Set(Arc::new(RwLock::new(
549                    vec![ObjHashable::Bool(true)]
550                        .into_iter()
551                        .collect::<HashSet<_>>()
552                )))
553            ),
554            "{True}"
555        );
556        assert_eq!(
557            format!(
558                "{:?}",
559                Obj::FrozenSet(Arc::new(
560                    vec![ObjHashable::Bool(true)]
561                        .into_iter()
562                        .collect::<HashSet<_>>()
563                ))
564            ),
565            "frozenset({True})"
566        );
567        assert_eq!(format!("{:?}", Obj::Code(Arc::new(Code {
568            argcount: 0,
569            posonlyargcount: 1,
570            kwonlyargcount: 2,
571            nlocals: 3,
572            stacksize: 4,
573            flags: CodeFlags::NESTED | CodeFlags::COROUTINE,
574            code: Arc::new(Vec::from(b"abc" as &[u8])),
575            consts: Arc::new(vec![Obj::Bool(true)]),
576            names: vec![],
577            varnames: vec![Arc::new("a".to_owned())],
578            freevars: vec![Arc::new("b".to_owned()), Arc::new("c".to_owned())],
579            cellvars: vec![Arc::new("de".to_owned())],
580            filename: Arc::new("xyz.py".to_owned()),
581            name: Arc::new("fgh".to_owned()),
582            firstlineno: 5,
583            lnotab: Arc::new(vec![255, 0, 45, 127, 0, 73]),
584        }))), "code(argcount=0, posonlyargcount=1, kwonlyargcount=2, nlocals=3, stacksize=4, flags=NESTED | COROUTINE, code=b\"abc\", consts=[True], names=[], varnames=[\"a\"], freevars=[\"b\", \"c\"], cellvars=[\"de\"], filename=\"xyz.py\", name=\"fgh\", firstlineno=5, lnotab=bytes([255, 0, 45, 127, 0, 73]))");
585    }
586
587    #[test]
588    fn test_float_debug_repr() {
589        assert_eq!(format!("{:?}", Obj::Float(1.23)), "1.23");
590        assert_eq!(format!("{:?}", Obj::Float(f64::NAN)), "float('nan')");
591        assert_eq!(format!("{:?}", Obj::Float(f64::INFINITY)), "float('inf')");
592        assert_eq!(format!("{:?}", Obj::Float(-f64::INFINITY)), "-float('inf')");
593        assert_eq!(format!("{:?}", Obj::Float(0.0)), "0.0");
594        assert_eq!(format!("{:?}", Obj::Float(-0.0)), "-0.0");
595    }
596
597    #[test]
598    fn test_complex_debug_repr() {
599        assert_eq!(
600            format!("{:?}", Obj::Complex(Complex { re: 2., im: 1. })),
601            "(2+1j)"
602        );
603        assert_eq!(
604            format!("{:?}", Obj::Complex(Complex { re: 0., im: 1. })),
605            "1j"
606        );
607        assert_eq!(
608            format!("{:?}", Obj::Complex(Complex { re: 2., im: 0. })),
609            "(2+0j)"
610        );
611        assert_eq!(
612            format!("{:?}", Obj::Complex(Complex { re: 0., im: 0. })),
613            "0j"
614        );
615        assert_eq!(
616            format!("{:?}", Obj::Complex(Complex { re: -2., im: 1. })),
617            "(-2+1j)"
618        );
619        assert_eq!(
620            format!("{:?}", Obj::Complex(Complex { re: -2., im: 0. })),
621            "(-2+0j)"
622        );
623        assert_eq!(
624            format!("{:?}", Obj::Complex(Complex { re: 2., im: -1. })),
625            "(2-1j)"
626        );
627        assert_eq!(
628            format!("{:?}", Obj::Complex(Complex { re: 0., im: -1. })),
629            "-1j"
630        );
631        assert_eq!(
632            format!("{:?}", Obj::Complex(Complex { re: -2., im: -1. })),
633            "(-2-1j)"
634        );
635        assert_eq!(
636            format!("{:?}", Obj::Complex(Complex { re: 0., im: -1. })),
637            "-1j"
638        );
639        assert_eq!(
640            format!("{:?}", Obj::Complex(Complex { re: -2., im: 0. })),
641            "(-2+0j)"
642        );
643        assert_eq!(
644            format!("{:?}", Obj::Complex(Complex { re: -0., im: 1. })),
645            "(-0+1j)"
646        );
647        assert_eq!(
648            format!("{:?}", Obj::Complex(Complex { re: -0., im: -1. })),
649            "(-0-1j)"
650        );
651    }
652
653    #[test]
654    fn test_bytes_string_debug_repr() {
655        assert_eq!(format!("{:?}", Obj::Bytes(Arc::new(Vec::from(
656                            b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\x7f\x80\x81\x82\x83\x84\x85\x86\x87\x88\x89\x8a\x8b\x8c\x8d\x8e\x8f\x90\x91\x92\x93\x94\x95\x96\x97\x98\x99\x9a\x9b\x9c\x9d\x9e\x9f\xa0\xa1\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9\xaa\xab\xac\xad\xae\xaf\xb0\xb1\xb2\xb3\xb4\xb5\xb6\xb7\xb8\xb9\xba\xbb\xbc\xbd\xbe\xbf\xc0\xc1\xc2\xc3\xc4\xc5\xc6\xc7\xc8\xc9\xca\xcb\xcc\xcd\xce\xcf\xd0\xd1\xd2\xd3\xd4\xd5\xd6\xd7\xd8\xd9\xda\xdb\xdc\xdd\xde\xdf\xe0\xe1\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xeb\xec\xed\xee\xef\xf0\xf1\xf2\xf3\xf4\xf5\xf6\xf7\xf8\xf9\xfa\xfb\xfc\xfd\xfe" as &[u8]
657                            )))),
658        "b\"\\x00\\x01\\x02\\x03\\x04\\x05\\x06\\x07\\x08\\t\\n\\x0b\\x0c\\r\\x0e\\x0f\\x10\\x11\\x12\\x13\\x14\\x15\\x16\\x17\\x18\\x19\\x1a\\x1b\\x1c\\x1d\\x1e\\x1f !\\\"#$%&\\\'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\\x7f\\x80\\x81\\x82\\x83\\x84\\x85\\x86\\x87\\x88\\x89\\x8a\\x8b\\x8c\\x8d\\x8e\\x8f\\x90\\x91\\x92\\x93\\x94\\x95\\x96\\x97\\x98\\x99\\x9a\\x9b\\x9c\\x9d\\x9e\\x9f\\xa0\\xa1\\xa2\\xa3\\xa4\\xa5\\xa6\\xa7\\xa8\\xa9\\xaa\\xab\\xac\\xad\\xae\\xaf\\xb0\\xb1\\xb2\\xb3\\xb4\\xb5\\xb6\\xb7\\xb8\\xb9\\xba\\xbb\\xbc\\xbd\\xbe\\xbf\\xc0\\xc1\\xc2\\xc3\\xc4\\xc5\\xc6\\xc7\\xc8\\xc9\\xca\\xcb\\xcc\\xcd\\xce\\xcf\\xd0\\xd1\\xd2\\xd3\\xd4\\xd5\\xd6\\xd7\\xd8\\xd9\\xda\\xdb\\xdc\\xdd\\xde\\xdf\\xe0\\xe1\\xe2\\xe3\\xe4\\xe5\\xe6\\xe7\\xe8\\xe9\\xea\\xeb\\xec\\xed\\xee\\xef\\xf0\\xf1\\xf2\\xf3\\xf4\\xf5\\xf6\\xf7\\xf8\\xf9\\xfa\\xfb\\xfc\\xfd\\xfe\""
659        );
660        assert_eq!(format!("{:?}", Obj::String(Arc::new(String::from(
661                            "\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\x7f")))),
662                            "\"\\0\\x01\\x02\\x03\\x04\\x05\\x06\\x07\\x08\\t\\n\\x0b\\x0c\\r\\x0e\\x0f\\x10\\x11\\x12\\x13\\x14\\x15\\x16\\x17\\x18\\x19\\x1a\\x1b\\x1c\\x1d\\x1e\\x1f !\\\"#$%&\'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\\x7f\"");
663    }
664}
665
666mod utils {
667    use num_bigint::{BigUint, Sign};
668    use num_traits::Zero;
669    use std::cmp::Ordering;
670
671    /// Based on `_PyLong_AsByteArray` in <https://github.com/python/cpython/blob/master/Objects/longobject.c>
672    #[allow(clippy::cast_possible_truncation)]
673    pub fn biguint_from_pylong_digits(digits: &[u16]) -> BigUint {
674        if digits.is_empty() {
675            return BigUint::zero();
676        };
677        assert!(digits[digits.len() - 1] != 0);
678        let mut accum: u64 = 0;
679        let mut accumbits: u8 = 0;
680        let mut p = Vec::<u32>::new();
681        for (i, &thisdigit) in digits.iter().enumerate() {
682            accum |= u64::from(thisdigit) << accumbits;
683            accumbits += if i == digits.len() - 1 {
684                16 - (thisdigit.leading_zeros() as u8)
685            } else {
686                15
687            };
688
689            // Modified to get u32s instead of u8s.
690            while accumbits >= 32 {
691                p.push(accum as u32);
692                accumbits -= 32;
693                accum >>= 32;
694            }
695        }
696        assert!(accumbits < 32);
697        if accumbits > 0 {
698            p.push(accum as u32);
699        }
700        BigUint::new(p)
701    }
702
703    pub fn sign_of<T: Ord + Zero>(x: &T) -> Sign {
704        match x.cmp(&T::zero()) {
705            Ordering::Less => Sign::Minus,
706            Ordering::Equal => Sign::NoSign,
707            Ordering::Greater => Sign::Plus,
708        }
709    }
710
711    #[cfg(test)]
712    mod test {
713        use super::biguint_from_pylong_digits;
714        use num_bigint::BigUint;
715
716        #[allow(clippy::inconsistent_digit_grouping)]
717        #[test]
718        fn test_biguint_from_pylong_digits() {
719            assert_eq!(
720                biguint_from_pylong_digits(&[
721                    0b000_1101_1100_0100,
722                    0b110_1101_0010_0100,
723                    0b001_0000_1001_1101
724                ]),
725                BigUint::from(0b0100_0010_0111_0111_0110_1001_0010_0000_1101_1100_0100_u64)
726            );
727        }
728    }
729}
730
731#[allow(clippy::wildcard_imports)] // read::errors
732pub mod read {
733    pub mod errors {
734        use error_chain::error_chain;
735
736        error_chain! {
737            foreign_links {
738                Io(::std::io::Error);
739                Utf8(::std::str::Utf8Error);
740                FromUtf8(::std::string::FromUtf8Error);
741                ParseFloat(::std::num::ParseFloatError);
742            }
743
744            errors {
745                InvalidType(x: u8)
746                RecursionLimitExceeded
747                DigitOutOfRange(x: u16)
748                UnnormalizedLong
749                IsNull
750                Unhashable(x: crate::Obj)
751                TypeError(x: crate::Obj)
752                InvalidRef
753            }
754
755            skip_msg_variant
756        }
757    }
758
759    use self::errors::*;
760    use crate::{utils, Code, CodeFlags, Depth, Obj, ObjHashable, Type};
761    use num_bigint::BigInt;
762    use num_complex::Complex;
763    use num_traits::{FromPrimitive, Zero};
764    use std::{
765        collections::{HashMap, HashSet},
766        convert::TryFrom,
767        io::Read,
768        str::FromStr,
769        sync::{Arc, RwLock},
770    };
771
772    struct RFile<R: Read> {
773        depth: Depth,
774        readable: R,
775        refs: Vec<Obj>,
776        has_posonlyargcount: bool,
777    }
778
779    macro_rules! define_r {
780        ($ident:ident -> $ty:ty; $n:literal) => {
781            fn $ident(p: &mut RFile<impl Read>) -> Result<$ty> {
782                let mut buf: [u8; $n] = [0; $n];
783                p.readable.read_exact(&mut buf)?;
784                Ok(<$ty>::from_le_bytes(buf))
785            }
786        };
787    }
788
789    define_r! { r_byte      -> u8 ; 1 }
790    define_r! { r_short     -> u16; 2 }
791    define_r! { r_long      -> u32; 4 }
792    define_r! { r_long64    -> u64; 8 }
793    define_r! { r_float_bin -> f64; 8 }
794
795    fn r_bytes(n: usize, p: &mut RFile<impl Read>) -> Result<Vec<u8>> {
796        let mut buf = Vec::new();
797        buf.resize(n, 0);
798        p.readable.read_exact(&mut buf)?;
799        Ok(buf)
800    }
801
802    fn r_string(n: usize, p: &mut RFile<impl Read>) -> Result<String> {
803        let buf = r_bytes(n, p)?;
804        Ok(String::from_utf8(buf)?)
805    }
806
807    fn r_float_str(p: &mut RFile<impl Read>) -> Result<f64> {
808        let n = r_byte(p)?;
809        let s = r_string(n as usize, p)?;
810        Ok(f64::from_str(&s)?)
811    }
812
813    // TODO: test
814    /// May misbehave on 16-bit platforms.
815    fn r_pylong(p: &mut RFile<impl Read>) -> Result<BigInt> {
816        #[allow(clippy::cast_possible_wrap)]
817        let n = r_long(p)? as i32;
818        if n == 0 {
819            return Ok(BigInt::zero());
820        };
821        #[allow(clippy::cast_sign_loss)]
822        let size = n.wrapping_abs() as u32;
823        let mut digits = Vec::<u16>::with_capacity(size as usize);
824        for _ in 0..size {
825            let d = r_short(p)?;
826            if d > (1 << 15) {
827                return Err(ErrorKind::DigitOutOfRange(d).into());
828            }
829            digits.push(d);
830        }
831        if digits[(size - 1) as usize] == 0 {
832            return Err(ErrorKind::UnnormalizedLong.into());
833        }
834        Ok(BigInt::from_biguint(
835            utils::sign_of(&n),
836            utils::biguint_from_pylong_digits(&digits),
837        ))
838    }
839
840    fn r_vec(n: usize, p: &mut RFile<impl Read>) -> Result<Vec<Obj>> {
841        let mut vec = Vec::with_capacity(n);
842        for _ in 0..n {
843            vec.push(r_object_not_null(p)?);
844        }
845        Ok(vec)
846    }
847
848    fn r_hashmap(p: &mut RFile<impl Read>) -> Result<HashMap<ObjHashable, Obj>> {
849        let mut map = HashMap::new();
850        loop {
851            match r_object(p)? {
852                None => break,
853                Some(key) => match r_object(p)? {
854                    None => break,
855                    Some(value) => {
856                        map.insert(
857                            ObjHashable::try_from(&key).map_err(ErrorKind::Unhashable)?,
858                            value,
859                        );
860                    } // TODO
861                },
862            }
863        }
864        Ok(map)
865    }
866
867    fn r_hashset(n: usize, p: &mut RFile<impl Read>) -> Result<HashSet<ObjHashable>> {
868        let mut set = HashSet::new();
869        r_hashset_into(&mut set, n, p)?;
870        Ok(set)
871    }
872    fn r_hashset_into(
873        set: &mut HashSet<ObjHashable>,
874        n: usize,
875        p: &mut RFile<impl Read>,
876    ) -> Result<()> {
877        for _ in 0..n {
878            set.insert(
879                ObjHashable::try_from(&r_object_not_null(p)?).map_err(ErrorKind::Unhashable)?,
880            );
881        }
882        Ok(())
883    }
884
885    #[allow(clippy::too_many_lines)]
886    fn r_object(p: &mut RFile<impl Read>) -> Result<Option<Obj>> {
887        let code: u8 = r_byte(p)?;
888        let _depth_handle = p
889            .depth
890            .try_clone()
891            .map_or(Err(ErrorKind::RecursionLimitExceeded), Ok)?;
892        let (flag, type_) = {
893            let flag: bool = (code & Type::FLAG_REF) != 0;
894            let type_u8: u8 = code & !Type::FLAG_REF;
895            let type_: Type =
896                Type::from_u8(type_u8).map_or(Err(ErrorKind::InvalidType(type_u8)), Ok)?;
897            (flag, type_)
898        };
899        let mut idx: Option<usize> = match type_ {
900            // R_REF/r_ref_reserve before reading contents
901            // See https://github.com/sollyucko/py-marshal/issues/2
902            Type::SmallTuple
903            | Type::Tuple
904            | Type::List
905            | Type::Dict
906            | Type::Set
907            | Type::FrozenSet
908            | Type::Code
909                if flag =>
910            {
911                let i = p.refs.len();
912                p.refs.push(Obj::None);
913                Some(i)
914            }
915            _ => None,
916        };
917        #[allow(clippy::cast_possible_wrap)]
918        let retval = match type_ {
919            Type::Null => None,
920            Type::None => Some(Obj::None),
921            Type::StopIter => Some(Obj::StopIteration),
922            Type::Ellipsis => Some(Obj::Ellipsis),
923            Type::False => Some(Obj::Bool(false)),
924            Type::True => Some(Obj::Bool(true)),
925            Type::Int => Some(Obj::Long(Arc::new(BigInt::from(r_long(p)? as i32)))),
926            Type::Int64 => Some(Obj::Long(Arc::new(BigInt::from(r_long64(p)? as i64)))),
927            Type::Long => Some(Obj::Long(Arc::new(r_pylong(p)?))),
928            Type::Float => Some(Obj::Float(r_float_str(p)?)),
929            Type::BinaryFloat => Some(Obj::Float(r_float_bin(p)?)),
930            Type::Complex => Some(Obj::Complex(Complex {
931                re: r_float_str(p)?,
932                im: r_float_str(p)?,
933            })),
934            Type::BinaryComplex => Some(Obj::Complex(Complex {
935                re: r_float_bin(p)?,
936                im: r_float_bin(p)?,
937            })),
938            Type::String => Some(Obj::Bytes(Arc::new(r_bytes(r_long(p)? as usize, p)?))),
939            Type::AsciiInterned | Type::Ascii | Type::Interned | Type::Unicode => {
940                Some(Obj::String(Arc::new(r_string(r_long(p)? as usize, p)?)))
941            }
942            Type::ShortAsciiInterned | Type::ShortAscii => {
943                Some(Obj::String(Arc::new(r_string(r_byte(p)? as usize, p)?)))
944            }
945            Type::SmallTuple => Some(Obj::Tuple(Arc::new(r_vec(r_byte(p)? as usize, p)?))),
946            Type::Tuple => Some(Obj::Tuple(Arc::new(r_vec(r_long(p)? as usize, p)?))),
947            Type::List => Some(Obj::List(Arc::new(RwLock::new(r_vec(
948                r_long(p)? as usize,
949                p,
950            )?)))),
951            Type::Set => {
952                let set = Arc::new(RwLock::new(HashSet::new()));
953
954                if flag {
955                    idx = Some(p.refs.len());
956                    p.refs.push(Obj::Set(Arc::clone(&set)));
957                }
958
959                r_hashset_into(&mut *set.write().unwrap(), r_long(p)? as usize, p)?;
960                Some(Obj::Set(set))
961            }
962            Type::FrozenSet => Some(Obj::FrozenSet(Arc::new(r_hashset(r_long(p)? as usize, p)?))),
963            Type::Dict => Some(Obj::Dict(Arc::new(RwLock::new(r_hashmap(p)?)))),
964            Type::Code => Some(Obj::Code(Arc::new(Code {
965                argcount: r_long(p)?,
966                posonlyargcount: if p.has_posonlyargcount { r_long(p)? } else { 0 },
967                kwonlyargcount: r_long(p)?,
968                nlocals: r_long(p)?,
969                stacksize: r_long(p)?,
970                flags: CodeFlags::from_bits_truncate(r_long(p)?),
971                code: r_object_extract_bytes(p)?,
972                consts: r_object_extract_tuple(p)?,
973                names: r_object_extract_tuple_string(p)?,
974                varnames: r_object_extract_tuple_string(p)?,
975                freevars: r_object_extract_tuple_string(p)?,
976                cellvars: r_object_extract_tuple_string(p)?,
977                filename: r_object_extract_string(p)?,
978                name: r_object_extract_string(p)?,
979                firstlineno: r_long(p)?,
980                lnotab: r_object_extract_bytes(p)?,
981            }))),
982
983            Type::Ref => {
984                let n = r_long(p)? as usize;
985                let result = p.refs.get(n).ok_or(ErrorKind::InvalidRef)?.clone();
986                if result.is_none() {
987                    return Err(ErrorKind::InvalidRef.into());
988                } else {
989                    Some(result)
990                }
991            }
992            Type::Unknown => return Err(ErrorKind::InvalidType(Type::Unknown as u8).into()),
993        };
994        match (&retval, idx) {
995            (None, _)
996            | (Some(Obj::None), _)
997            | (Some(Obj::StopIteration), _)
998            | (Some(Obj::Ellipsis), _)
999            | (Some(Obj::Bool(_)), _) => {}
1000            (Some(x), Some(i)) if flag => {
1001                p.refs[i] = x.clone();
1002            }
1003            (Some(x), None) if flag => {
1004                p.refs.push(x.clone());
1005            }
1006            (Some(_), _) => {}
1007        };
1008        Ok(retval)
1009    }
1010
1011    fn r_object_not_null(p: &mut RFile<impl Read>) -> Result<Obj> {
1012        Ok(r_object(p)?.ok_or(ErrorKind::IsNull)?)
1013    }
1014    fn r_object_extract_string(p: &mut RFile<impl Read>) -> Result<Arc<String>> {
1015        Ok(r_object_not_null(p)?
1016            .extract_string()
1017            .map_err(ErrorKind::TypeError)?)
1018    }
1019    fn r_object_extract_bytes(p: &mut RFile<impl Read>) -> Result<Arc<Vec<u8>>> {
1020        Ok(r_object_not_null(p)?
1021            .extract_bytes()
1022            .map_err(ErrorKind::TypeError)?)
1023    }
1024    fn r_object_extract_tuple(p: &mut RFile<impl Read>) -> Result<Arc<Vec<Obj>>> {
1025        Ok(r_object_not_null(p)?
1026            .extract_tuple()
1027            .map_err(ErrorKind::TypeError)?)
1028    }
1029    fn r_object_extract_tuple_string(p: &mut RFile<impl Read>) -> Result<Vec<Arc<String>>> {
1030        r_object_extract_tuple(p)?
1031            .iter()
1032            .map(|x| {
1033                x.clone()
1034                    .extract_string()
1035                    .map_err(|o: Obj| Error::from(ErrorKind::TypeError(o)))
1036            })
1037            .collect::<Result<Vec<Arc<String>>>>()
1038    }
1039
1040    fn read_object(p: &mut RFile<impl Read>) -> Result<Obj> {
1041        r_object_not_null(p)
1042    }
1043
1044    #[derive(Copy, Clone, Debug)]
1045    pub struct MarshalLoadExOptions {
1046        pub has_posonlyargcount: bool,
1047    }
1048    /// Assume latest version
1049    impl Default for MarshalLoadExOptions {
1050        fn default() -> Self {
1051            Self {
1052                has_posonlyargcount: true,
1053            }
1054        }
1055    }
1056
1057    /// # Errors
1058    /// See [`ErrorKind`].
1059    pub fn marshal_load_ex(readable: impl Read, opts: MarshalLoadExOptions) -> Result<Obj> {
1060        let mut rf = RFile {
1061            depth: Depth::new(),
1062            readable,
1063            refs: Vec::<Obj>::new(),
1064            has_posonlyargcount: opts.has_posonlyargcount,
1065        };
1066        read_object(&mut rf)
1067    }
1068
1069    /// # Errors
1070    /// See [`ErrorKind`].
1071    pub fn marshal_load(readable: impl Read) -> Result<Obj> {
1072        marshal_load_ex(readable, MarshalLoadExOptions::default())
1073    }
1074
1075    /// Allows coercion from array reference to slice.
1076    /// # Errors
1077    /// See [`ErrorKind`].
1078    pub fn marshal_loads(bytes: &[u8]) -> Result<Obj> {
1079        marshal_load(bytes)
1080    }
1081
1082    // Ported from <https://github.com/python/cpython/blob/master/Lib/test/test_marshal.py>
1083    #[cfg(test)]
1084    mod test {
1085        use super::{
1086            errors, marshal_load, marshal_load_ex, marshal_loads, Code, CodeFlags,
1087            MarshalLoadExOptions, Obj, ObjHashable,
1088        };
1089        use num_bigint::BigInt;
1090        use num_traits::Pow;
1091        use std::{
1092            io::{self, Read},
1093            ops::Deref,
1094            sync::Arc,
1095        };
1096
1097        macro_rules! assert_match {
1098            ($expr:expr, $pat:pat) => {
1099                match $expr {
1100                    $pat => {}
1101                    _ => panic!(),
1102                }
1103            };
1104        }
1105
1106        fn load_unwrap(r: impl Read) -> Obj {
1107            marshal_load(r).unwrap()
1108        }
1109
1110        fn loads_unwrap(s: &[u8]) -> Obj {
1111            load_unwrap(s)
1112        }
1113
1114        #[test]
1115        fn test_ints() {
1116            assert_eq!(BigInt::parse_bytes(b"85070591730234615847396907784232501249", 10).unwrap(), *loads_unwrap(b"l\t\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\xf0\x7f\xff\x7f\xff\x7f\xff\x7f?\x00").extract_long().unwrap());
1117        }
1118
1119        #[allow(clippy::unreadable_literal)]
1120        #[test]
1121        fn test_int64() {
1122            for mut base in [i64::MAX, i64::MIN, -i64::MAX, -(i64::MIN >> 1)]
1123                .iter()
1124                .copied()
1125            {
1126                while base != 0 {
1127                    let mut s = Vec::<u8>::new();
1128                    s.push(b'I');
1129                    s.extend_from_slice(&base.to_le_bytes());
1130                    assert_eq!(
1131                        BigInt::from(base),
1132                        *loads_unwrap(&s).extract_long().unwrap()
1133                    );
1134
1135                    if base == -1 {
1136                        base = 0
1137                    } else {
1138                        base >>= 1
1139                    }
1140                }
1141            }
1142
1143            assert_eq!(
1144                BigInt::from(0x1032547698badcfe_i64),
1145                *loads_unwrap(b"I\xfe\xdc\xba\x98\x76\x54\x32\x10")
1146                    .extract_long()
1147                    .unwrap()
1148            );
1149            assert_eq!(
1150                BigInt::from(-0x1032547698badcff_i64),
1151                *loads_unwrap(b"I\x01\x23\x45\x67\x89\xab\xcd\xef")
1152                    .extract_long()
1153                    .unwrap()
1154            );
1155            assert_eq!(
1156                BigInt::from(0x7f6e5d4c3b2a1908_i64),
1157                *loads_unwrap(b"I\x08\x19\x2a\x3b\x4c\x5d\x6e\x7f")
1158                    .extract_long()
1159                    .unwrap()
1160            );
1161            assert_eq!(
1162                BigInt::from(-0x7f6e5d4c3b2a1909_i64),
1163                *loads_unwrap(b"I\xf7\xe6\xd5\xc4\xb3\xa2\x91\x80")
1164                    .extract_long()
1165                    .unwrap()
1166            );
1167        }
1168
1169        #[test]
1170        fn test_bool() {
1171            assert!(loads_unwrap(b"T").extract_bool().unwrap());
1172            assert!(!loads_unwrap(b"F").extract_bool().unwrap());
1173        }
1174
1175        #[allow(clippy::float_cmp, clippy::cast_precision_loss)]
1176        #[test]
1177        fn test_floats() {
1178            assert_eq!(
1179                (i64::MAX as f64) * 3.7e250,
1180                loads_unwrap(b"g\x11\x9f6\x98\xd2\xab\xe4w")
1181                    .extract_float()
1182                    .unwrap()
1183            );
1184        }
1185
1186        #[test]
1187        fn test_unicode() {
1188            assert_eq!("", *loads_unwrap(b"\xda\x00").extract_string().unwrap());
1189            assert_eq!(
1190                "Andr\u{e8} Previn",
1191                *loads_unwrap(b"u\r\x00\x00\x00Andr\xc3\xa8 Previn")
1192                    .extract_string()
1193                    .unwrap()
1194            );
1195            assert_eq!(
1196                "abc",
1197                *loads_unwrap(b"\xda\x03abc").extract_string().unwrap()
1198            );
1199            assert_eq!(
1200                " ".repeat(10_000),
1201                *loads_unwrap(&[b"a\x10'\x00\x00" as &[u8], &[b' '; 10_000]].concat())
1202                    .extract_string()
1203                    .unwrap()
1204            );
1205        }
1206
1207        #[test]
1208        fn test_string() {
1209            assert_eq!("", *loads_unwrap(b"\xda\x00").extract_string().unwrap());
1210            assert_eq!(
1211                "Andr\u{e8} Previn",
1212                *loads_unwrap(b"\xf5\r\x00\x00\x00Andr\xc3\xa8 Previn")
1213                    .extract_string()
1214                    .unwrap()
1215            );
1216            assert_eq!(
1217                "abc",
1218                *loads_unwrap(b"\xda\x03abc").extract_string().unwrap()
1219            );
1220            assert_eq!(
1221                " ".repeat(10_000),
1222                *loads_unwrap(&[b"\xe1\x10'\x00\x00" as &[u8], &[b' '; 10_000]].concat())
1223                    .extract_string()
1224                    .unwrap()
1225            );
1226        }
1227
1228        #[test]
1229        fn test_bytes() {
1230            assert_eq!(
1231                b"",
1232                &loads_unwrap(b"\xf3\x00\x00\x00\x00")
1233                    .extract_bytes()
1234                    .unwrap()[..]
1235            );
1236            assert_eq!(
1237                b"Andr\xe8 Previn",
1238                &loads_unwrap(b"\xf3\x0c\x00\x00\x00Andr\xe8 Previn")
1239                    .extract_bytes()
1240                    .unwrap()[..]
1241            );
1242            assert_eq!(
1243                b"abc",
1244                &loads_unwrap(b"\xf3\x03\x00\x00\x00abc")
1245                    .extract_bytes()
1246                    .unwrap()[..]
1247            );
1248            assert_eq!(
1249                b" ".repeat(10_000),
1250                &loads_unwrap(&[b"\xf3\x10'\x00\x00" as &[u8], &[b' '; 10_000]].concat())
1251                    .extract_bytes()
1252                    .unwrap()[..]
1253            );
1254        }
1255
1256        #[test]
1257        fn test_exceptions() {
1258            loads_unwrap(b"S").extract_stop_iteration().unwrap();
1259        }
1260
1261        fn assert_test_exceptions_code_valid(code: &Code) {
1262            assert_eq!(code.argcount, 1);
1263            assert!(code.cellvars.is_empty());
1264            assert_eq!(*code.code, b"t\x00\xa0\x01t\x00\xa0\x02t\x03\xa1\x01\xa1\x01}\x01|\x00\xa0\x04t\x03|\x01\xa1\x02\x01\x00d\x00S\x00");
1265            assert_eq!(code.consts.len(), 1);
1266            assert!(code.consts[0].is_none());
1267            assert_eq!(*code.filename, "<string>");
1268            assert_eq!(code.firstlineno, 3);
1269            assert_eq!(
1270                code.flags,
1271                CodeFlags::NOFREE | CodeFlags::NEWLOCALS | CodeFlags::OPTIMIZED
1272            );
1273            assert!(code.freevars.is_empty());
1274            assert_eq!(code.kwonlyargcount, 0);
1275            assert_eq!(*code.lnotab, b"\x00\x01\x10\x01");
1276            assert_eq!(*code.name, "test_exceptions");
1277            assert!(code.names.iter().map(Deref::deref).eq(vec![
1278                "marshal",
1279                "loads",
1280                "dumps",
1281                "StopIteration",
1282                "assertEqual"
1283            ]
1284            .iter()));
1285            assert_eq!(code.nlocals, 2);
1286            assert_eq!(code.stacksize, 5);
1287            assert!(code
1288                .varnames
1289                .iter()
1290                .map(Deref::deref)
1291                .eq(vec!["self", "new"].iter()));
1292        }
1293
1294        #[test]
1295        fn test_code() {
1296            // ExceptionTestCase.test_exceptions
1297            // { 'co_argcount': 1, 'co_cellvars': (), 'co_code': b't\x00\xa0\x01t\x00\xa0\x02t\x03\xa1\x01\xa1\x01}\x01|\x00\xa0\x04t\x03|\x01\xa1\x02\x01\x00d\x00S\x00', 'co_consts': (None,), 'co_filename': '<string>', 'co_firstlineno': 3, 'co_flags': 67, 'co_freevars': (), 'co_kwonlyargcount': 0, 'co_lnotab': b'\x00\x01\x10\x01', 'co_name': 'test_exceptions', 'co_names': ('marshal', 'loads', 'dumps', 'StopIteration', 'assertEqual'), 'co_nlocals': 2, 'co_stacksize': 5, 'co_varnames': ('self', 'new') }
1298            let mut input: &[u8] = b"\xe3\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x05\x00\x00\x00C\x00\x00\x00s \x00\x00\x00t\x00\xa0\x01t\x00\xa0\x02t\x03\xa1\x01\xa1\x01}\x01|\x00\xa0\x04t\x03|\x01\xa1\x02\x01\x00d\x00S\x00)\x01N)\x05\xda\x07marshal\xda\x05loads\xda\x05dumps\xda\rStopIteration\xda\x0bassertEqual)\x02\xda\x04self\xda\x03new\xa9\x00r\x08\x00\x00\x00\xda\x08<string>\xda\x0ftest_exceptions\x03\x00\x00\x00s\x04\x00\x00\x00\x00\x01\x10\x01";
1299            println!("{}", input.len());
1300            let code_result = marshal_load_ex(
1301                &mut input,
1302                MarshalLoadExOptions {
1303                    has_posonlyargcount: false,
1304                },
1305            );
1306            println!("{}", input.len());
1307            let code = code_result.unwrap().extract_code().unwrap();
1308            assert_test_exceptions_code_valid(&code);
1309        }
1310
1311        #[test]
1312        fn test_many_codeobjects() {
1313            let mut input: &[u8] = &[b"(\x88\x13\x00\x00\xe3\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x05\x00\x00\x00C\x00\x00\x00s \x00\x00\x00t\x00\xa0\x01t\x00\xa0\x02t\x03\xa1\x01\xa1\x01}\x01|\x00\xa0\x04t\x03|\x01\xa1\x02\x01\x00d\x00S\x00)\x01N)\x05\xda\x07marshal\xda\x05loads\xda\x05dumps\xda\rStopIteration\xda\x0bassertEqual)\x02\xda\x04self\xda\x03new\xa9\x00r\x08\x00\x00\x00\xda\x08<string>\xda\x0ftest_exceptions\x03\x00\x00\x00s\x04\x00\x00\x00\x00\x01\x10\x01" as &[u8], &b"r\x00\x00\x00\x00".repeat(4999)].concat();
1314            let result = marshal_load_ex(
1315                &mut input,
1316                MarshalLoadExOptions {
1317                    has_posonlyargcount: false,
1318                },
1319            );
1320            let tuple = result.unwrap().extract_tuple().unwrap();
1321            for o in &*tuple {
1322                assert_test_exceptions_code_valid(&o.clone().extract_code().unwrap());
1323            }
1324        }
1325
1326        #[test]
1327        fn test_different_filenames() {
1328            let mut input: &[u8] = b")\x02c\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00@\x00\x00\x00s\x08\x00\x00\x00e\x00\x01\x00d\x00S\x00)\x01N)\x01\xda\x01x\xa9\x00r\x01\x00\x00\x00r\x01\x00\x00\x00\xda\x02f1\xda\x08<module>\x01\x00\x00\x00\xf3\x00\x00\x00\x00c\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00@\x00\x00\x00s\x08\x00\x00\x00e\x00\x01\x00d\x00S\x00)\x01N)\x01\xda\x01yr\x01\x00\x00\x00r\x01\x00\x00\x00r\x01\x00\x00\x00\xda\x02f2r\x03\x00\x00\x00\x01\x00\x00\x00r\x04\x00\x00\x00";
1329            println!("{}", input.len());
1330            let result = marshal_load_ex(
1331                &mut input,
1332                MarshalLoadExOptions {
1333                    has_posonlyargcount: false,
1334                },
1335            );
1336            println!("{}", input.len());
1337            let tuple = result.unwrap().extract_tuple().unwrap();
1338            assert_eq!(tuple.len(), 2);
1339            assert_eq!(*tuple[0].clone().extract_code().unwrap().filename, "f1");
1340            assert_eq!(*tuple[1].clone().extract_code().unwrap().filename, "f2");
1341        }
1342
1343        #[allow(clippy::float_cmp)]
1344        #[test]
1345        fn test_dict() {
1346            let mut input: &[u8] = b"{\xda\x07astring\xfa\x10foo@bar.baz.spam\xda\x06afloat\xe7H\xe1z\x14ns\xbc@\xda\x05anint\xe9\x00\x00\x10\x00\xda\nashortlong\xe9\x02\x00\x00\x00\xda\x05alist[\x01\x00\x00\x00\xfa\x07.zyx.41\xda\x06atuple\xa9\n\xfa\x07.zyx.41r\x0c\x00\x00\x00r\x0c\x00\x00\x00r\x0c\x00\x00\x00r\x0c\x00\x00\x00r\x0c\x00\x00\x00r\x0c\x00\x00\x00r\x0c\x00\x00\x00r\x0c\x00\x00\x00r\x0c\x00\x00\x00\xda\x08abooleanF\xda\x08aunicode\xf5\r\x00\x00\x00Andr\xc3\xa8 Previn0";
1347            println!("{}", input.len());
1348            let result = marshal_load(&mut input);
1349            println!("{}", input.len());
1350            let dict_ref = result.unwrap().extract_dict().unwrap();
1351            let dict = dict_ref.try_read().unwrap();
1352            assert_eq!(dict.len(), 8);
1353            assert_eq!(
1354                *dict[&ObjHashable::String(Arc::new("astring".to_owned()))]
1355                    .clone()
1356                    .extract_string()
1357                    .unwrap(),
1358                "foo@bar.baz.spam"
1359            );
1360            assert_eq!(
1361                dict[&ObjHashable::String(Arc::new("afloat".to_owned()))]
1362                    .clone()
1363                    .extract_float()
1364                    .unwrap(),
1365                7283.43_f64
1366            );
1367            assert_eq!(
1368                *dict[&ObjHashable::String(Arc::new("anint".to_owned()))]
1369                    .clone()
1370                    .extract_long()
1371                    .unwrap(),
1372                BigInt::from(2).pow(20_u8)
1373            );
1374            assert_eq!(
1375                *dict[&ObjHashable::String(Arc::new("ashortlong".to_owned()))]
1376                    .clone()
1377                    .extract_long()
1378                    .unwrap(),
1379                BigInt::from(2)
1380            );
1381
1382            let list_ref = dict[&ObjHashable::String(Arc::new("alist".to_owned()))]
1383                .clone()
1384                .extract_list()
1385                .unwrap();
1386            let list = list_ref.try_read().unwrap();
1387            assert_eq!(list.len(), 1);
1388            assert_eq!(*list[0].clone().extract_string().unwrap(), ".zyx.41");
1389
1390            let tuple = dict[&ObjHashable::String(Arc::new("atuple".to_owned()))]
1391                .clone()
1392                .extract_tuple()
1393                .unwrap();
1394            assert_eq!(tuple.len(), 10);
1395            for o in &*tuple {
1396                assert_eq!(*o.clone().extract_string().unwrap(), ".zyx.41");
1397            }
1398            assert!(!dict[&ObjHashable::String(Arc::new("aboolean".to_owned()))]
1399                .clone()
1400                .extract_bool()
1401                .unwrap());
1402            assert_eq!(
1403                *dict[&ObjHashable::String(Arc::new("aunicode".to_owned()))]
1404                    .clone()
1405                    .extract_string()
1406                    .unwrap(),
1407                "Andr\u{e8} Previn"
1408            );
1409        }
1410
1411        #[test]
1412        fn test_dict_with_bytes_key() {
1413            let mut input: &[u8] = b"{s\x08\x00\x00\x00u_key \xce\xb1u\n\x00\x00\x00unicode \xce\xb1s\x08\x00\x00\x00b_key \xce\xb1s\x08\x00\x00\x00bytes \xce\xb10";
1414            let dict_ref = marshal_load(&mut input).unwrap().extract_dict().unwrap();
1415            let dict = dict_ref.try_read().unwrap();
1416            assert_eq!(dict.len(), 2);
1417            assert_eq!(
1418                *dict[&ObjHashable::Bytes(Arc::new("b_key α".as_bytes().to_vec()))]
1419                    .clone()
1420                    .extract_bytes()
1421                    .unwrap(),
1422                "bytes α".as_bytes().to_vec()
1423            );
1424            assert_eq!(
1425                *dict[&ObjHashable::Bytes(Arc::new("u_key α".as_bytes().to_vec()))]
1426                    .clone()
1427                    .extract_string()
1428                    .unwrap(),
1429                "unicode α".to_owned()
1430            );
1431        }
1432
1433        #[test]
1434        fn test_set_with_bytes() {
1435            let mut input: &[u8] = b"<\x03\x00\x00\x00i{\x00\x00\x00u\x06\x00\x00\x00abc \xce\xb1s\x06\x00\x00\x00abc \xce\xb1";
1436            let set_ref = marshal_load(&mut input).unwrap().extract_set().unwrap();
1437            let set = set_ref.try_read().unwrap();
1438            assert_eq!(set.len(), 3);
1439            assert!(set.contains(&ObjHashable::Bytes(Arc::new("abc α".as_bytes().to_vec()))));
1440            assert!(set.contains(&ObjHashable::String(Arc::new("abc α".to_owned()))));
1441            assert!(set.contains(&ObjHashable::Long(Arc::new(BigInt::from(123)))));
1442        }
1443
1444        /// Tests hash implementation
1445        #[test]
1446        fn test_dict_tuple_key() {
1447            let dict = loads_unwrap(b"{\xa9\x02\xda\x01a\xda\x01b\xda\x01c0")
1448                .extract_dict()
1449                .unwrap();
1450            assert_eq!(dict.read().unwrap().len(), 1);
1451            assert_eq!(
1452                *dict.read().unwrap()[&ObjHashable::Tuple(Arc::new(vec![
1453                    ObjHashable::String(Arc::new("a".to_owned())),
1454                    ObjHashable::String(Arc::new("b".to_owned()))
1455                ]))]
1456                    .clone()
1457                    .extract_string()
1458                    .unwrap(),
1459                "c"
1460            );
1461        }
1462
1463        // TODO: test_list and test_tuple
1464
1465        #[test]
1466        fn test_sets() {
1467            let set = loads_unwrap(b"<\x08\x00\x00\x00\xda\x05alist\xda\x08aboolean\xda\x07astring\xda\x08aunicode\xda\x06afloat\xda\x05anint\xda\x06atuple\xda\nashortlong").extract_set().unwrap();
1468            assert_eq!(set.read().unwrap().len(), 8);
1469            let frozenset = loads_unwrap(b">\x08\x00\x00\x00\xda\x06atuple\xda\x08aunicode\xda\x05anint\xda\x08aboolean\xda\x06afloat\xda\x05alist\xda\nashortlong\xda\x07astring").extract_frozenset().unwrap();
1470            assert_eq!(frozenset.len(), 8);
1471            // TODO: check values
1472        }
1473
1474        // TODO: test_bytearray, test_memoryview, test_array
1475
1476        #[test]
1477        fn test_patch_873224() {
1478            assert_match!(
1479                marshal_loads(b"0").unwrap_err().kind(),
1480                errors::ErrorKind::IsNull
1481            );
1482            let f_err = marshal_loads(b"f").unwrap_err();
1483            match f_err.kind() {
1484                errors::ErrorKind::Io(io_err) => {
1485                    assert_eq!(io_err.kind(), io::ErrorKind::UnexpectedEof);
1486                }
1487                _ => panic!(),
1488            }
1489            let int_err =
1490                marshal_loads(b"l\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00 ").unwrap_err();
1491            match int_err.kind() {
1492                errors::ErrorKind::Io(io_err) => {
1493                    assert_eq!(io_err.kind(), io::ErrorKind::UnexpectedEof);
1494                }
1495                _ => panic!(),
1496            }
1497        }
1498
1499        #[test]
1500        fn test_fuzz() {
1501            for i in 0..=u8::MAX {
1502                println!("{:?}", marshal_loads(&[i]));
1503            }
1504        }
1505
1506        /// Warning: this has to be run on a release build to avoid a stack overflow.
1507        #[cfg(not(debug_assertions))]
1508        #[test]
1509        fn test_loads_recursion() {
1510            loads_unwrap(&[&b")\x01".repeat(100)[..], b"N"].concat());
1511            loads_unwrap(&[&b"(\x01\x00\x00\x00".repeat(100)[..], b"N"].concat());
1512            loads_unwrap(&[&b"[\x01\x00\x00\x00".repeat(100)[..], b"N"].concat());
1513            loads_unwrap(&[&b"{N".repeat(100)[..], b"N", &b"0".repeat(100)[..]].concat());
1514            loads_unwrap(&[&b">\x01\x00\x00\x00".repeat(100)[..], b"N"].concat());
1515
1516            assert_match!(
1517                marshal_loads(&[&b")\x01".repeat(1048576)[..], b"N"].concat())
1518                    .unwrap_err()
1519                    .kind(),
1520                errors::ErrorKind::RecursionLimitExceeded
1521            );
1522            assert_match!(
1523                marshal_loads(&[&b"(\x01\x00\x00\x00".repeat(1048576)[..], b"N"].concat())
1524                    .unwrap_err()
1525                    .kind(),
1526                errors::ErrorKind::RecursionLimitExceeded
1527            );
1528            assert_match!(
1529                marshal_loads(&[&b"[\x01\x00\x00\x00".repeat(1048576)[..], b"N"].concat())
1530                    .unwrap_err()
1531                    .kind(),
1532                errors::ErrorKind::RecursionLimitExceeded
1533            );
1534            assert_match!(
1535                marshal_loads(
1536                    &[&b"{N".repeat(1048576)[..], b"N", &b"0".repeat(1048576)[..]].concat()
1537                )
1538                .unwrap_err()
1539                .kind(),
1540                errors::ErrorKind::RecursionLimitExceeded
1541            );
1542            assert_match!(
1543                marshal_loads(&[&b">\x01\x00\x00\x00".repeat(1048576)[..], b"N"].concat())
1544                    .unwrap_err()
1545                    .kind(),
1546                errors::ErrorKind::RecursionLimitExceeded
1547            );
1548        }
1549
1550        #[test]
1551        fn test_invalid_longs() {
1552            assert_match!(
1553                marshal_loads(b"l\x02\x00\x00\x00\x00\x00\x00\x00")
1554                    .unwrap_err()
1555                    .kind(),
1556                errors::ErrorKind::UnnormalizedLong
1557            );
1558        }
1559
1560        // See https://github.com/sollyucko/py-marshal/issues/2
1561        #[test]
1562        fn test_issue_2_ref_demarshalling_ordering_previously_broken() {
1563            let list_ref = marshal_loads(b"\xdb\x02\x00\x00\x00\xda\x01ar\x01\x00\x00\x00")
1564                .unwrap()
1565                .extract_list()
1566                .unwrap();
1567            let list = list_ref.try_read().unwrap();
1568            assert_eq!(list.len(), 2);
1569            assert_eq!(*list[0].clone().extract_string().unwrap(), "a");
1570            assert_eq!(*list[1].clone().extract_string().unwrap(), "a");
1571        }
1572        #[test]
1573        fn test_issue_2_ref_demarshalling_ordering_previously_working() {
1574            let list_ref = marshal_loads(b"[\x02\x00\x00\x00\xda\x01ar\x00\x00\x00\x00")
1575                .unwrap()
1576                .extract_list()
1577                .unwrap();
1578            let list = list_ref.try_read().unwrap();
1579            assert_eq!(list.len(), 2);
1580            assert_eq!(*list[0].clone().extract_string().unwrap(), "a");
1581            assert_eq!(*list[1].clone().extract_string().unwrap(), "a");
1582        }
1583    }
1584}