py27_marshal/
read.rs

1#[allow(clippy::wildcard_imports)] // read::errors
2pub mod errors {
3    use thiserror::Error;
4
5    #[derive(Debug, Error)]
6    pub enum ErrorKind {
7        #[error("invalid type: 0x{0:X}")]
8        InvalidType(u8),
9        #[error("recursion limit exceeded")]
10        RecursionLimitExceeded,
11        #[error("digit is out of range: {0}")]
12        DigitOutOfRange(u16),
13        #[error("unnormalized long")]
14        UnnormalizedLong,
15        #[error("null object")]
16        IsNull,
17        #[error("encountered unhashable: {0:?}")]
18        Unhashable(crate::Obj),
19        #[error("type error: {0:?}")]
20        TypeError(crate::Obj),
21        #[error("invalid reference")]
22        InvalidRef,
23
24        #[error("IO error occurred: {0}")]
25        Io(#[from] std::io::Error),
26        #[error("Error occurred while creating a str: {0}")]
27        Utf8(#[from] std::str::Utf8Error),
28        #[error("Error occurred while creating a string from utf8: {0}")]
29        FromUtf8(#[from] std::string::FromUtf8Error),
30        #[error("Error occurred while parsing float: {0}")]
31        ParseFloat(#[from] std::num::ParseFloatError),
32    }
33}
34
35use self::errors::*;
36use crate::{utils, Code, CodeFlags, Depth, Obj, ObjHashable, Type};
37use bstr::BString;
38use num_bigint::BigInt;
39use num_complex::Complex;
40use num_traits::{FromPrimitive, Zero};
41use std::{
42    collections::{HashMap, HashSet},
43    convert::TryFrom,
44    io::Read,
45    str::FromStr,
46    sync::{Arc, RwLock},
47};
48
49pub type ParseResult<T> = Result<T, ErrorKind>;
50
51struct RFile<R: Read> {
52    depth: Depth,
53    readable: R,
54    refs: Vec<Obj>,
55    stringrefs: Vec<Obj>,
56    has_posonlyargcount: bool,
57}
58
59macro_rules! define_r {
60    ($ident:ident -> $ty:ty; $n:literal) => {
61        fn $ident(p: &mut RFile<impl Read>) -> ParseResult<$ty> {
62            let mut buf: [u8; $n] = [0; $n];
63            p.readable.read_exact(&mut buf)?;
64            Ok(<$ty>::from_le_bytes(buf))
65        }
66    };
67}
68
69define_r! { r_byte      -> u8 ; 1 }
70define_r! { r_short     -> u16; 2 }
71define_r! { r_long      -> u32; 4 }
72define_r! { r_long64    -> u64; 8 }
73define_r! { r_float_bin -> f64; 8 }
74
75fn r_bytes(n: usize, p: &mut RFile<impl Read>) -> ParseResult<Vec<u8>> {
76    let mut buf = Vec::new();
77    buf.resize(n, 0);
78    p.readable.read_exact(&mut buf)?;
79    Ok(buf)
80}
81
82fn r_string(n: usize, p: &mut RFile<impl Read>) -> ParseResult<String> {
83    let buf = r_bytes(n, p)?;
84    Ok(String::from_utf8(buf)?)
85}
86
87fn r_bstring(n: usize, p: &mut RFile<impl Read>) -> ParseResult<BString> {
88    let buf = r_bytes(n, p)?;
89    Ok(BString::from(buf))
90}
91
92fn r_float_str(p: &mut RFile<impl Read>) -> ParseResult<f64> {
93    let n = r_byte(p)?;
94    let s = r_string(n as usize, p)?;
95    Ok(f64::from_str(s.as_ref())?)
96}
97
98// TODO: test
99/// May misbehave on 16-bit platforms.
100fn r_pylong(p: &mut RFile<impl Read>) -> ParseResult<BigInt> {
101    #[allow(clippy::cast_possible_wrap)]
102    let n = r_long(p)? as i32;
103    if n == 0 {
104        return Ok(BigInt::zero());
105    };
106    #[allow(clippy::cast_sign_loss)]
107    let size = n.wrapping_abs() as u32;
108    let mut digits = Vec::<u16>::with_capacity(size as usize);
109    for _ in 0..size {
110        let d = r_short(p)?;
111        if d > (1 << 15) {
112            return Err(ErrorKind::DigitOutOfRange(d).into());
113        }
114        digits.push(d);
115    }
116    if digits[(size - 1) as usize] == 0 {
117        return Err(ErrorKind::UnnormalizedLong.into());
118    }
119    Ok(BigInt::from_biguint(
120        utils::sign_of(&n),
121        utils::biguint_from_pylong_digits(&digits),
122    ))
123}
124
125fn r_vec(n: usize, p: &mut RFile<impl Read>) -> ParseResult<Vec<Obj>> {
126    let mut vec = Vec::with_capacity(n);
127    for _ in 0..n {
128        vec.push(r_object_not_null(p)?);
129    }
130    Ok(vec)
131}
132
133fn r_hashmap(p: &mut RFile<impl Read>) -> ParseResult<HashMap<ObjHashable, Obj>> {
134    let mut map = HashMap::new();
135    loop {
136        match r_object(p)? {
137            None => break,
138            Some(key) => match r_object(p)? {
139                None => break,
140                Some(value) => {
141                    map.insert(
142                        ObjHashable::try_from(&key).map_err(ErrorKind::Unhashable)?,
143                        value,
144                    );
145                } // TODO
146            },
147        }
148    }
149    Ok(map)
150}
151
152fn r_hashset(n: usize, p: &mut RFile<impl Read>) -> ParseResult<HashSet<ObjHashable>> {
153    let mut set = HashSet::new();
154    r_hashset_into(&mut set, n, p)?;
155    Ok(set)
156}
157
158fn r_hashset_into(
159    set: &mut HashSet<ObjHashable>,
160    n: usize,
161    p: &mut RFile<impl Read>,
162) -> ParseResult<()> {
163    for _ in 0..n {
164        set.insert(ObjHashable::try_from(&r_object_not_null(p)?).map_err(ErrorKind::Unhashable)?);
165    }
166    Ok(())
167}
168
169#[allow(clippy::too_many_lines)]
170fn r_object(p: &mut RFile<impl Read>) -> ParseResult<Option<Obj>> {
171    let code: u8 = r_byte(p)?;
172    let _depth_handle = p
173        .depth
174        .try_clone()
175        .map_or(Err(ErrorKind::RecursionLimitExceeded), Ok)?;
176    let (flag, type_) = {
177        let flag: bool = (code & Type::FLAG_REF) != 0;
178        let type_u8: u8 = code & !Type::FLAG_REF;
179        let type_: Type =
180            Type::from_u8(type_u8).map_or(Err(ErrorKind::InvalidType(type_u8)), Ok)?;
181        if let Type::Bytes = type_ {
182            // this is a fake type
183            return Err(ErrorKind::InvalidType(type_u8));
184        }
185        (flag, type_)
186    };
187    let mut idx: Option<usize> = match type_ {
188        // immutable collections
189        Type::Tuple | Type::FrozenSet | Type::Code if flag => {
190            let i = p.refs.len();
191            p.refs.push(Obj::None);
192            Some(i)
193        }
194        _ => None,
195    };
196    #[allow(clippy::cast_possible_wrap)]
197    let retval = match type_ {
198        Type::Bytes => unreachable!(),
199        Type::Null => None,
200        Type::None => Some(Obj::None),
201        Type::StopIter => Some(Obj::StopIteration),
202        Type::Ellipsis => Some(Obj::Ellipsis),
203        Type::False => Some(Obj::Bool(false)),
204        Type::True => Some(Obj::Bool(true)),
205        Type::Int => Some(Obj::Long(Arc::new(BigInt::from(r_long(p)? as i32)))),
206        Type::Int64 => Some(Obj::Long(Arc::new(BigInt::from(r_long64(p)? as i64)))),
207        Type::Long => Some(Obj::Long(Arc::new(r_pylong(p)?))),
208        Type::Float => Some(Obj::Float(r_float_str(p)?)),
209        Type::BinaryFloat => Some(Obj::Float(r_float_bin(p)?)),
210        Type::Complex => Some(Obj::Complex(Complex {
211            re: r_float_str(p)?,
212            im: r_float_str(p)?,
213        })),
214        Type::BinaryComplex => Some(Obj::Complex(Complex {
215            re: r_float_bin(p)?,
216            im: r_float_bin(p)?,
217        })),
218        Type::String | Type::Unicode => {
219            let obj = Obj::String(Arc::new(r_bstring(r_long(p)? as usize, p)?));
220            Some(obj)
221        }
222        Type::StringRef => {
223            let n = r_long(p)? as usize;
224            let result = p.stringrefs.get(n).ok_or(ErrorKind::InvalidRef)?.clone();
225            if result.is_none() {
226                return Err(ErrorKind::InvalidRef.into());
227            } else {
228                Some(result)
229            }
230        }
231        Type::Interned => {
232            let obj = Obj::String(Arc::new(r_bstring(r_long(p)? as usize, p)?));
233            p.stringrefs.push(obj.clone());
234            Some(obj)
235        }
236        Type::Tuple => Some(Obj::Tuple(Arc::new(r_vec(r_long(p)? as usize, p)?))),
237        Type::List => Some(Obj::List(Arc::new(RwLock::new(r_vec(
238            r_long(p)? as usize,
239            p,
240        )?)))),
241        Type::Set => {
242            let set = Arc::new(RwLock::new(HashSet::new()));
243
244            if flag {
245                idx = Some(p.refs.len());
246                p.refs.push(Obj::Set(Arc::clone(&set)));
247            }
248
249            r_hashset_into(&mut *set.write().unwrap(), r_long(p)? as usize, p)?;
250            Some(Obj::Set(set))
251        }
252        Type::FrozenSet => Some(Obj::FrozenSet(Arc::new(r_hashset(r_long(p)? as usize, p)?))),
253        Type::Dict => Some(Obj::Dict(Arc::new(RwLock::new(r_hashmap(p)?)))),
254        Type::Code => Some(Obj::Code(Arc::new(Code {
255            argcount: r_long(p)?,
256            nlocals: r_long(p)?,
257            stacksize: r_long(p)?,
258            flags: CodeFlags::from_bits_truncate(r_long(p)?),
259            code: r_object_extract_bytes(p)?,
260            consts: r_object_extract_tuple(p)?,
261            names: r_object_extract_tuple_string(p)?,
262            varnames: r_object_extract_tuple_string(p)?,
263            freevars: r_object_extract_tuple_string(p)?,
264            cellvars: r_object_extract_tuple_string(p)?,
265            filename: r_object_extract_string(p)?,
266            name: r_object_extract_string(p)?,
267            firstlineno: r_long(p)?,
268            lnotab: r_object_extract_bytes(p)?,
269        }))),
270        Type::Unknown => return Err(ErrorKind::InvalidType(Type::Unknown as u8).into()),
271    };
272    match (&retval, idx) {
273        (None, _)
274        | (Some(Obj::None), _)
275        | (Some(Obj::StopIteration), _)
276        | (Some(Obj::Ellipsis), _)
277        | (Some(Obj::Bool(_)), _) => {}
278        (Some(x), Some(i)) if flag => {
279            p.refs[i] = x.clone();
280        }
281        (Some(x), None) if flag => {
282            p.refs.push(x.clone());
283        }
284        (Some(_), _) => {}
285    };
286    Ok(retval)
287}
288
289fn r_object_not_null(p: &mut RFile<impl Read>) -> ParseResult<Obj> {
290    Ok(r_object(p)?.ok_or(ErrorKind::IsNull)?)
291}
292fn r_object_extract_string(p: &mut RFile<impl Read>) -> ParseResult<Arc<BString>> {
293    Ok(r_object_not_null(p)?
294        .extract_string()
295        .map_err(ErrorKind::TypeError)?)
296}
297fn r_object_extract_bytes(p: &mut RFile<impl Read>) -> ParseResult<Arc<Vec<u8>>> {
298    Ok(r_object_not_null(p)?
299        .extract_string()
300        .map_err(ErrorKind::TypeError)?)
301    // this forces an allocation but makes some operations easier
302    .map(|bytes| Arc::new(bytes.to_vec()))
303}
304fn r_object_extract_tuple(p: &mut RFile<impl Read>) -> ParseResult<Arc<Vec<Obj>>> {
305    Ok(r_object_not_null(p)?
306        .extract_tuple()
307        .map_err(ErrorKind::TypeError)?)
308}
309fn r_object_extract_tuple_string(p: &mut RFile<impl Read>) -> ParseResult<Vec<Arc<BString>>> {
310    Ok(r_object_extract_tuple(p)?
311        .iter()
312        .map(|x| x.clone().extract_string().map_err(ErrorKind::TypeError))
313        .collect::<ParseResult<Vec<Arc<BString>>>>()?)
314}
315
316fn read_object(p: &mut RFile<impl Read>) -> ParseResult<Obj> {
317    r_object_not_null(p)
318}
319
320#[derive(Copy, Clone, Debug)]
321pub struct MarshalLoadExOptions {
322    pub has_posonlyargcount: bool,
323}
324/// Assume latest version
325impl Default for MarshalLoadExOptions {
326    fn default() -> Self {
327        Self {
328            has_posonlyargcount: true,
329        }
330    }
331}
332
333/// # Errors
334/// See [`ErrorKind`].
335pub fn marshal_load_ex(readable: impl Read, opts: MarshalLoadExOptions) -> ParseResult<Obj> {
336    let mut rf = RFile {
337        depth: Depth::new(),
338        readable,
339        refs: Vec::<Obj>::new(),
340        stringrefs: Vec::<Obj>::new(),
341        has_posonlyargcount: opts.has_posonlyargcount,
342    };
343    read_object(&mut rf)
344}
345
346/// # Errors
347/// See [`ErrorKind`].
348pub fn marshal_load(readable: impl Read) -> ParseResult<Obj> {
349    marshal_load_ex(readable, MarshalLoadExOptions::default())
350}
351
352/// Allows coercion from array reference to slice.
353/// # Errors
354/// See [`ErrorKind`].
355pub fn marshal_loads(bytes: &[u8]) -> ParseResult<Obj> {
356    marshal_load(bytes)
357}
358
359// Ported from <https://github.com/python/cpython/blob/master/Lib/test/test_marshal.py>
360#[cfg(test)]
361mod test {
362    use super::{
363        errors, marshal_load, marshal_load_ex, marshal_loads, Code, CodeFlags,
364        MarshalLoadExOptions, Obj, ObjHashable,
365    };
366    use bstr::BString;
367    use num_bigint::BigInt;
368    use num_traits::Pow;
369    use std::{
370        io::{self, Read},
371        ops::Deref,
372        sync::Arc,
373    };
374
375    macro_rules! assert_match {
376        ($expr:expr, $pat:pat) => {
377            match $expr {
378                $pat => {}
379                _ => panic!(),
380            }
381        };
382    }
383
384    fn load_unwrap(r: impl Read) -> Obj {
385        marshal_load(r).unwrap()
386    }
387
388    fn loads_unwrap(s: &[u8]) -> Obj {
389        load_unwrap(s)
390    }
391
392    #[test]
393    fn test_ints() {
394        assert_eq!(BigInt::parse_bytes(b"85070591730234615847396907784232501249", 10).unwrap(), *loads_unwrap(b"l\t\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\xf0\x7f\xff\x7f\xff\x7f\xff\x7f?\x00").extract_long().unwrap());
395    }
396
397    #[allow(clippy::unreadable_literal)]
398    #[test]
399    fn test_int64() {
400        for mut base in [i64::MAX, i64::MIN, -i64::MAX, -(i64::MIN >> 1)]
401            .iter()
402            .copied()
403        {
404            while base != 0 {
405                let mut s = Vec::<u8>::new();
406                s.push(b'I');
407                s.extend_from_slice(&base.to_le_bytes());
408                assert_eq!(
409                    BigInt::from(base),
410                    *loads_unwrap(&s).extract_long().unwrap()
411                );
412
413                if base == -1 {
414                    base = 0
415                } else {
416                    base >>= 1
417                }
418            }
419        }
420
421        assert_eq!(
422            BigInt::from(0x1032547698badcfe_i64),
423            *loads_unwrap(b"I\xfe\xdc\xba\x98\x76\x54\x32\x10")
424                .extract_long()
425                .unwrap()
426        );
427        assert_eq!(
428            BigInt::from(-0x1032547698badcff_i64),
429            *loads_unwrap(b"I\x01\x23\x45\x67\x89\xab\xcd\xef")
430                .extract_long()
431                .unwrap()
432        );
433        assert_eq!(
434            BigInt::from(0x7f6e5d4c3b2a1908_i64),
435            *loads_unwrap(b"I\x08\x19\x2a\x3b\x4c\x5d\x6e\x7f")
436                .extract_long()
437                .unwrap()
438        );
439        assert_eq!(
440            BigInt::from(-0x7f6e5d4c3b2a1909_i64),
441            *loads_unwrap(b"I\xf7\xe6\xd5\xc4\xb3\xa2\x91\x80")
442                .extract_long()
443                .unwrap()
444        );
445    }
446
447    #[test]
448    fn test_bool() {
449        assert_eq!(true, loads_unwrap(b"T").extract_bool().unwrap());
450        assert_eq!(false, loads_unwrap(b"F").extract_bool().unwrap());
451    }
452
453    #[allow(clippy::float_cmp, clippy::cast_precision_loss)]
454    #[test]
455    fn test_floats() {
456        assert_eq!(
457            (i64::MAX as f64) * 3.7e250,
458            loads_unwrap(b"g\x11\x9f6\x98\xd2\xab\xe4w")
459                .extract_float()
460                .unwrap()
461        );
462    }
463
464    #[test]
465    fn test_unicode() {
466        assert_eq!("", *loads_unwrap(b"\xda\x00").extract_string().unwrap());
467        assert_eq!(
468            "Andr\u{e8} Previn",
469            *loads_unwrap(b"u\r\x00\x00\x00Andr\xc3\xa8 Previn")
470                .extract_string()
471                .unwrap()
472        );
473        assert_eq!(
474            "abc",
475            *loads_unwrap(b"\xda\x03abc").extract_string().unwrap()
476        );
477        assert_eq!(
478            " ".repeat(10_000),
479            *loads_unwrap(&[b"a\x10'\x00\x00" as &[u8], &[b' '; 10_000]].concat())
480                .extract_string()
481                .unwrap()
482        );
483    }
484
485    #[test]
486    fn test_string() {
487        assert_eq!("", *loads_unwrap(b"\xda\x00").extract_string().unwrap());
488        assert_eq!(
489            "Andr\u{e8} Previn",
490            *loads_unwrap(b"\xf5\r\x00\x00\x00Andr\xc3\xa8 Previn")
491                .extract_string()
492                .unwrap()
493        );
494        assert_eq!(
495            "abc",
496            *loads_unwrap(b"\xda\x03abc").extract_string().unwrap()
497        );
498        assert_eq!(
499            " ".repeat(10_000),
500            *loads_unwrap(&[b"\xe1\x10'\x00\x00" as &[u8], &[b' '; 10_000]].concat())
501                .extract_string()
502                .unwrap()
503        );
504    }
505
506    #[test]
507    fn test_bytes() {
508        assert_eq!(
509            b"",
510            &loads_unwrap(b"\xf3\x00\x00\x00\x00")
511                .extract_bytes()
512                .unwrap()[..]
513        );
514        assert_eq!(
515            b"Andr\xe8 Previn",
516            &loads_unwrap(b"\xf3\x0c\x00\x00\x00Andr\xe8 Previn")
517                .extract_bytes()
518                .unwrap()[..]
519        );
520        assert_eq!(
521            b"abc",
522            &loads_unwrap(b"\xf3\x03\x00\x00\x00abc")
523                .extract_bytes()
524                .unwrap()[..]
525        );
526        assert_eq!(
527            b" ".repeat(10_000),
528            &loads_unwrap(&[b"\xf3\x10'\x00\x00" as &[u8], &[b' '; 10_000]].concat())
529                .extract_bytes()
530                .unwrap()[..]
531        );
532    }
533
534    #[test]
535    fn test_exceptions() {
536        loads_unwrap(b"S").extract_stop_iteration().unwrap();
537    }
538
539    fn assert_test_exceptions_code_valid(code: &Code) {
540        assert_eq!(code.argcount, 1);
541        assert!(code.cellvars.is_empty());
542        assert_eq!(*code.code, &b"t\x00\xa0\x01t\x00\xa0\x02t\x03\xa1\x01\xa1\x01}\x01|\x00\xa0\x04t\x03|\x01\xa1\x02\x01\x00d\x00S\x00"[..]);
543        assert_eq!(code.consts.len(), 1);
544        assert!(code.consts[0].is_none());
545        assert_eq!(*code.filename, "<string>");
546        assert_eq!(code.firstlineno, 3);
547        assert_eq!(
548            code.flags,
549            CodeFlags::NOFREE | CodeFlags::NEWLOCALS | CodeFlags::OPTIMIZED
550        );
551        assert!(code.freevars.is_empty());
552        assert_eq!(*code.lnotab, &b"\x00\x01\x10\x01"[..]);
553        assert_eq!(*code.name, "test_exceptions");
554        assert!(code.names.iter().map(Deref::deref).eq(vec![
555            "marshal",
556            "loads",
557            "dumps",
558            "StopIteration",
559            "assertEqual"
560        ]
561        .iter()));
562        assert_eq!(code.nlocals, 2);
563        assert_eq!(code.stacksize, 5);
564        assert!(code
565            .varnames
566            .iter()
567            .map(Deref::deref)
568            .eq(vec!["self", "new"].iter()));
569    }
570
571    #[test]
572    fn test_code() {
573        // ExceptionTestCase.test_exceptions
574        // { 'co_argcount': 1, 'co_cellvars': (), 'co_code': b't\x00\xa0\x01t\x00\xa0\x02t\x03\xa1\x01\xa1\x01}\x01|\x00\xa0\x04t\x03|\x01\xa1\x02\x01\x00d\x00S\x00', 'co_consts': (None,), 'co_filename': '<string>', 'co_firstlineno': 3, 'co_flags': 67, 'co_freevars': (), 'co_kwonlyargcount': 0, 'co_lnotab': b'\x00\x01\x10\x01', 'co_name': 'test_exceptions', 'co_names': ('marshal', 'loads', 'dumps', 'StopIteration', 'assertEqual'), 'co_nlocals': 2, 'co_stacksize': 5, 'co_varnames': ('self', 'new') }
575        let mut input: &[u8] = b"\xe3\x01\x00\x00\x00\x02\x00\x00\x00\x05\x00\x00\x00C\x00\x00\x00s \x00\x00\x00t\x00\xa0\x01t\x00\xa0\x02t\x03\xa1\x01\xa1\x01}\x01|\x00\xa0\x04t\x03|\x01\xa1\x02\x01\x00d\x00S\x00)\x01N)\x05\xda\x07marshal\xda\x05loads\xda\x05dumps\xda\rStopIteration\xda\x0bassertEqual)\x02\xda\x04self\xda\x03new\xa9\x00r\x08\x00\x00\x00\xda\x08<string>\xda\x0ftest_exceptions\x03\x00\x00\x00s\x04\x00\x00\x00\x00\x01\x10\x01";
576        println!("{}", input.len());
577        let code_result = marshal_load_ex(
578            &mut input,
579            MarshalLoadExOptions {
580                has_posonlyargcount: false,
581            },
582        );
583        println!("{}", input.len());
584        let code = code_result.unwrap().extract_code().unwrap();
585        assert_test_exceptions_code_valid(&code);
586    }
587
588    #[test]
589    fn test_many_codeobjects() {
590        let mut input: &[u8] = &[b"(\x88\x13\x00\x00\xe3\x01\x00\x00\x00\x02\x00\x00\x00\x05\x00\x00\x00C\x00\x00\x00s \x00\x00\x00t\x00\xa0\x01t\x00\xa0\x02t\x03\xa1\x01\xa1\x01}\x01|\x00\xa0\x04t\x03|\x01\xa1\x02\x01\x00d\x00S\x00)\x01N)\x05\xda\x07marshal\xda\x05loads\xda\x05dumps\xda\rStopIteration\xda\x0bassertEqual)\x02\xda\x04self\xda\x03new\xa9\x00r\x08\x00\x00\x00\xda\x08<string>\xda\x0ftest_exceptions\x03\x00\x00\x00s\x04\x00\x00\x00\x00\x01\x10\x01" as &[u8], &b"r\x00\x00\x00\x00".repeat(4999)].concat();
591        let result = marshal_load_ex(
592            &mut input,
593            MarshalLoadExOptions {
594                has_posonlyargcount: false,
595            },
596        );
597        let tuple = result.unwrap().extract_tuple().unwrap();
598        for o in &*tuple {
599            assert_test_exceptions_code_valid(&o.clone().extract_code().unwrap());
600        }
601    }
602
603    #[test]
604    fn test_different_filenames() {
605        let mut input: &[u8] = b")\x02c\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00@\x00\x00\x00s\x08\x00\x00\x00e\x00\x01\x00d\x00S\x00)\x01N)\x01\xda\x01x\xa9\x00r\x01\x00\x00\x00r\x01\x00\x00\x00\xda\x02f1\xda\x08<module>\x01\x00\x00\x00\xf3\x00\x00\x00\x00c\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00@\x00\x00\x00s\x08\x00\x00\x00e\x00\x01\x00d\x00S\x00)\x01N)\x01\xda\x01yr\x01\x00\x00\x00r\x01\x00\x00\x00r\x01\x00\x00\x00\xda\x02f2r\x03\x00\x00\x00\x01\x00\x00\x00r\x04\x00\x00\x00";
606        println!("{}", input.len());
607        let result = marshal_load_ex(
608            &mut input,
609            MarshalLoadExOptions {
610                has_posonlyargcount: false,
611            },
612        );
613        println!("{}", input.len());
614        let tuple = result.unwrap().extract_tuple().unwrap();
615        assert_eq!(tuple.len(), 2);
616        assert_eq!(*tuple[0].clone().extract_code().unwrap().filename, "f1");
617        assert_eq!(*tuple[1].clone().extract_code().unwrap().filename, "f2");
618    }
619
620    #[allow(clippy::float_cmp)]
621    #[test]
622    fn test_dict() {
623        let mut input: &[u8] = b"{\xda\x07astring\xfa\x10foo@bar.baz.spam\xda\x06afloat\xe7H\xe1z\x14ns\xbc@\xda\x05anint\xe9\x00\x00\x10\x00\xda\nashortlong\xe9\x02\x00\x00\x00\xda\x05alist[\x01\x00\x00\x00\xfa\x07.zyx.41\xda\x06atuple\xa9\n\xfa\x07.zyx.41r\x0c\x00\x00\x00r\x0c\x00\x00\x00r\x0c\x00\x00\x00r\x0c\x00\x00\x00r\x0c\x00\x00\x00r\x0c\x00\x00\x00r\x0c\x00\x00\x00r\x0c\x00\x00\x00r\x0c\x00\x00\x00\xda\x08abooleanF\xda\x08aunicode\xf5\r\x00\x00\x00Andr\xc3\xa8 Previn0";
624        println!("{}", input.len());
625        let result = marshal_load(&mut input);
626        println!("{}", input.len());
627        let dict_ref = result.unwrap().extract_dict().unwrap();
628        let dict = dict_ref.try_read().unwrap();
629        assert_eq!(dict.len(), 8);
630        assert_eq!(
631            *dict[&ObjHashable::String(Arc::new(BString::from("astring")))]
632                .clone()
633                .extract_string()
634                .unwrap(),
635            "foo@bar.baz.spam"
636        );
637        assert_eq!(
638            dict[&ObjHashable::String(Arc::new(BString::from("afloat")))]
639                .clone()
640                .extract_float()
641                .unwrap(),
642            7283.43_f64
643        );
644        assert_eq!(
645            *dict[&ObjHashable::String(Arc::new(BString::from("anint")))]
646                .clone()
647                .extract_long()
648                .unwrap(),
649            BigInt::from(2).pow(20_u8)
650        );
651        assert_eq!(
652            *dict[&ObjHashable::String(Arc::new(BString::from("ashortlong")))]
653                .clone()
654                .extract_long()
655                .unwrap(),
656            BigInt::from(2)
657        );
658
659        let list_ref = dict[&ObjHashable::String(Arc::new(BString::from("alist")))]
660            .clone()
661            .extract_list()
662            .unwrap();
663        let list = list_ref.try_read().unwrap();
664        assert_eq!(list.len(), 1);
665        assert_eq!(*list[0].clone().extract_string().unwrap(), ".zyx.41");
666
667        let tuple = dict[&ObjHashable::String(Arc::new(BString::from("atuple")))]
668            .clone()
669            .extract_tuple()
670            .unwrap();
671        assert_eq!(tuple.len(), 10);
672        for o in &*tuple {
673            assert_eq!(*o.clone().extract_string().unwrap(), ".zyx.41");
674        }
675        assert_eq!(
676            dict[&ObjHashable::String(Arc::new(BString::from("aboolean")))]
677                .clone()
678                .extract_bool()
679                .unwrap(),
680            false
681        );
682        assert_eq!(
683            *dict[&ObjHashable::String(Arc::new(BString::from("aunicode")))]
684                .clone()
685                .extract_string()
686                .unwrap(),
687            "Andr\u{e8} Previn"
688        );
689    }
690
691    /// Tests hash implementation
692    #[test]
693    fn test_dict_tuple_key() {
694        let dict = loads_unwrap(b"{\xa9\x02\xda\x01a\xda\x01b\xda\x01c0")
695            .extract_dict()
696            .unwrap();
697        assert_eq!(dict.read().unwrap().len(), 1);
698        assert_eq!(
699            *dict.read().unwrap()[&ObjHashable::Tuple(Arc::new(vec![
700                ObjHashable::String(Arc::new(BString::from("a"))),
701                ObjHashable::String(Arc::new(BString::from("b")))
702            ]))]
703                .clone()
704                .extract_string()
705                .unwrap(),
706            "c"
707        );
708    }
709
710    // TODO: test_list and test_tuple
711
712    #[test]
713    fn test_sets() {
714        let set = loads_unwrap(b"<\x08\x00\x00\x00\xda\x05alist\xda\x08aboolean\xda\x07astring\xda\x08aunicode\xda\x06afloat\xda\x05anint\xda\x06atuple\xda\nashortlong").extract_set().unwrap();
715        assert_eq!(set.read().unwrap().len(), 8);
716        let frozenset = loads_unwrap(b">\x08\x00\x00\x00\xda\x06atuple\xda\x08aunicode\xda\x05anint\xda\x08aboolean\xda\x06afloat\xda\x05alist\xda\nashortlong\xda\x07astring").extract_frozenset().unwrap();
717        assert_eq!(frozenset.len(), 8);
718        // TODO: check values
719    }
720
721    // TODO: test_bytearray, test_memoryview, test_array
722
723    #[test]
724    fn test_patch_873224() {
725        assert_match!(marshal_loads(b"0").unwrap_err(), errors::ErrorKind::IsNull);
726        let f_err = marshal_loads(b"f").unwrap_err();
727        match f_err {
728            errors::ErrorKind::Io(io_err) => {
729                assert_eq!(io_err.kind(), io::ErrorKind::UnexpectedEof);
730            }
731            _ => panic!(),
732        }
733        let int_err =
734            marshal_loads(b"l\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00 ").unwrap_err();
735        match int_err {
736            errors::ErrorKind::Io(io_err) => {
737                assert_eq!(io_err.kind(), io::ErrorKind::UnexpectedEof);
738            }
739            _ => panic!(),
740        }
741    }
742
743    #[test]
744    fn test_fuzz() {
745        for i in 0..=u8::MAX {
746            println!("{:?}", marshal_loads(&[i]));
747        }
748    }
749
750    /// Warning: this has to be run on a release build to avoid a stack overflow.
751    #[cfg(not(debug_assertions))]
752    #[test]
753    fn test_loads_recursion() {
754        loads_unwrap(&[&b")\x01".repeat(100)[..], b"N"].concat());
755        loads_unwrap(&[&b"(\x01\x00\x00\x00".repeat(100)[..], b"N"].concat());
756        loads_unwrap(&[&b"[\x01\x00\x00\x00".repeat(100)[..], b"N"].concat());
757        loads_unwrap(&[&b"{N".repeat(100)[..], b"N", &b"0".repeat(100)[..]].concat());
758        loads_unwrap(&[&b">\x01\x00\x00\x00".repeat(100)[..], b"N"].concat());
759
760        assert_match!(
761            marshal_loads(&[&b")\x01".repeat(1048576)[..], b"N"].concat())
762                .unwrap_err()
763                .kind(),
764            errors::ErrorKind::RecursionLimitExceeded
765        );
766        assert_match!(
767            marshal_loads(&[&b"(\x01\x00\x00\x00".repeat(1048576)[..], b"N"].concat())
768                .unwrap_err()
769                .kind(),
770            errors::ErrorKind::RecursionLimitExceeded
771        );
772        assert_match!(
773            marshal_loads(&[&b"[\x01\x00\x00\x00".repeat(1048576)[..], b"N"].concat())
774                .unwrap_err()
775                .kind(),
776            errors::ErrorKind::RecursionLimitExceeded
777        );
778        assert_match!(
779            marshal_loads(&[&b"{N".repeat(1048576)[..], b"N", &b"0".repeat(1048576)[..]].concat())
780                .unwrap_err()
781                .kind(),
782            errors::ErrorKind::RecursionLimitExceeded
783        );
784        assert_match!(
785            marshal_loads(&[&b">\x01\x00\x00\x00".repeat(1048576)[..], b"N"].concat())
786                .unwrap_err()
787                .kind(),
788            errors::ErrorKind::RecursionLimitExceeded
789        );
790    }
791
792    #[test]
793    fn test_invalid_longs() {
794        assert_match!(
795            marshal_loads(b"l\x02\x00\x00\x00\x00\x00\x00\x00").unwrap_err(),
796            errors::ErrorKind::UnnormalizedLong
797        );
798    }
799}