Skip to main content

rustpython_vm/builtins/
enumerate.rs

1use super::{
2    IterStatus, PositionIterInternal, PyGenericAlias, PyIntRef, PyTupleRef, PyType, PyTypeRef,
3    iter::builtins_reversed,
4};
5use crate::common::lock::{PyMutex, PyRwLock};
6use crate::{
7    AsObject, Context, Py, PyObjectRef, PyPayload, PyResult, VirtualMachine,
8    class::PyClassImpl,
9    convert::ToPyObject,
10    function::OptionalArg,
11    protocol::{PyIter, PyIterReturn},
12    raise_if_stop,
13    types::{Constructor, IterNext, Iterable, SelfIter},
14};
15use malachite_bigint::BigInt;
16use num_traits::Zero;
17
18#[pyclass(module = false, name = "enumerate", traverse)]
19#[derive(Debug)]
20pub struct PyEnumerate {
21    #[pytraverse(skip)]
22    counter: PyRwLock<BigInt>,
23    iterable: PyIter,
24}
25
26impl PyPayload for PyEnumerate {
27    #[inline]
28    fn class(ctx: &Context) -> &'static Py<PyType> {
29        ctx.types.enumerate_type
30    }
31}
32
33#[derive(FromArgs)]
34pub struct EnumerateArgs {
35    #[pyarg(any)]
36    iterable: PyIter,
37    #[pyarg(any, optional)]
38    start: OptionalArg<PyIntRef>,
39}
40
41impl Constructor for PyEnumerate {
42    type Args = EnumerateArgs;
43
44    fn py_new(
45        _cls: &Py<PyType>,
46        Self::Args { iterable, start }: Self::Args,
47        _vm: &VirtualMachine,
48    ) -> PyResult<Self> {
49        let counter = start.map_or_else(BigInt::zero, |start| start.as_bigint().clone());
50        Ok(Self {
51            counter: PyRwLock::new(counter),
52            iterable,
53        })
54    }
55}
56
57#[pyclass(with(Py, IterNext, Iterable, Constructor), flags(BASETYPE))]
58impl PyEnumerate {
59    #[pyclassmethod]
60    fn __class_getitem__(cls: PyTypeRef, args: PyObjectRef, vm: &VirtualMachine) -> PyGenericAlias {
61        PyGenericAlias::from_args(cls, args, vm)
62    }
63}
64
65#[pyclass]
66impl Py<PyEnumerate> {
67    #[pymethod]
68    fn __reduce__(&self) -> (PyTypeRef, (PyIter, BigInt)) {
69        (
70            self.class().to_owned(),
71            (self.iterable.clone(), self.counter.read().clone()),
72        )
73    }
74}
75
76impl SelfIter for PyEnumerate {}
77
78impl IterNext for PyEnumerate {
79    fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
80        let next_obj = raise_if_stop!(zelf.iterable.next(vm)?);
81        let mut counter = zelf.counter.write();
82        let position = counter.clone();
83        *counter += 1;
84        Ok(PyIterReturn::Return((position, next_obj).to_pyobject(vm)))
85    }
86}
87
88#[pyclass(module = false, name = "reversed", traverse)]
89#[derive(Debug)]
90pub struct PyReverseSequenceIterator {
91    internal: PyMutex<PositionIterInternal<PyObjectRef>>,
92}
93
94impl PyPayload for PyReverseSequenceIterator {
95    #[inline]
96    fn class(ctx: &Context) -> &'static Py<PyType> {
97        ctx.types.reverse_iter_type
98    }
99}
100
101#[pyclass(with(IterNext, Iterable))]
102impl PyReverseSequenceIterator {
103    pub const fn new(obj: PyObjectRef, len: usize) -> Self {
104        let position = len.saturating_sub(1);
105        Self {
106            internal: PyMutex::new(PositionIterInternal::new(obj, position)),
107        }
108    }
109
110    #[pymethod]
111    fn __length_hint__(&self, vm: &VirtualMachine) -> PyResult<usize> {
112        let internal = self.internal.lock();
113        if let IterStatus::Active(obj) = &internal.status
114            && internal.position <= obj.length(vm)?
115        {
116            return Ok(internal.position + 1);
117        }
118        Ok(0)
119    }
120
121    #[pymethod]
122    fn __setstate__(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
123        self.internal.lock().set_state(state, |_, pos| pos, vm)
124    }
125
126    #[pymethod]
127    fn __reduce__(&self, vm: &VirtualMachine) -> PyTupleRef {
128        let func = builtins_reversed(vm);
129        self.internal.lock().reduce(
130            func,
131            |x| x.clone(),
132            |vm| vm.ctx.empty_tuple.clone().into(),
133            vm,
134        )
135    }
136}
137
138impl SelfIter for PyReverseSequenceIterator {}
139impl IterNext for PyReverseSequenceIterator {
140    fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
141        zelf.internal
142            .lock()
143            .rev_next(|obj, pos| PyIterReturn::from_getitem_result(obj.get_item(&pos, vm), vm))
144    }
145}
146
147pub fn init(context: &'static Context) {
148    PyEnumerate::extend_class(context, context.types.enumerate_type);
149    PyReverseSequenceIterator::extend_class(context, context.types.reverse_iter_type);
150}