1pub(crate) use decl::module_def;
3
4#[pymodule(name = "marshal")]
5mod decl {
6 use crate::builtins::code::{CodeObject, Literal, PyObjBag};
7 use crate::class::StaticType;
8 use crate::common::wtf8::Wtf8;
9 use crate::{
10 PyObjectRef, PyResult, TryFromObject, VirtualMachine,
11 builtins::{
12 PyBool, PyByteArray, PyBytes, PyCode, PyComplex, PyDict, PyEllipsis, PyFloat,
13 PyFrozenSet, PyInt, PyList, PyNone, PySet, PyStopIteration, PyStr, PyTuple,
14 },
15 convert::ToPyObject,
16 function::{ArgBytesLike, OptionalArg},
17 object::{AsObject, PyPayload},
18 protocol::PyBuffer,
19 };
20 use malachite_bigint::BigInt;
21 use num_traits::Zero;
22 use rustpython_compiler_core::marshal;
23
24 #[pyattr(name = "version")]
25 use marshal::FORMAT_VERSION;
26
27 pub struct DumpError;
28
29 impl marshal::Dumpable for PyObjectRef {
30 type Error = DumpError;
31 type Constant = Literal;
32
33 fn with_dump<R>(
34 &self,
35 f: impl FnOnce(marshal::DumpableValue<'_, Self>) -> R,
36 ) -> Result<R, Self::Error> {
37 use marshal::DumpableValue::*;
38 if self.is(PyStopIteration::static_type()) {
39 return Ok(f(StopIter));
40 }
41 let ret = match_class!(match self {
42 PyNone => f(None),
43 PyEllipsis => f(Ellipsis),
44 ref pyint @ PyInt => {
45 if self.class().is(PyBool::static_type()) {
46 f(Boolean(!pyint.as_bigint().is_zero()))
47 } else {
48 f(Integer(pyint.as_bigint()))
49 }
50 }
51 ref pyfloat @ PyFloat => {
52 f(Float(pyfloat.to_f64()))
53 }
54 ref pycomplex @ PyComplex => {
55 f(Complex(pycomplex.to_complex64()))
56 }
57 ref pystr @ PyStr => {
58 f(Str(pystr.as_wtf8()))
59 }
60 ref pylist @ PyList => {
61 f(List(&pylist.borrow_vec()))
62 }
63 ref pyset @ PySet => {
64 let elements = pyset.elements();
65 f(Set(&elements))
66 }
67 ref pyfrozen @ PyFrozenSet => {
68 let elements = pyfrozen.elements();
69 f(Frozenset(&elements))
70 }
71 ref pytuple @ PyTuple => {
72 f(Tuple(pytuple.as_slice()))
73 }
74 ref pydict @ PyDict => {
75 let entries = pydict.into_iter().collect::<Vec<_>>();
76 f(Dict(&entries))
77 }
78 ref bytes @ PyBytes => {
79 f(Bytes(bytes.as_bytes()))
80 }
81 ref bytes @ PyByteArray => {
82 f(Bytes(&bytes.borrow_buf()))
83 }
84 ref co @ PyCode => {
85 f(Code(co))
86 }
87 _ => return Err(DumpError),
88 });
89 Ok(ret)
90 }
91 }
92
93 #[derive(FromArgs)]
94 struct DumpsArgs {
95 value: PyObjectRef,
96 #[pyarg(any, optional)]
97 _version: OptionalArg<i32>,
98 #[pyarg(named, default = true)]
99 allow_code: bool,
100 }
101
102 #[pyfunction]
103 fn dumps(args: DumpsArgs, vm: &VirtualMachine) -> PyResult<PyBytes> {
104 let DumpsArgs {
105 value,
106 allow_code,
107 _version,
108 } = args;
109 let version = _version.unwrap_or(marshal::FORMAT_VERSION as i32);
110 if !allow_code {
111 check_no_code(&value, vm)?;
112 }
113 check_exact_type(&value, vm)?;
114 let mut buf = Vec::new();
115 let mut refs = if version >= 3 {
116 Some(WriterRefTable::new())
117 } else {
118 None
119 };
120 write_object(&mut buf, &value, &mut refs, version, vm)?;
121 Ok(PyBytes::from(buf))
122 }
123
124 struct WriterRefTable {
125 map: std::collections::HashMap<usize, u32>,
126 next_idx: u32,
127 }
128
129 impl WriterRefTable {
130 fn new() -> Self {
131 Self {
132 map: std::collections::HashMap::new(),
133 next_idx: 0,
134 }
135 }
136 fn try_ref(&mut self, buf: &mut Vec<u8>, obj: &PyObjectRef) -> bool {
137 use marshal::Write;
138 let id = obj.get_id();
139 if let Some(&idx) = self.map.get(&id) {
140 buf.write_u8(b'r');
141 buf.write_u32(idx);
142 true
143 } else {
144 false
145 }
146 }
147 fn reserve(&mut self, obj: &PyObjectRef) -> u32 {
148 let idx = self.next_idx;
149 self.map.insert(obj.get_id(), idx);
150 self.next_idx += 1;
151 idx
152 }
153 }
154
155 fn write_object(
156 buf: &mut Vec<u8>,
157 obj: &PyObjectRef,
158 refs: &mut Option<WriterRefTable>,
159 version: i32,
160 vm: &VirtualMachine,
161 ) -> PyResult<()> {
162 write_object_depth(
163 buf,
164 obj,
165 refs,
166 version,
167 vm,
168 marshal::MAX_MARSHAL_STACK_DEPTH,
169 )
170 }
171
172 fn write_object_depth(
173 buf: &mut Vec<u8>,
174 obj: &PyObjectRef,
175 refs: &mut Option<WriterRefTable>,
176 version: i32,
177 vm: &VirtualMachine,
178 depth: usize,
179 ) -> PyResult<()> {
180 use marshal::Write;
181 if depth == 0 {
182 return Err(vm.new_value_error("object too deeply nested to marshal".to_string()));
183 }
184
185 let is_singleton = vm.is_none(obj)
187 || obj.class().is(PyBool::static_type())
188 || obj.is(PyStopIteration::static_type())
189 || obj.downcast_ref::<crate::builtins::PyEllipsis>().is_some();
190
191 if !is_singleton
193 && let Some(rt) = refs.as_mut()
194 && rt.try_ref(buf, obj)
195 {
196 return Ok(());
197 }
198 let type_pos = buf.len();
199 let use_ref = refs.is_some() && !is_singleton;
200 if use_ref {
201 refs.as_mut().unwrap().reserve(obj);
202 }
203
204 if vm.is_none(obj) {
205 buf.write_u8(b'N');
206 } else if obj.is(PyStopIteration::static_type()) {
207 buf.write_u8(b'S');
208 } else if obj.class().is(PyBool::static_type()) {
209 let val = obj
210 .downcast_ref::<PyInt>()
211 .is_some_and(|i| !i.as_bigint().is_zero());
212 buf.write_u8(if val { b'T' } else { b'F' });
213 } else if obj.downcast_ref::<crate::builtins::PyEllipsis>().is_some() {
214 buf.write_u8(b'.');
215 } else if let Some(i) = obj.downcast_ref::<PyInt>() {
216 if let Ok(val) = i32::try_from(i.as_bigint()) {
218 buf.write_u8(b'i');
219 buf.write_u32(val as u32);
220 } else {
221 buf.write_u8(b'l');
222 let (sign, raw) = i.as_bigint().to_bytes_le();
223 let mut digits = Vec::new();
224 let mut accum: u32 = 0;
225 let mut bits = 0u32;
226 for &byte in &raw {
227 accum |= (byte as u32) << bits;
228 bits += 8;
229 while bits >= 15 {
230 digits.push((accum & 0x7fff) as u16);
231 accum >>= 15;
232 bits -= 15;
233 }
234 }
235 if accum > 0 || digits.is_empty() {
236 digits.push(accum as u16);
237 }
238 while digits.len() > 1 && *digits.last().unwrap() == 0 {
239 digits.pop();
240 }
241 let n = digits.len() as i32;
242 let n = if sign == malachite_bigint::Sign::Minus {
243 -n
244 } else {
245 n
246 };
247 buf.write_u32(n as u32);
248 for d in &digits {
249 buf.write_u16(*d);
250 }
251 }
252 } else if let Some(f) = obj.downcast_ref::<PyFloat>() {
253 buf.write_u8(b'g');
254 buf.write_u64(f.to_f64().to_bits());
255 } else if let Some(c) = obj.downcast_ref::<PyComplex>() {
256 buf.write_u8(b'y');
257 let cv = c.to_complex64();
258 buf.write_u64(cv.re.to_bits());
259 buf.write_u64(cv.im.to_bits());
260 } else if let Some(s) = obj.downcast_ref::<PyStr>() {
261 let bytes = s.as_wtf8().as_bytes();
262 let interned = version >= 3;
263 if bytes.len() < 256 && bytes.is_ascii() {
264 buf.write_u8(if interned { b'Z' } else { b'z' });
265 buf.write_u8(bytes.len() as u8);
266 } else {
267 buf.write_u8(if interned { b't' } else { b'u' });
268 buf.write_u32(bytes.len() as u32);
269 }
270 buf.write_slice(bytes);
271 } else if let Some(b) = obj.downcast_ref::<PyBytes>() {
272 buf.write_u8(b's');
273 let data = b.as_bytes();
274 buf.write_u32(data.len() as u32);
275 buf.write_slice(data);
276 } else if let Some(b) = obj.downcast_ref::<PyByteArray>() {
277 buf.write_u8(b's');
278 let data = b.borrow_buf();
279 buf.write_u32(data.len() as u32);
280 buf.write_slice(&data);
281 } else if let Some(t) = obj.downcast_ref::<PyTuple>() {
282 buf.write_u8(b'(');
283 buf.write_u32(t.len() as u32);
284 for elem in t.as_slice() {
285 write_object_depth(buf, elem, refs, version, vm, depth - 1)?;
286 }
287 } else if let Some(l) = obj.downcast_ref::<PyList>() {
288 buf.write_u8(b'[');
289 let items = l.borrow_vec();
290 buf.write_u32(items.len() as u32);
291 for elem in items.iter() {
292 write_object_depth(buf, elem, refs, version, vm, depth - 1)?;
293 }
294 } else if let Some(d) = obj.downcast_ref::<PyDict>() {
295 buf.write_u8(b'{');
296 for (k, v) in d.into_iter() {
297 write_object_depth(buf, &k, refs, version, vm, depth - 1)?;
298 write_object_depth(buf, &v, refs, version, vm, depth - 1)?;
299 }
300 buf.write_u8(b'0'); } else if let Some(s) = obj.downcast_ref::<PySet>() {
302 buf.write_u8(b'<');
303 let elems = s.elements();
304 buf.write_u32(elems.len() as u32);
305 for elem in &elems {
306 write_object_depth(buf, elem, refs, version, vm, depth - 1)?;
307 }
308 } else if let Some(s) = obj.downcast_ref::<PyFrozenSet>() {
309 buf.write_u8(b'>');
310 let elems = s.elements();
311 buf.write_u32(elems.len() as u32);
312 for elem in &elems {
313 write_object_depth(buf, elem, refs, version, vm, depth - 1)?;
314 }
315 } else if let Some(co) = obj.downcast_ref::<PyCode>() {
316 buf.write_u8(b'c');
317 marshal::serialize_code(buf, &co.code);
318 } else if let Some(sl) = obj.downcast_ref::<crate::builtins::PySlice>() {
319 if version < 5 {
320 return Err(vm.new_value_error("unmarshallable object".to_string()));
321 }
322 buf.write_u8(b':');
323 let none: PyObjectRef = vm.ctx.none();
324 write_object_depth(
325 buf,
326 sl.start.as_ref().unwrap_or(&none),
327 refs,
328 version,
329 vm,
330 depth - 1,
331 )?;
332 write_object_depth(buf, &sl.stop, refs, version, vm, depth - 1)?;
333 write_object_depth(
334 buf,
335 sl.step.as_ref().unwrap_or(&none),
336 refs,
337 version,
338 vm,
339 depth - 1,
340 )?;
341 } else if let Ok(bytes_like) = ArgBytesLike::try_from_object(vm, obj.clone()) {
342 buf.write_u8(b's');
343 let data = bytes_like.borrow_buf();
344 buf.write_u32(data.len() as u32);
345 buf.write_slice(&data);
346 } else {
347 return Err(vm.new_value_error("unmarshallable object".to_string()));
348 }
349
350 if use_ref {
351 buf[type_pos] |= marshal::FLAG_REF;
352 }
353 Ok(())
354 }
355
356 #[derive(FromArgs)]
357 struct DumpArgs {
358 value: PyObjectRef,
359 f: PyObjectRef,
360 #[pyarg(any, optional)]
361 _version: OptionalArg<i32>,
362 #[pyarg(named, default = true)]
363 allow_code: bool,
364 }
365
366 #[pyfunction]
367 fn dump(args: DumpArgs, vm: &VirtualMachine) -> PyResult<()> {
368 let dumped = dumps(
369 DumpsArgs {
370 value: args.value,
371 _version: args._version,
372 allow_code: args.allow_code,
373 },
374 vm,
375 )?;
376 vm.call_method(&args.f, "write", (dumped,))?;
377 Ok(())
378 }
379
380 #[derive(Copy, Clone)]
381 struct PyMarshalBag<'a>(&'a VirtualMachine);
382
383 impl<'a> marshal::MarshalBag for PyMarshalBag<'a> {
384 type Value = PyObjectRef;
385 type ConstantBag = PyObjBag<'a>;
386
387 fn make_bool(&self, value: bool) -> Self::Value {
388 self.0.ctx.new_bool(value).into()
389 }
390 fn make_none(&self) -> Self::Value {
391 self.0.ctx.none()
392 }
393 fn make_ellipsis(&self) -> Self::Value {
394 self.0.ctx.ellipsis.clone().into()
395 }
396 fn make_float(&self, value: f64) -> Self::Value {
397 self.0.ctx.new_float(value).into()
398 }
399 fn make_complex(&self, value: num_complex::Complex64) -> Self::Value {
400 self.0.ctx.new_complex(value).into()
401 }
402 fn make_str(&self, value: &Wtf8) -> Self::Value {
403 self.0.ctx.new_str(value).into()
404 }
405 fn make_bytes(&self, value: &[u8]) -> Self::Value {
406 self.0.ctx.new_bytes(value.to_vec()).into()
407 }
408 fn make_int(&self, value: BigInt) -> Self::Value {
409 self.0.ctx.new_int(value).into()
410 }
411 fn make_tuple(&self, elements: impl Iterator<Item = Self::Value>) -> Self::Value {
412 self.0.ctx.new_tuple(elements.collect()).into()
413 }
414 fn make_code(&self, code: CodeObject) -> Self::Value {
415 self.0.ctx.new_code(code).into()
416 }
417 fn make_stop_iter(&self) -> Result<Self::Value, marshal::MarshalError> {
418 Ok(self.0.ctx.exceptions.stop_iteration.to_owned().into())
419 }
420 fn make_list(
421 &self,
422 it: impl Iterator<Item = Self::Value>,
423 ) -> Result<Self::Value, marshal::MarshalError> {
424 Ok(self.0.ctx.new_list(it.collect()).into())
425 }
426 fn make_set(
427 &self,
428 it: impl Iterator<Item = Self::Value>,
429 ) -> Result<Self::Value, marshal::MarshalError> {
430 let set = PySet::default().into_ref(&self.0.ctx);
431 for elem in it {
432 set.add(elem, self.0).unwrap()
433 }
434 Ok(set.into())
435 }
436 fn make_frozenset(
437 &self,
438 it: impl Iterator<Item = Self::Value>,
439 ) -> Result<Self::Value, marshal::MarshalError> {
440 Ok(PyFrozenSet::from_iter(self.0, it)
441 .unwrap()
442 .to_pyobject(self.0))
443 }
444 fn make_dict(
445 &self,
446 it: impl Iterator<Item = (Self::Value, Self::Value)>,
447 ) -> Result<Self::Value, marshal::MarshalError> {
448 let dict = self.0.ctx.new_dict();
449 for (k, v) in it {
450 dict.set_item(&*k, v, self.0).unwrap()
451 }
452 Ok(dict.into())
453 }
454 fn make_slice(
455 &self,
456 start: Self::Value,
457 stop: Self::Value,
458 step: Self::Value,
459 ) -> Result<Self::Value, marshal::MarshalError> {
460 use crate::builtins::PySlice;
461 let vm = self.0;
462 Ok(PySlice {
463 start: if vm.is_none(&start) {
464 None
465 } else {
466 Some(start)
467 },
468 stop,
469 step: if vm.is_none(&step) { None } else { Some(step) },
470 }
471 .into_ref(&vm.ctx)
472 .into())
473 }
474 fn constant_bag(self) -> Self::ConstantBag {
475 PyObjBag(&self.0.ctx)
476 }
477 }
478
479 #[derive(FromArgs)]
480 struct LoadsArgs {
481 #[pyarg(any)]
482 data: PyBuffer,
483 #[pyarg(named, default = true)]
484 allow_code: bool,
485 }
486
487 #[pyfunction]
488 fn loads(args: LoadsArgs, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
489 let LoadsArgs {
490 data: pybuffer,
491 allow_code,
492 } = args;
493 let buf = pybuffer.as_contiguous().ok_or_else(|| {
494 vm.new_buffer_error("Buffer provided to marshal.loads() is not contiguous")
495 })?;
496
497 let result =
498 marshal::deserialize_value(&mut &buf[..], PyMarshalBag(vm)).map_err(|e| match e {
499 marshal::MarshalError::Eof => vm.new_exception_msg(
500 vm.ctx.exceptions.eof_error.to_owned(),
501 "marshal data too short".into(),
502 ),
503 _ => vm.new_value_error("bad marshal data"),
504 })?;
505 if !allow_code {
506 check_no_code(&result, vm)?;
507 }
508 Ok(result)
509 }
510
511 #[derive(FromArgs)]
512 struct LoadArgs {
513 f: PyObjectRef,
514 #[pyarg(named, default = true)]
515 allow_code: bool,
516 }
517
518 #[pyfunction]
519 fn load(args: LoadArgs, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
520 let tell_before = vm
524 .call_method(&args.f, "tell", ())?
525 .try_into_value::<i64>(vm)?;
526 let read_res = vm.call_method(&args.f, "read", ())?;
527 let bytes = ArgBytesLike::try_from_object(vm, read_res)?;
528 let buf = bytes.borrow_buf();
529
530 let mut rdr: &[u8] = &buf;
531 let len_before = rdr.len();
532 let result =
533 marshal::deserialize_value(&mut rdr, PyMarshalBag(vm)).map_err(|e| match e {
534 marshal::MarshalError::Eof => vm.new_exception_msg(
535 vm.ctx.exceptions.eof_error.to_owned(),
536 "marshal data too short".into(),
537 ),
538 _ => vm.new_value_error("bad marshal data"),
539 })?;
540 let consumed = len_before - rdr.len();
541
542 let new_pos = tell_before + consumed as i64;
544 vm.call_method(&args.f, "seek", (new_pos,))?;
545
546 if !args.allow_code {
547 check_no_code(&result, vm)?;
548 }
549 Ok(result)
550 }
551
552 fn check_no_code(obj: &PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
555 if obj.downcast_ref::<PyCode>().is_some() {
556 return Err(vm.new_value_error("unmarshalling code objects is disallowed".to_string()));
557 }
558 if let Some(tup) = obj.downcast_ref::<PyTuple>() {
559 for elem in tup.as_slice() {
560 check_no_code(elem, vm)?;
561 }
562 } else if let Some(list) = obj.downcast_ref::<PyList>() {
563 for elem in list.borrow_vec().iter() {
564 check_no_code(elem, vm)?;
565 }
566 } else if let Some(set) = obj.downcast_ref::<PySet>() {
567 for elem in set.elements() {
568 check_no_code(&elem, vm)?;
569 }
570 } else if let Some(fset) = obj.downcast_ref::<PyFrozenSet>() {
571 for elem in fset.elements() {
572 check_no_code(&elem, vm)?;
573 }
574 } else if let Some(dict) = obj.downcast_ref::<PyDict>() {
575 for (k, v) in dict.into_iter() {
576 check_no_code(&k, vm)?;
577 check_no_code(&v, vm)?;
578 }
579 }
580 Ok(())
581 }
582
583 fn check_exact_type(obj: &PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
584 let cls = obj.class();
585 if cls.is(PyBool::static_type()) {
587 return Ok(());
588 }
589 for base in [
590 PyInt::static_type(),
591 PyFloat::static_type(),
592 PyComplex::static_type(),
593 PyTuple::static_type(),
594 PyList::static_type(),
595 PyDict::static_type(),
596 PySet::static_type(),
597 PyFrozenSet::static_type(),
598 ] {
599 if cls.fast_issubclass(base) && !cls.is(base) {
600 return Err(vm.new_value_error("unmarshallable object".to_string()));
601 }
602 }
603 Ok(())
604 }
605}