python_marshal/
lib.rs

1pub mod code_objects;
2pub mod error;
3pub mod magic;
4mod optimizer;
5mod reader;
6pub mod resolver;
7mod walker;
8mod writer;
9
10use bitflags::bitflags;
11use bstr::BString;
12use error::Error;
13use hashable::HashableHashSet;
14use indexmap::{IndexMap, IndexSet};
15use magic::PyVersion;
16use num_bigint::BigInt;
17use num_complex::Complex;
18use num_derive::{FromPrimitive, ToPrimitive};
19use optimizer::{get_used_references, ReferenceOptimizer, Transformable};
20use ordered_float::OrderedFloat;
21use reader::PyReader;
22use std::io::{Read, Write};
23use writer::PyWriter;
24
25#[derive(Debug, Clone, Copy, FromPrimitive, ToPrimitive, PartialEq, Eq, Hash)]
26#[repr(u8)]
27#[rustfmt::skip]
28pub enum Kind {
29    Null               = b'0',
30    None               = b'N',
31    False              = b'F',
32    True               = b'T',
33    StopIteration      = b'S',
34    Ellipsis           = b'.',
35    Int                = b'i',
36    Int64              = b'I', // Only generated in version 0
37    Float              = b'f', // Only generated in marshal version 0
38    BinaryFloat        = b'g',
39    Complex            = b'x', // Only generated in marshal version 0
40    BinaryComplex      = b'y',
41    Long               = b'l',
42    String             = b's',
43    Interned           = b't',
44    Ref                = b'r',
45    Tuple              = b'(',
46    List               = b'[',
47    Dict               = b'{',
48    Code               = b'c',
49    Unicode            = b'u',
50    Unknown            = b'?',
51    Set                = b'<',
52    FrozenSet          = b'>',
53    ASCII              = b'a',
54    ASCIIInterned      = b'A',
55    SmallTuple         = b')',
56    ShortAscii         = b'z',
57    ShortAsciiInterned = b'Z',
58    FlagRef            = 0x80,
59}
60
61bitflags! {
62    #[derive(Clone, Debug, PartialEq)]
63    pub struct CodeFlags: u32 {
64        const OPTIMIZED                   = 0x1;
65        const NEWLOCALS                   = 0x2;
66        const VARARGS                     = 0x4;
67        const VARKEYWORDS                 = 0x8;
68        const NESTED                     = 0x10;
69        const GENERATOR                  = 0x20;
70
71        const NOFREE                     = 0x40; // Removed in 3.10
72
73        const COROUTINE                  = 0x80;
74        const ITERABLE_COROUTINE        = 0x100;
75        const ASYNC_GENERATOR           = 0x200;
76
77        const GENERATOR_ALLOWED        = 0x1000;
78
79        const FUTURE_DIVISION          = 0x2000;
80        const FUTURE_ABSOLUTE_IMPORT   = 0x4000;
81        const FUTURE_WITH_STATEMENT    = 0x8000;
82        const FUTURE_PRINT_FUNCTION   = 0x10000;
83        const FUTURE_UNICODE_LITERALS = 0x20000;
84
85        const FUTURE_BARRY_AS_BDFL    = 0x40000;
86        const FUTURE_GENERATOR_STOP   = 0x80000;
87        const FUTURE_ANNOTATIONS     = 0x100000;
88
89        const NO_MONITORING_EVENTS    = 0x200000; // Added in 3.13
90    }
91}
92
93// Code object enum for all supported Python versions
94#[derive(Clone, Debug, PartialEq)]
95pub enum Code {
96    // Contains the code object for Python 3.10
97    V310(code_objects::Code310),
98    // Contains the code object for Python 3.11
99    V311(code_objects::Code311),
100    // Contains the code object for Python 3.12 which is exactly the same as 3.11 so we use the same struct
101    V312(code_objects::Code311),
102    // Contains the code object for Python 3.13 which is exactly the same as 3.11 so we use the same struct
103    V313(code_objects::Code311),
104}
105
106#[derive(Clone, Debug, PartialEq, Eq, Hash)]
107pub struct PyString {
108    pub value: BString,
109    pub kind: Kind,
110}
111
112impl From<String> for PyString {
113    fn from(value: String) -> Self {
114        Self {
115            value: value.clone().into(),
116            kind: {
117                if value.is_ascii() {
118                    if value.len() <= 255 {
119                        Kind::ShortAscii
120                    } else {
121                        Kind::ASCII
122                    }
123                } else {
124                    Kind::Unicode
125                }
126            }, // Default kind
127        }
128    }
129}
130
131impl PyString {
132    pub fn new(value: BString, kind: Kind) -> Self {
133        Self { value, kind }
134    }
135}
136
137#[rustfmt::skip]
138#[derive(Clone, Debug, PartialEq)]
139pub enum Object {
140    None,
141    StopIteration,
142    Ellipsis,
143    Bool      (bool),
144    Long      (BigInt),
145    Float     (f64),
146    Complex   (Complex<f64>),
147    Bytes     (Vec<u8>),
148    String    (PyString),
149    Tuple     (Vec<Object>),
150    List      (Vec<Object>),
151    Dict      (IndexMap<ObjectHashable, Object>),
152    Set       (IndexSet<ObjectHashable>),
153    FrozenSet (IndexSet<ObjectHashable>),
154    Code      (Box<Code>),
155    LoadRef   (usize),
156    StoreRef  (usize),
157}
158
159// impl Eq for Object {} // Required to check if Code objects are equal with float values
160
161#[rustfmt::skip]
162#[derive(Clone, Debug, PartialEq, Eq, Hash)]
163pub enum ObjectHashable {
164    None,
165    StopIteration,
166    Ellipsis,
167    Bool      (bool),
168    Long      (BigInt),
169    Float     (OrderedFloat<f64>),
170    Complex   (Complex<OrderedFloat<f64>>),
171    Bytes     (Vec<u8>),
172    String    (PyString),
173    Tuple     (Vec<ObjectHashable>),
174    FrozenSet (HashableHashSet<ObjectHashable>),
175    LoadRef   (usize), // You need to ensure that the reference is hashable
176    StoreRef  (usize), // Same as above
177}
178
179impl ObjectHashable {
180    pub fn from_ref(obj: Object, references: &Vec<Object>) -> Result<Self, Error> {
181        // If the object is a reference, resolve it and make sure it's hashable
182        match obj {
183            Object::LoadRef(index) | Object::StoreRef(index) => {
184                if let Some(resolved_obj) = references.get(index) {
185                    let resolved_obj = resolved_obj.clone();
186                    Self::from_ref(resolved_obj.clone(), references)?;
187                    match obj {
188                        Object::LoadRef(index) => Ok(Self::LoadRef(index)),
189                        Object::StoreRef(index) => Ok(Self::StoreRef(index)),
190                        _ => unreachable!(),
191                    }
192                } else {
193                    Err(Error::InvalidReference)
194                }
195            }
196            Object::Tuple(t) => Ok(Self::Tuple(
197                // Tuple can contain references
198                t.iter()
199                    .map(|o| Self::from_ref((*o).clone(), references))
200                    .collect::<Result<Vec<_>, _>>()?,
201            )),
202            _ => Self::try_from(obj),
203        }
204    }
205}
206
207impl TryFrom<Object> for ObjectHashable {
208    type Error = Error;
209
210    fn try_from(obj: Object) -> Result<Self, Self::Error> {
211        match obj {
212            Object::None => Ok(ObjectHashable::None),
213            Object::StopIteration => Ok(ObjectHashable::StopIteration),
214            Object::Ellipsis => Ok(ObjectHashable::Ellipsis),
215            Object::Bool(b) => Ok(ObjectHashable::Bool(b)),
216            Object::Long(i) => Ok(ObjectHashable::Long(i)),
217            Object::Float(f) => Ok(ObjectHashable::Float(f.into())),
218            Object::Complex(c) => Ok(ObjectHashable::Complex(Complex {
219                re: OrderedFloat(c.re),
220                im: OrderedFloat(c.im),
221            })),
222            Object::Bytes(b) => Ok(ObjectHashable::Bytes(b)),
223            Object::String(s) => Ok(ObjectHashable::String(s)),
224            Object::Tuple(t) => Ok(ObjectHashable::Tuple(
225                t.iter()
226                    .map(|o| ObjectHashable::try_from((*o).clone()))
227                    .collect::<Result<Vec<_>, _>>()?,
228            )),
229            Object::FrozenSet(s) => Ok(ObjectHashable::FrozenSet(
230                s.iter()
231                    .map(|o| (*o).clone())
232                    .collect::<HashableHashSet<_>>(),
233            )),
234            _ => Err(Error::InvalidObject(obj)),
235        }
236    }
237}
238
239impl From<ObjectHashable> for Object {
240    fn from(obj: ObjectHashable) -> Self {
241        match obj {
242            ObjectHashable::None => Object::None,
243            ObjectHashable::StopIteration => Object::StopIteration,
244            ObjectHashable::Ellipsis => Object::Ellipsis,
245            ObjectHashable::Bool(b) => Object::Bool(b),
246            ObjectHashable::Long(i) => Object::Long(i),
247            ObjectHashable::Float(f) => Object::Float(f.into_inner()),
248            ObjectHashable::Complex(c) => Object::Complex(Complex {
249                re: c.re.into_inner(),
250                im: c.im.into_inner(),
251            }),
252            ObjectHashable::Bytes(b) => Object::Bytes(b),
253            ObjectHashable::String(s) => Object::String(s),
254            ObjectHashable::Tuple(t) => Object::Tuple(
255                t.iter()
256                    .map(|o| Object::from((*o).clone()))
257                    .collect::<Vec<_>>(),
258            ),
259            ObjectHashable::FrozenSet(s) => {
260                Object::FrozenSet(s.iter().cloned().collect::<IndexSet<_>>())
261            }
262            ObjectHashable::LoadRef(index) => Object::LoadRef(index),
263            ObjectHashable::StoreRef(index) => Object::StoreRef(index),
264        }
265    }
266}
267
268#[derive(Debug, Clone)]
269pub struct PycFile {
270    pub python_version: PyVersion,
271    pub timestamp: Option<u32>, // Only present in Python 3.7 and later
272    pub hash: u64,
273    pub object: Object,
274    pub references: Vec<Object>,
275}
276
277pub fn optimize_references(object: Object, references: Vec<Object>) -> (Object, Vec<Object>) {
278    // Remove all unused references
279    let mut object = object;
280
281    let usage_counter = get_used_references(&mut object, references.clone());
282
283    let mut optimizer = ReferenceOptimizer::new(references, usage_counter);
284
285    object.transform(&mut optimizer);
286
287    (object, optimizer.new_references)
288}
289
290pub fn load_bytes(data: &[u8], python_version: PyVersion) -> Result<(Object, Vec<Object>), Error> {
291    if python_version < (3, 0) {
292        return Err(Error::UnsupportedPyVersion(python_version));
293    }
294
295    let mut py_reader = PyReader::new(data.to_vec(), python_version);
296
297    let object = py_reader.read_object()?;
298
299    Ok((object, py_reader.references))
300}
301
302pub fn load_pyc(data: impl Read) -> Result<PycFile, Error> {
303    let data = data.bytes().collect::<Result<Vec<u8>, _>>()?;
304
305    let magic_number = u32::from_le_bytes(data[0..4].try_into().map_err(|_| Error::NoMagicNumber)?);
306    let python_version = PyVersion::try_from(magic_number)?;
307
308    let timestamp = if python_version >= (3, 7) {
309        Some(u32::from_le_bytes(
310            data[4..8].try_into().map_err(|_| Error::NoTimeStamp)?,
311        ))
312    } else {
313        None
314    };
315
316    let hash = if python_version >= (3, 7) {
317        u64::from_le_bytes(data[8..16].try_into().map_err(|_| Error::NoHash)?)
318    } else {
319        u64::from_le_bytes(data[4..12].try_into().map_err(|_| Error::NoHash)?)
320    };
321
322    let data = &data[16..];
323
324    let (object, references) = load_bytes(data, python_version)?;
325
326    Ok(PycFile {
327        python_version,
328        timestamp,
329        hash,
330        object,
331        references,
332    })
333}
334
335pub fn dump_pyc(writer: &mut impl Write, pyc_file: PycFile) -> Result<(), Error> {
336    let mut buf = Vec::new();
337    let mut py_writer = PyWriter::new(pyc_file.references, 4);
338
339    buf.extend_from_slice(&u32::to_le_bytes(pyc_file.python_version.to_magic()?));
340    if let Some(timestamp) = pyc_file.timestamp {
341        buf.extend_from_slice(&u32::to_le_bytes(timestamp));
342    }
343    buf.extend_from_slice(&u64::to_le_bytes(pyc_file.hash));
344    buf.extend_from_slice(&py_writer.write_object(Some(pyc_file.object))?);
345
346    std::io::copy(&mut buf.as_slice(), writer)?;
347
348    Ok(())
349}
350
351pub fn dump_bytes(
352    obj: Object,
353    references: Option<Vec<Object>>,
354    python_version: PyVersion,
355    marshal_version: u8,
356) -> Result<Vec<u8>, Error> {
357    if python_version < (3, 0) {
358        return Err(Error::UnsupportedPyVersion(python_version));
359    }
360
361    let mut py_writer = PyWriter::new(references.unwrap_or_default(), marshal_version);
362
363    py_writer.write_object(Some(obj))
364}
365
366#[cfg(test)]
367mod tests {
368    use std::collections::{HashMap, HashSet};
369    use std::io::Write;
370    use std::vec;
371
372    use tempfile::NamedTempFile;
373
374    use error::Error;
375
376    use crate::resolver::{get_recursive_refs, resolve_all_refs};
377
378    use super::*;
379
380    #[test]
381    fn test_load_long() {
382        // 1
383        let data = b"i\x01\x00\x00\x00";
384        let (kind, _) = load_bytes(data, (3, 10).into()).unwrap();
385        assert_eq!(
386            extract_object!(Some(kind), Object::Long(num) => num, Error::UnexpectedObject).unwrap(),
387            BigInt::from(1).into()
388        );
389
390        // 4294967295
391        let data = b"l\x03\x00\x00\x00\xff\x7f\xff\x7f\x03\x00";
392        let (kind, _) = load_bytes(data, (3, 10).into()).unwrap();
393
394        assert_eq!(
395            extract_object!(Some(kind), Object::Long(num) => num, Error::UnexpectedObject).unwrap(),
396            BigInt::from(4294967295u32).into()
397        );
398    }
399
400    #[test]
401    fn test_load_float() {
402        // 1.0
403        let data = b"g\x00\x00\x00\x00\x00\x00\xf0?";
404        let (kind, _) = load_bytes(data, (3, 10).into()).unwrap();
405        assert_eq!(
406            extract_object!(Some(kind), Object::Float(num) => num, Error::UnexpectedObject)
407                .unwrap(),
408            1.0
409        );
410    }
411
412    #[test]
413    fn test_load_complex() {
414        // 3 + 4j
415        let data = b"y\x00\x00\x00\x00\x00\x00\x08@\x00\x00\x00\x00\x00\x00\x10@";
416        let (kind, _) = load_bytes(data, (3, 10).into()).unwrap();
417        assert_eq!(
418            extract_object!(Some(kind), Object::Complex(num) => num, Error::UnexpectedObject)
419                .unwrap(),
420            Complex::new(3.0, 4.0)
421        );
422    }
423
424    #[test]
425    fn test_load_bytes() {
426        // b"test"
427        let data = b"s\x04\x00\x00\x00test";
428        let (kind, _) = load_bytes(data, (3, 10).into()).unwrap();
429        assert_eq!(
430            extract_object!(Some(kind), Object::Bytes(bytes) => bytes, Error::UnexpectedObject)
431                .unwrap(),
432            "test".as_bytes().to_vec()
433        );
434    }
435
436    #[test]
437    fn test_load_string() {
438        // "test"
439        let data = b"Z\x04test";
440        let (kind, _) = load_bytes(data, (3, 10).into()).unwrap();
441        assert_eq!(
442            extract_object!(Some(kind), Object::String(string) => string, Error::UnexpectedObject)
443                .unwrap(),
444            PyString::new("test".into(), Kind::ShortAsciiInterned).into()
445        );
446
447        // "\xe9"
448        let data = b"u\x03\x00\x00\x00\xed\xb2\x80";
449        let (kind, _) = load_bytes(data, (3, 10).into()).unwrap();
450
451        assert_eq!(
452            extract_object!(Some(kind), Object::String(string) => string, Error::UnexpectedObject)
453                .unwrap(),
454            PyString::new(BString::new([237, 178, 128].to_vec()), Kind::Unicode).into()
455        );
456    }
457
458    #[test]
459    fn test_load_tuple() {
460        // Empty tuple
461        let data = b")\x00";
462        let (kind, refs) = load_bytes(data, (3, 10).into()).unwrap();
463        assert_eq!(
464            extract_strings_tuple!(
465                extract_object!(Some(kind), Object::Tuple(tuple) => tuple, Error::UnexpectedObject)
466                    .unwrap(),
467                refs
468            )
469            .unwrap(),
470            vec![]
471        );
472
473        // Tuple with two elements ("a", "b")
474        let data = b")\x02Z\x01aZ\x01b";
475        let (kind, refs) = load_bytes(data, (3, 10).into()).unwrap();
476        assert_eq!(
477            extract_strings_tuple!(
478                extract_object!(Some(kind), Object::Tuple(tuple) => tuple, Error::UnexpectedObject)
479                    .unwrap(),
480                refs
481            )
482            .unwrap(),
483            vec![
484                PyString::new("a".into(), Kind::ShortAsciiInterned).into(),
485                PyString::new("b".into(), Kind::ShortAsciiInterned).into()
486            ]
487        );
488    }
489
490    #[test]
491    fn test_load_list() {
492        // Empty list
493        let data = b"[\x00\x00\x00\x00";
494        let (kind, _) = load_bytes(data, (3, 10).into()).unwrap();
495        assert_eq!(
496            extract_strings_list!(
497                extract_object!(Some(kind), Object::List(list) => list, Error::UnexpectedObject)
498                    .unwrap()
499            )
500            .unwrap(),
501            vec![]
502        );
503
504        // List with two elements ("a", "b")
505        let data = b"[\x02\x00\x00\x00Z\x01aZ\x01b";
506        let (kind, _) = load_bytes(data, (3, 10).into()).unwrap();
507        assert_eq!(
508            extract_strings_list!(
509                extract_object!(Some(kind), Object::List(list) => list, Error::UnexpectedObject)
510                    .unwrap()
511            )
512            .unwrap(),
513            vec![
514                PyString::new("a".into(), Kind::ShortAsciiInterned).into(),
515                PyString::new("b".into(), Kind::ShortAsciiInterned).into()
516            ]
517        );
518    }
519
520    #[test]
521    fn test_reference() {
522        // Reference to the first element
523        let data = b"\xdb\x03\x00\x00\x00\xe9\x01\x00\x00\x00r\x01\x00\x00\x00r\x01\x00\x00\x00";
524        let (kind, refs) = load_bytes(data, (3, 10).into()).unwrap();
525
526        assert_eq!(
527            extract_object!(Some(kind.clone()), Object::StoreRef(index) => index, Error::UnexpectedObject)
528                .unwrap(),
529            0
530        );
531
532        assert_eq!(
533            resolve_object_ref!(Some(kind.clone()), refs).unwrap(),
534            Object::List(
535                vec![
536                    Object::StoreRef(1).into(),
537                    Object::LoadRef(1).into(),
538                    Object::LoadRef(1).into()
539                ]
540                .into()
541            )
542        );
543
544        assert_eq!(*refs.get(1).unwrap(), Object::Long(BigInt::from(1)).into());
545
546        // Recursive reference
547        let data = b"\xdb\x01\x00\x00\x00r\x00\x00\x00\x00";
548
549        let (kind, refs) = load_bytes(data, (3, 10).into()).unwrap();
550
551        dbg!(&kind, &refs);
552        dbg!(get_recursive_refs(kind, refs).unwrap());
553    }
554
555    #[test]
556    fn test_resolve_refs() {
557        // Reference to the first element
558        let data = b"\xdb\x03\x00\x00\x00\xe9\x01\x00\x00\x00r\x01\x00\x00\x00r\x01\x00\x00\x00";
559        let (kind, refs) = load_bytes(data, (3, 10).into()).unwrap();
560
561        let (obj, refs) = resolve_all_refs(kind, refs).unwrap();
562
563        assert_eq!(
564            extract_object!(Some(obj), Object::List(list) => list, Error::UnexpectedObject)
565                .unwrap()
566                .iter()
567                .map(|o| o.clone())
568                .collect::<Vec<_>>(),
569            vec![
570                Object::Long(BigInt::from(1)),
571                Object::Long(BigInt::from(1)),
572                Object::Long(BigInt::from(1))
573            ]
574        );
575
576        assert_eq!(refs.len(), 0);
577
578        let kind = Object::StoreRef(0);
579        let refs = vec![
580            Object::List(vec![Object::StoreRef(1).into(), Object::LoadRef(1).into()]).into(),
581            Object::StoreRef(2),
582            Object::Long(BigInt::from(1)).into(),
583        ];
584
585        let (kind, refs) = resolve_all_refs(kind, refs).unwrap();
586
587        assert_eq!(
588            kind,
589            Object::List(
590                vec![
591                    Object::Long(BigInt::from(1)).into(),
592                    Object::Long(BigInt::from(1)).into()
593                ]
594                .into()
595            )
596        );
597
598        assert_eq!(refs.len(), 0);
599    }
600
601    #[test]
602    fn test_load_dict() {
603        // Empty dict
604        let data = b"{0";
605        let (kind, _) = load_bytes(data, (3, 10).into()).unwrap();
606        assert_eq!(
607            extract_strings_dict!(
608                extract_object!(Some(kind), Object::Dict(dict) => dict, Error::UnexpectedObject)
609                    .unwrap()
610            )
611            .unwrap(),
612            HashMap::new()
613        );
614
615        // Dict with two elements {"a": "b", "c": "d"}
616        let data = b"{Z\x01aZ\x01bZ\x01cZ\x01d0";
617        let (kind, _) = load_bytes(data, (3, 10).into()).unwrap();
618        assert_eq!(
619            extract_strings_dict!(
620                extract_object!(Some(kind), Object::Dict(dict) => dict, Error::UnexpectedObject)
621                    .unwrap()
622            )
623            .unwrap(),
624            {
625                let mut map = HashMap::new();
626                map.insert(
627                    PyString::new("a".into(), Kind::ShortAsciiInterned).into(),
628                    PyString::new("b".into(), Kind::ShortAsciiInterned).into(),
629                );
630                map.insert(
631                    PyString::new("c".into(), Kind::ShortAsciiInterned).into(),
632                    PyString::new("d".into(), Kind::ShortAsciiInterned).into(),
633                );
634                map
635            }
636        );
637    }
638
639    #[test]
640    fn test_load_set() {
641        // Empty set
642        let data = b"<\x00\x00\x00\x00";
643        let (kind, _) = load_bytes(data, (3, 10).into()).unwrap();
644        assert_eq!(
645            extract_strings_set!(
646                extract_object!(Some(kind), Object::Set(set) => set, Error::UnexpectedObject)
647                    .unwrap()
648            )
649            .unwrap(),
650            HashSet::new().into()
651        );
652
653        // Set with two elements {"a", "b"}
654        let data = b"<\x02\x00\x00\x00Z\x01bZ\x01a";
655        let (kind, _) = load_bytes(data, (3, 10).into()).unwrap();
656        assert_eq!(
657            extract_strings_set!(
658                extract_object!(Some(kind), Object::Set(set) => set, Error::UnexpectedObject)
659                    .unwrap()
660            )
661            .unwrap(),
662            {
663                let mut set = HashSet::new();
664                set.insert(PyString::new("a".into(), Kind::ShortAsciiInterned).into());
665                set.insert(PyString::new("b".into(), Kind::ShortAsciiInterned).into());
666                set
667            }
668        );
669    }
670
671    #[test]
672    fn test_load_frozenset() {
673        // Empty frozenset
674        let data = b">\x00\x00\x00\x00";
675        let (kind, _) = load_bytes(data, (3, 10).into()).unwrap();
676        assert_eq!(
677            extract_strings_frozenset!(
678                extract_object!(Some(kind), Object::FrozenSet(set) => set, Error::UnexpectedObject)
679                    .unwrap()
680            )
681            .unwrap(),
682            HashSet::new()
683        );
684
685        // Frozenset with two elements {"a", "b"}
686        let data = b">\x02\x00\x00\x00Z\x01bZ\x01a";
687        let (kind, _) = load_bytes(data, (3, 10).into()).unwrap();
688        assert_eq!(
689            extract_strings_frozenset!(
690                extract_object!(Some(kind), Object::FrozenSet(set) => set, Error::UnexpectedObject)
691                    .unwrap()
692            )
693            .unwrap(),
694            {
695                let mut set = HashSet::new();
696                set.insert(PyString::new("a".into(), Kind::ShortAsciiInterned).into());
697                set.insert(PyString::new("b".into(), Kind::ShortAsciiInterned).into());
698                set
699            }
700        );
701    }
702
703    #[test]
704    fn test_load_code310() {
705        // def f(arg1, arg2=None): print(arg1, arg2)
706        let data = b"\xe3\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00C\x00\x00\x00s\x0e\x00\x00\x00t\x00|\x00|\x01\x83\x02\x01\x00d\x00S\x00\xa9\x01N)\x01\xda\x05print)\x02Z\x04arg1Z\x04arg2\xa9\x00r\x03\x00\x00\x00\xfa\x07<stdin>\xda\x01f\x01\x00\x00\x00s\x02\x00\x00\x00\x0e\x00";
707        let (kind, refs) = load_bytes(data, (3, 10).into()).unwrap();
708
709        let code = extract_object!(Some(resolve_object_ref!(Some(kind), refs).unwrap()), Object::Code(code) => code, Error::UnexpectedObject)
710                .unwrap().clone();
711
712        match *code {
713            Code::V310(code) => {
714                let inner_code = extract_object!(Some(resolve_object_ref!(Some((*code.code).clone()), &refs).unwrap()), Object::Bytes(bytes) => bytes, Error::NullInTuple).unwrap();
715                let inner_consts = extract_object!(Some(resolve_object_ref!(Some((*code.consts).clone()), &refs).unwrap()), Object::Tuple(objs) => objs, Error::NullInTuple).unwrap();
716                let inner_names = extract_strings_tuple!(extract_object!(Some(resolve_object_ref!(Some((*code.names).clone()), &refs).unwrap()), Object::Tuple(objs) => objs, Error::NullInTuple).unwrap(), &refs).unwrap();
717                let inner_varnames = extract_strings_tuple!(extract_object!(Some(resolve_object_ref!(Some((*code.varnames).clone()), &refs).unwrap()), Object::Tuple(objs) => objs, Error::NullInTuple).unwrap(), &refs).unwrap();
718                let inner_freevars = extract_strings_tuple!(extract_object!(Some(resolve_object_ref!(Some((*code.freevars).clone()), &refs).unwrap()), Object::Tuple(objs) => objs, Error::NullInTuple).unwrap(), &refs).unwrap();
719                let inner_cellvars = extract_strings_tuple!(extract_object!(Some(resolve_object_ref!(Some((*code.cellvars).clone()), &refs).unwrap()), Object::Tuple(objs) => objs, Error::NullInTuple).unwrap(), &refs).unwrap();
720                let inner_filename = extract_object!(Some(resolve_object_ref!(Some((*code.filename).clone()), &refs).unwrap()), Object::String(string) => string, Error::UnexpectedObject).unwrap();
721                let inner_name = extract_object!(Some(resolve_object_ref!(Some((*code.name).clone()), &refs).unwrap()), Object::String(string) => string, Error::UnexpectedObject).unwrap();
722                let inner_lnotab = extract_object!(Some(resolve_object_ref!(Some((*code.lnotab).clone()), &refs).unwrap()), Object::Bytes(bytes) => bytes, Error::NullInTuple).unwrap();
723
724                assert_eq!(code.argcount, 2);
725                assert_eq!(code.posonlyargcount, 0);
726                assert_eq!(code.kwonlyargcount, 0);
727                assert_eq!(code.nlocals, 2);
728                assert_eq!(code.stacksize, 3);
729                // assert_eq!(code.flags, );
730                assert_eq!(inner_code.len(), 14);
731                assert_eq!(inner_consts.len(), 1);
732                assert_eq!(inner_names.len(), 1);
733                assert_eq!(inner_varnames.len(), 2);
734                assert_eq!(inner_freevars.len(), 0);
735                assert_eq!(inner_cellvars.len(), 0);
736                assert_eq!(
737                    inner_filename,
738                    PyString::new("<stdin>".into(), Kind::ShortAscii).into()
739                );
740                assert_eq!(
741                    inner_name,
742                    PyString::new("f".into(), Kind::ShortAsciiInterned).into()
743                );
744                assert_eq!(code.firstlineno, 1);
745                assert_eq!(inner_lnotab.len(), 2);
746            }
747            _ => panic!("Invalid code object"),
748        }
749    }
750
751    #[test]
752    fn test_load_code311() {
753        // def f(arg1, arg2=None): print(arg1, arg2)
754        let data = b"\xe3\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x03\x00\x00\x00\xf3&\x00\x00\x00\x97\x00t\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00|\x00|\x01\xa6\x02\x00\x00\xab\x02\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00d\x00S\x00\xa9\x01N)\x01\xda\x05print)\x02\xda\x04arg1\xda\x04arg2s\x02\x00\x00\x00  \xfa\x07<stdin>\xda\x01fr\x07\x00\x00\x00\x01\x00\x00\x00s\x17\x00\x00\x00\x80\x00\x9d\x05\x98d\xa0D\xd1\x18)\xd4\x18)\xd0\x18)\xd0\x18)\xd0\x18)\xf3\x00\x00\x00\x00";
755        let (kind, refs) = load_bytes(data, (3, 11).into()).unwrap();
756
757        let code = extract_object!(Some(resolve_object_ref!(Some(kind), refs).unwrap()), Object::Code(code) => code, Error::UnexpectedObject)
758                .unwrap().clone();
759
760        match *code {
761            Code::V311(code) => {
762                let inner_code = extract_object!(Some(resolve_object_ref!(Some((*code.code).clone()), &refs).unwrap()), Object::Bytes(bytes) => bytes, Error::NullInTuple).unwrap();
763                let inner_consts = extract_object!(Some(resolve_object_ref!(Some((*code.consts).clone()), &refs).unwrap()), Object::Tuple(objs) => objs, Error::NullInTuple).unwrap();
764                let inner_names = extract_strings_tuple!(extract_object!(Some(resolve_object_ref!(Some((*code.names).clone()), &refs).unwrap()), Object::Tuple(objs) => objs, Error::NullInTuple).unwrap(), &refs).unwrap();
765                let inner_localsplusnames = extract_strings_tuple!(extract_object!(Some(resolve_object_ref!(Some((*code.localsplusnames).clone()), &refs).unwrap()), Object::Tuple(objs) => objs, Error::NullInTuple).unwrap(), &refs).unwrap();
766                let inner_localspluskinds = extract_object!(Some(resolve_object_ref!(Some((*code.localspluskinds).clone()), &refs).unwrap()), Object::Bytes(bytes) => bytes, Error::NullInTuple).unwrap();
767                let inner_filename = extract_object!(Some(resolve_object_ref!(Some((*code.filename).clone()), &refs).unwrap()), Object::String(string) => string, Error::UnexpectedObject).unwrap();
768                let inner_name = extract_object!(Some(resolve_object_ref!(Some((*code.name).clone()), &refs).unwrap()), Object::String(string) => string, Error::UnexpectedObject).unwrap();
769                let inner_linetable = extract_object!(Some(resolve_object_ref!(Some((*code.linetable).clone()), &refs).unwrap()), Object::Bytes(bytes) => bytes, Error::NullInTuple).unwrap();
770                let inner_exceptiontable = extract_object!(Some(resolve_object_ref!(Some((*code.exceptiontable).clone()), &refs).unwrap()), Object::Bytes(bytes) => bytes, Error::NullInTuple).unwrap();
771
772                assert_eq!(code.argcount, 2);
773                assert_eq!(code.posonlyargcount, 0);
774                assert_eq!(code.kwonlyargcount, 0);
775                assert_eq!(code.stacksize, 4);
776                // assert_eq!(code.flags, );
777                assert_eq!(inner_code.len(), 38);
778                assert_eq!(inner_consts.len(), 1);
779                assert_eq!(inner_names.len(), 1);
780                assert_eq!(inner_localsplusnames.len(), 2);
781                assert_eq!(inner_localspluskinds.len(), 2);
782                assert_eq!(
783                    inner_filename,
784                    PyString::new("<stdin>".into(), Kind::ShortAscii).into()
785                );
786                assert_eq!(
787                    inner_name,
788                    PyString::new("f".into(), Kind::ShortAsciiInterned).into()
789                );
790                assert_eq!(code.firstlineno, 1);
791                assert_eq!(inner_linetable.len(), 23);
792                assert_eq!(inner_exceptiontable.len(), 0);
793            }
794            _ => panic!("Invalid code object"),
795        }
796    }
797
798    #[test]
799    fn test_load_pyc() {
800        let data = b"o\r\r\n\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xe3\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00@\x00\x00\x00s\x0c\x00\x00\x00e\x00d\x00\x83\x01\x01\x00d\x01S\x00)\x02z\x0ehi from PythonN)\x01\xda\x05print\xa9\x00r\x02\x00\x00\x00r\x02\x00\x00\x00z\x08<string>\xda\x08<module>\x01\x00\x00\x00s\x02\x00\x00\x00\x0c\x00";
801
802        let mut temp_file = NamedTempFile::new().unwrap();
803        temp_file.write_all(data).unwrap();
804
805        let obj = load_pyc(&data[..]).unwrap();
806
807        dbg!(&obj); // TODO: Add assertions
808    }
809
810    #[test]
811    fn test_dump_long() {
812        // 1
813        let data = b"i\x01\x00\x00\x00";
814        let object = Object::Long(BigInt::from(1).into());
815        let dumped = dump_bytes(object, None, (3, 10).into(), 4).unwrap();
816        assert_eq!(data.to_vec(), dumped);
817
818        // 4294967295
819        let data = b"l\x03\x00\x00\x00\xff\x7f\xff\x7f\x03\x00";
820        let object = Object::Long(BigInt::from(4294967295u32).into());
821        let dumped = dump_bytes(object, None, (3, 10).into(), 4).unwrap();
822        assert_eq!(data.to_vec(), dumped);
823    }
824
825    #[test]
826    fn test_dump_float() {
827        // 1.0
828        let data = b"g\x00\x00\x00\x00\x00\x00\xf0?";
829        let object = Object::Float(1.0);
830        let dumped = dump_bytes(object, None, (3, 10).into(), 4).unwrap();
831        assert_eq!(data.to_vec(), dumped);
832    }
833
834    #[test]
835    fn test_dump_complex() {
836        // 3 + 4j
837        let data = b"y\x00\x00\x00\x00\x00\x00\x08@\x00\x00\x00\x00\x00\x00\x10@";
838        let object = Object::Complex(Complex::new(3.0, 4.0));
839        let dumped = dump_bytes(object, None, (3, 10).into(), 4).unwrap();
840        assert_eq!(data.to_vec(), dumped);
841    }
842
843    #[test]
844    fn test_dump_bytes() {
845        // b"test"
846        let data = b"s\x04\x00\x00\x00test";
847        let object = Object::Bytes("test".as_bytes().to_vec().into());
848        let dumped = dump_bytes(object, None, (3, 10).into(), 4).unwrap();
849        assert_eq!(data.to_vec(), dumped);
850    }
851
852    #[test]
853    fn test_dump_string() {
854        // "test"
855        let data = b"z\x04test";
856        let object = Object::String(PyString::from("test".to_string()).into());
857        let dumped = dump_bytes(object, None, (3, 10).into(), 4).unwrap();
858        assert_eq!(data.to_vec(), dumped);
859
860        // "\xe9"
861        let data = b"u\x03\x00\x00\x00\xed\xb2\x80";
862        let object = Object::String(
863            PyString::new(BString::new([237, 178, 128].to_vec()), Kind::Unicode).into(),
864        );
865        let dumped = dump_bytes(object, None, (3, 10).into(), 4).unwrap();
866        assert_eq!(data.to_vec(), dumped);
867    }
868
869    #[test]
870    fn test_dump_tuple() {
871        // Empty tuple
872        let data = b")\x00";
873        let object = Object::Tuple(vec![].into());
874        let dumped = dump_bytes(object, None, (3, 10).into(), 4).unwrap();
875        assert_eq!(data.to_vec(), dumped);
876
877        // Tuple with two elements ("a", "b")
878        let data = b")\x02z\x01az\x01b";
879        let object = Object::Tuple(
880            vec![
881                Object::String(PyString::from("a".to_string()).into()).into(),
882                Object::String(PyString::from("b".to_string()).into()).into(),
883            ]
884            .into(),
885        );
886        let dumped = dump_bytes(object, None, (3, 10).into(), 4).unwrap();
887        assert_eq!(data.to_vec(), dumped);
888    }
889
890    #[test]
891    fn test_dump_list() {
892        // Empty list
893        let data = b"[\x00\x00\x00\x00";
894        let object = Object::List(vec![].into());
895        let dumped = dump_bytes(object, None, (3, 10).into(), 4).unwrap();
896        assert_eq!(data.to_vec(), dumped);
897
898        // List with two elements ("a", "b")
899        let data = b"[\x02\x00\x00\x00z\x01az\x01b";
900        let object = Object::List(
901            vec![
902                Object::String(PyString::from("a".to_string()).into()).into(),
903                Object::String(PyString::from("b".to_string()).into()).into(),
904            ]
905            .into(),
906        );
907        dbg!(&object);
908        let dumped = dump_bytes(object, None, (3, 10).into(), 4).unwrap();
909        assert_eq!(data.to_vec(), dumped);
910    }
911
912    #[test]
913    fn test_dump_dict() {
914        // Empty dict
915        let data = b"{0";
916        let object = Object::Dict(IndexMap::new().into());
917        let dumped = dump_bytes(object, None, (3, 10).into(), 4).unwrap();
918        assert_eq!(data.to_vec(), dumped);
919
920        // Dict with two elements {"a": "b", "c": "d"}
921        let data1 = b"{z\x01az\x01bz\x01cz\x01d0";
922        let data2 = b"{z\x01cz\x01dz\x01az\x01b0"; // Order is not guaranteed
923        let object = Object::Dict({
924            let mut map = IndexMap::new();
925            map.insert(
926                ObjectHashable::String(PyString::from("a".to_string()).into()).into(),
927                Object::String(PyString::from("b".to_string()).into()).into(),
928            );
929            map.insert(
930                ObjectHashable::String(PyString::from("c".to_string()).into()).into(),
931                Object::String(PyString::from("d".to_string()).into()).into(),
932            );
933            map
934        });
935        let dumped = dump_bytes(object, None, (3, 10).into(), 4).unwrap();
936        assert!(data1.to_vec() == dumped || data2.to_vec() == dumped);
937    }
938
939    #[test]
940    fn test_dump_set() {
941        // Empty set
942        let data = b"<\x00\x00\x00\x00";
943        let object = Object::Set(IndexSet::new().into());
944        let dumped = dump_bytes(object, None, (3, 10).into(), 4).unwrap();
945        assert_eq!(data.to_vec(), dumped);
946
947        // Set with two elements {"a", "b"}
948        let data1 = b"<\x02\x00\x00\x00z\x01az\x01b";
949        let data2 = b"<\x02\x00\x00\x00z\x01bz\x01a"; // Order is not guaranteed
950        let object = Object::Set({
951            let mut set = IndexSet::new();
952            set.insert(ObjectHashable::String(PyString::from("a".to_string()).into()).into());
953            set.insert(ObjectHashable::String(PyString::from("b".to_string()).into()).into());
954            set
955        });
956        let dumped = dump_bytes(object, None, (3, 10).into(), 4).unwrap();
957        assert!(data1.to_vec() == dumped || data2.to_vec() == dumped);
958    }
959
960    #[test]
961    fn test_dump_frozenset() {
962        // Empty frozenset
963        let data = b">\x00\x00\x00\x00";
964        let object = Object::FrozenSet(IndexSet::new().into());
965        let dumped = dump_bytes(object, None, (3, 10).into(), 4).unwrap();
966        assert_eq!(data.to_vec(), dumped);
967
968        // Frozenset with two elements {"a", "b"}
969        let data1 = b">\x02\x00\x00\x00z\x01az\x01b"; // Order is not guaranteed
970        let data2 = b">\x02\x00\x00\x00z\x01bz\x01a";
971        let object = Object::FrozenSet({
972            let mut set = IndexSet::new();
973            set.insert(ObjectHashable::String(PyString::from("a".to_string()).into()).into());
974            set.insert(ObjectHashable::String(PyString::from("b".to_string()).into()).into());
975            set
976        });
977        let dumped = dump_bytes(object, None, (3, 10).into(), 4).unwrap();
978        assert!(data1.to_vec() == dumped || data2.to_vec() == dumped || data2.to_vec() == dumped);
979    }
980
981    #[test]
982    fn test_dump_code() {
983        // def f(arg1, arg2=None): print(arg1, arg2)
984        let data = b"c\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00C\x00\x00\x00s\x0e\x00\x00\x00t\x00|\x00|\x01\x83\x02\x01\x00d\x00S\x00)\x01N)\x01z\x05print)\x02z\x04arg1z\x04arg2)\x00)\x00z\x07<stdin>z\x01f\x01\x00\x00\x00s\x02\x00\x00\x00\x0e\x00";
985
986        let object = Code::V310(code_objects::Code310 {
987            argcount: 2,
988            posonlyargcount: 0,
989            kwonlyargcount: 0,
990            nlocals: 2,
991            stacksize: 3,
992            flags: CodeFlags::from_bits_truncate(0x43),
993            code: Object::Bytes(vec![116, 0, 124, 0, 124, 1, 131, 2, 1, 0, 100, 0, 83, 0].into())
994                .into(),
995            consts: Object::Tuple([Object::None.into()].to_vec().into()).into(),
996            names: Object::Tuple(
997                [Object::String(PyString::from("print".to_string()).into()).into()]
998                    .to_vec()
999                    .into(),
1000            )
1001            .into(),
1002            varnames: Object::Tuple(
1003                [
1004                    Object::String(PyString::from("arg1".to_string()).into()).into(),
1005                    Object::String(PyString::from("arg2".to_string()).into()).into(),
1006                ]
1007                .to_vec()
1008                .into(),
1009            )
1010            .into(),
1011            freevars: Object::Tuple([].to_vec().into()).into(),
1012            cellvars: Object::Tuple([].to_vec().into()).into(),
1013            filename: Object::String(PyString::from("<stdin>".to_string()).into()).into(),
1014            name: Object::String(PyString::from("f".to_string()).into()).into(),
1015            firstlineno: 1,
1016            lnotab: Object::Bytes([14, 0].to_vec().into()).into(),
1017        });
1018        let dumped = dump_bytes(Object::Code(object.into()), None, (3, 10).into(), 4).unwrap();
1019
1020        assert_eq!(data.to_vec(), dumped);
1021    }
1022
1023    #[test]
1024    fn test_recompile() {
1025        let data = b"c\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00C\x00\x00\x00s\x0e\x00\x00\x00t\x00|\x00|\x01\x83\x02\x01\x00d\x00S\x00)\x01N)\x01z\x05print)\x02z\x04arg1z\x04arg2)\x00)\x00z\x07<stdin>z\x01f\x01\x00\x00\x00s\x02\x00\x00\x00\x0e\x00";
1026
1027        let (kind, refs) = load_bytes(data, (3, 10).into()).unwrap();
1028        let dumped = dump_bytes(kind, Some(refs), (3, 10).into(), 4).unwrap();
1029
1030        assert_eq!(data.to_vec(), dumped);
1031    }
1032
1033    #[test]
1034    fn test_optimize_references() {
1035        let data = b"\xdb\x03\x00\x00\x00\xe9\x01\x00\x00\x00r\x01\x00\x00\x00r\x01\x00\x00\x00";
1036        let (kind, refs) = load_bytes(data, (3, 10).into()).unwrap();
1037
1038        let (kind, refs) = optimize_references(kind, refs);
1039
1040        dump_bytes(kind.clone(), Some(refs.clone()), (3, 10).into(), 4).unwrap();
1041
1042        assert_eq!(
1043            kind,
1044            Object::List(
1045                vec![
1046                    Object::StoreRef(0).into(),
1047                    Object::LoadRef(0).into(),
1048                    Object::LoadRef(0).into()
1049                ]
1050                .into()
1051            )
1052        );
1053
1054        assert_eq!(*refs.get(0).unwrap(), Object::Long(BigInt::from(1)).into());
1055
1056        let kind = Object::StoreRef(0);
1057        let refs = vec![
1058            Object::List(vec![Object::StoreRef(1).into(), Object::LoadRef(1).into()]).into(),
1059            Object::StoreRef(2),
1060            Object::Long(BigInt::from(1)).into(),
1061        ];
1062
1063        let (kind, refs) = optimize_references(kind, refs);
1064
1065        dump_bytes(kind.clone(), Some(refs.clone()), (3, 10).into(), 4).unwrap();
1066
1067        assert_eq!(
1068            kind,
1069            Object::List(vec![Object::StoreRef(0).into(), Object::LoadRef(0).into(),].into())
1070        );
1071
1072        assert_eq!(*refs.get(0).unwrap(), Object::Long(BigInt::from(1)).into());
1073    }
1074}