Skip to main content

rustpython_vm/builtins/
zip.rs

1use super::PyType;
2use crate::{
3    AsObject, Context, Py, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine,
4    builtins::PyTupleRef,
5    class::PyClassImpl,
6    function::{ArgIntoBool, OptionalArg, PosArgs},
7    protocol::{PyIter, PyIterReturn},
8    types::{Constructor, IterNext, Iterable, SelfIter},
9};
10use rustpython_common::atomic::{self, PyAtomic, Radium};
11
12#[pyclass(module = false, name = "zip", traverse)]
13#[derive(Debug)]
14pub struct PyZip {
15    iterators: Vec<PyIter>,
16    #[pytraverse(skip)]
17    strict: PyAtomic<bool>,
18}
19
20impl PyPayload for PyZip {
21    #[inline]
22    fn class(ctx: &Context) -> &'static Py<PyType> {
23        ctx.types.zip_type
24    }
25}
26
27#[derive(FromArgs)]
28pub struct PyZipNewArgs {
29    #[pyarg(named, optional)]
30    strict: OptionalArg<bool>,
31}
32
33impl Constructor for PyZip {
34    type Args = (PosArgs<PyIter>, PyZipNewArgs);
35
36    fn py_new(
37        _cls: &Py<PyType>,
38        (iterators, args): Self::Args,
39        _vm: &VirtualMachine,
40    ) -> PyResult<Self> {
41        let iterators = iterators.into_vec();
42        let strict = Radium::new(args.strict.unwrap_or(false));
43        Ok(Self { iterators, strict })
44    }
45}
46
47#[pyclass(with(IterNext, Iterable, Constructor), flags(BASETYPE))]
48impl PyZip {
49    #[pymethod]
50    fn __reduce__(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult<PyTupleRef> {
51        let cls = zelf.class().to_owned();
52        let iterators = zelf
53            .iterators
54            .iter()
55            .map(|obj| obj.clone().into())
56            .collect::<Vec<_>>();
57        let tuple_iter = vm.ctx.new_tuple(iterators);
58        Ok(if zelf.strict.load(atomic::Ordering::Acquire) {
59            vm.new_tuple((cls, tuple_iter, true))
60        } else {
61            vm.new_tuple((cls, tuple_iter))
62        })
63    }
64
65    #[pymethod]
66    fn __setstate__(zelf: PyRef<Self>, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
67        if let Ok(obj) = ArgIntoBool::try_from_object(vm, state) {
68            zelf.strict.store(obj.into(), atomic::Ordering::Release);
69        }
70        Ok(())
71    }
72}
73
74impl SelfIter for PyZip {}
75impl IterNext for PyZip {
76    fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
77        if zelf.iterators.is_empty() {
78            return Ok(PyIterReturn::StopIteration(None));
79        }
80        let mut next_objs = Vec::new();
81        for (idx, iterator) in zelf.iterators.iter().enumerate() {
82            let item = match iterator.next(vm)? {
83                PyIterReturn::Return(obj) => obj,
84                PyIterReturn::StopIteration(v) => {
85                    if zelf.strict.load(atomic::Ordering::Acquire) {
86                        if idx > 0 {
87                            let plural = if idx == 1 { " " } else { "s 1-" };
88                            return Err(vm.new_value_error(format!(
89                                "zip() argument {} is shorter than argument{}{}",
90                                idx + 1,
91                                plural,
92                                idx
93                            )));
94                        }
95                        for (idx, iterator) in zelf.iterators[1..].iter().enumerate() {
96                            if let PyIterReturn::Return(_obj) = iterator.next(vm)? {
97                                let plural = if idx == 0 { " " } else { "s 1-" };
98                                return Err(vm.new_value_error(format!(
99                                    "zip() argument {} is longer than argument{}{}",
100                                    idx + 2,
101                                    plural,
102                                    idx + 1
103                                )));
104                            }
105                        }
106                    }
107                    return Ok(PyIterReturn::StopIteration(v));
108                }
109            };
110            next_objs.push(item);
111        }
112        Ok(PyIterReturn::Return(vm.ctx.new_tuple(next_objs).into()))
113    }
114}
115
116pub fn init(ctx: &'static Context) {
117    PyZip::extend_class(ctx, ctx.types.zip_type);
118}