rustpython_vm/builtins/
map.rs1use super::PyType;
2use crate::{
3 AsObject, Context, Py, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine,
4 builtins::PyTupleRef,
5 class::PyClassImpl,
6 function::{ArgIntoBool, OptionalArg, PosArgs},
7 protocol::{PyIter, PyIterReturn},
8 types::{Constructor, IterNext, Iterable, SelfIter},
9};
10use rustpython_common::atomic::{self, PyAtomic, Radium};
11
12#[pyclass(module = false, name = "map", traverse)]
13#[derive(Debug)]
14pub struct PyMap {
15 mapper: PyObjectRef,
16 iterators: Vec<PyIter>,
17 #[pytraverse(skip)]
18 strict: PyAtomic<bool>,
19}
20
21impl PyPayload for PyMap {
22 #[inline]
23 fn class(ctx: &Context) -> &'static Py<PyType> {
24 ctx.types.map_type
25 }
26}
27
28#[derive(FromArgs)]
29pub struct PyMapNewArgs {
30 #[pyarg(named, optional)]
31 strict: OptionalArg<bool>,
32}
33
34impl Constructor for PyMap {
35 type Args = (PyObjectRef, PosArgs<PyIter>, PyMapNewArgs);
36
37 fn py_new(
38 _cls: &Py<PyType>,
39 (mapper, iterators, args): Self::Args,
40 _vm: &VirtualMachine,
41 ) -> PyResult<Self> {
42 let iterators = iterators.into_vec();
43 let strict = Radium::new(args.strict.unwrap_or(false));
44 Ok(Self {
45 mapper,
46 iterators,
47 strict,
48 })
49 }
50}
51
52#[pyclass(with(IterNext, Iterable, Constructor), flags(BASETYPE))]
53impl PyMap {
54 #[pymethod]
55 fn __length_hint__(&self, vm: &VirtualMachine) -> PyResult<usize> {
56 self.iterators.iter().try_fold(0, |prev, cur| {
57 let cur = cur.as_ref().to_owned().length_hint(0, vm)?;
58 let max = core::cmp::max(prev, cur);
59 Ok(max)
60 })
61 }
62
63 #[pymethod]
64 fn __reduce__(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult<PyTupleRef> {
65 let cls = zelf.class().to_owned();
66 let mut vec = vec![zelf.mapper.clone()];
67 vec.extend(zelf.iterators.iter().map(|o| o.clone().into()));
68 let tuple_args = vm.ctx.new_tuple(vec);
69 Ok(if zelf.strict.load(atomic::Ordering::Acquire) {
70 vm.new_tuple((cls, tuple_args, true))
71 } else {
72 vm.new_tuple((cls, tuple_args))
73 })
74 }
75
76 #[pymethod]
77 fn __setstate__(zelf: PyRef<Self>, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
78 if let Ok(obj) = ArgIntoBool::try_from_object(vm, state) {
79 zelf.strict.store(obj.into(), atomic::Ordering::Release);
80 }
81 Ok(())
82 }
83}
84
85impl SelfIter for PyMap {}
86
87impl IterNext for PyMap {
88 fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
89 let mut next_objs = Vec::new();
90 for (idx, iterator) in zelf.iterators.iter().enumerate() {
91 let item = match iterator.next(vm)? {
92 PyIterReturn::Return(obj) => obj,
93 PyIterReturn::StopIteration(v) => {
94 if zelf.strict.load(atomic::Ordering::Acquire) {
95 if idx > 0 {
96 let plural = if idx == 1 { " " } else { "s 1-" };
97 return Err(vm.new_value_error(format!(
98 "map() argument {} is shorter than argument{}{}",
99 idx + 1,
100 plural,
101 idx,
102 )));
103 }
104 for (idx, iterator) in zelf.iterators[1..].iter().enumerate() {
105 if let PyIterReturn::Return(_) = iterator.next(vm)? {
106 let plural = if idx == 0 { " " } else { "s 1-" };
107 return Err(vm.new_value_error(format!(
108 "map() argument {} is longer than argument{}{}",
109 idx + 2,
110 plural,
111 idx + 1,
112 )));
113 }
114 }
115 }
116 return Ok(PyIterReturn::StopIteration(v));
117 }
118 };
119 next_objs.push(item);
120 }
121
122 PyIterReturn::from_pyresult(zelf.mapper.call(next_objs, vm), vm)
124 }
125}
126
127pub fn init(context: &'static Context) {
128 PyMap::extend_class(context, context.types.map_type);
129}