rustpython_vm/builtins/
enumerate.rs1use super::{
2 IterStatus, PositionIterInternal, PyGenericAlias, PyIntRef, PyTupleRef, PyType, PyTypeRef,
3 iter::builtins_reversed,
4};
5use crate::common::lock::{PyMutex, PyRwLock};
6use crate::{
7 AsObject, Context, Py, PyObjectRef, PyPayload, PyResult, VirtualMachine,
8 class::PyClassImpl,
9 convert::ToPyObject,
10 function::OptionalArg,
11 protocol::{PyIter, PyIterReturn},
12 raise_if_stop,
13 types::{Constructor, IterNext, Iterable, SelfIter},
14};
15use malachite_bigint::BigInt;
16use num_traits::Zero;
17
18#[pyclass(module = false, name = "enumerate", traverse)]
19#[derive(Debug)]
20pub struct PyEnumerate {
21 #[pytraverse(skip)]
22 counter: PyRwLock<BigInt>,
23 iterable: PyIter,
24}
25
26impl PyPayload for PyEnumerate {
27 #[inline]
28 fn class(ctx: &Context) -> &'static Py<PyType> {
29 ctx.types.enumerate_type
30 }
31}
32
33#[derive(FromArgs)]
34pub struct EnumerateArgs {
35 #[pyarg(any)]
36 iterable: PyIter,
37 #[pyarg(any, optional)]
38 start: OptionalArg<PyIntRef>,
39}
40
41impl Constructor for PyEnumerate {
42 type Args = EnumerateArgs;
43
44 fn py_new(
45 _cls: &Py<PyType>,
46 Self::Args { iterable, start }: Self::Args,
47 _vm: &VirtualMachine,
48 ) -> PyResult<Self> {
49 let counter = start.map_or_else(BigInt::zero, |start| start.as_bigint().clone());
50 Ok(Self {
51 counter: PyRwLock::new(counter),
52 iterable,
53 })
54 }
55}
56
57#[pyclass(with(Py, IterNext, Iterable, Constructor), flags(BASETYPE))]
58impl PyEnumerate {
59 #[pyclassmethod]
60 fn __class_getitem__(cls: PyTypeRef, args: PyObjectRef, vm: &VirtualMachine) -> PyGenericAlias {
61 PyGenericAlias::from_args(cls, args, vm)
62 }
63}
64
65#[pyclass]
66impl Py<PyEnumerate> {
67 #[pymethod]
68 fn __reduce__(&self) -> (PyTypeRef, (PyIter, BigInt)) {
69 (
70 self.class().to_owned(),
71 (self.iterable.clone(), self.counter.read().clone()),
72 )
73 }
74}
75
76impl SelfIter for PyEnumerate {}
77
78impl IterNext for PyEnumerate {
79 fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
80 let next_obj = raise_if_stop!(zelf.iterable.next(vm)?);
81 let mut counter = zelf.counter.write();
82 let position = counter.clone();
83 *counter += 1;
84 Ok(PyIterReturn::Return((position, next_obj).to_pyobject(vm)))
85 }
86}
87
88#[pyclass(module = false, name = "reversed", traverse)]
89#[derive(Debug)]
90pub struct PyReverseSequenceIterator {
91 internal: PyMutex<PositionIterInternal<PyObjectRef>>,
92}
93
94impl PyPayload for PyReverseSequenceIterator {
95 #[inline]
96 fn class(ctx: &Context) -> &'static Py<PyType> {
97 ctx.types.reverse_iter_type
98 }
99}
100
101#[pyclass(with(IterNext, Iterable))]
102impl PyReverseSequenceIterator {
103 pub const fn new(obj: PyObjectRef, len: usize) -> Self {
104 let position = len.saturating_sub(1);
105 Self {
106 internal: PyMutex::new(PositionIterInternal::new(obj, position)),
107 }
108 }
109
110 #[pymethod]
111 fn __length_hint__(&self, vm: &VirtualMachine) -> PyResult<usize> {
112 let internal = self.internal.lock();
113 if let IterStatus::Active(obj) = &internal.status
114 && internal.position <= obj.length(vm)?
115 {
116 return Ok(internal.position + 1);
117 }
118 Ok(0)
119 }
120
121 #[pymethod]
122 fn __setstate__(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
123 self.internal.lock().set_state(state, |_, pos| pos, vm)
124 }
125
126 #[pymethod]
127 fn __reduce__(&self, vm: &VirtualMachine) -> PyTupleRef {
128 let func = builtins_reversed(vm);
129 self.internal.lock().reduce(
130 func,
131 |x| x.clone(),
132 |vm| vm.ctx.empty_tuple.clone().into(),
133 vm,
134 )
135 }
136}
137
138impl SelfIter for PyReverseSequenceIterator {}
139impl IterNext for PyReverseSequenceIterator {
140 fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
141 zelf.internal
142 .lock()
143 .rev_next(|obj, pos| PyIterReturn::from_getitem_result(obj.get_item(&pos, vm), vm))
144 }
145}
146
147pub fn init(context: &'static Context) {
148 PyEnumerate::extend_class(context, context.types.enumerate_type);
149 PyReverseSequenceIterator::extend_class(context, context.types.reverse_iter_type);
150}