pub(crate) use _functools::module_def;
#[pymodule]
mod _functools {
use crate::{
Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
builtins::{PyBoundMethod, PyDict, PyGenericAlias, PyTuple, PyType, PyTypeRef},
common::lock::PyRwLock,
function::{FuncArgs, KwArgs, OptionalOption},
object::AsObject,
protocol::PyIter,
pyclass,
recursion::ReprGuard,
types::{Callable, Constructor, GetDescriptor, Representable},
};
use indexmap::IndexMap;
use rustpython_common::wtf8::Wtf8Buf;
#[derive(FromArgs)]
struct ReduceArgs {
function: PyObjectRef,
iterator: PyIter,
#[pyarg(any, optional, name = "initial")]
initial: OptionalOption<PyObjectRef>,
}
#[pyfunction]
fn reduce(args: ReduceArgs, vm: &VirtualMachine) -> PyResult {
let ReduceArgs {
function,
iterator,
initial,
} = args;
let mut iter = iterator.iter_without_hint(vm)?;
let start_value = if let Some(val) = initial.into_option() {
val.unwrap_or_else(|| vm.ctx.none())
} else {
iter.next().transpose()?.ok_or_else(|| {
let exc_type = vm.ctx.exceptions.type_error.to_owned();
vm.new_exception_msg(
exc_type,
"reduce() of empty sequence with no initial value".into(),
)
})?
};
let mut accumulator = start_value;
for next_obj in iter {
accumulator = function.call((accumulator, next_obj?), vm)?
}
Ok(accumulator)
}
#[pyattr]
#[allow(non_snake_case)]
fn Placeholder(vm: &VirtualMachine) -> PyObjectRef {
let placeholder = PyPlaceholderType.into_pyobject(vm);
let typ = placeholder.class();
typ.set_attr(vm.ctx.intern_str("_instance"), placeholder.clone());
placeholder
}
#[pyattr]
#[pyclass(name = "_PlaceholderType", module = "functools")]
#[derive(Debug, PyPayload)]
pub struct PyPlaceholderType;
impl Constructor for PyPlaceholderType {
type Args = FuncArgs;
fn slot_new(cls: PyTypeRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult {
if !args.args.is_empty() || !args.kwargs.is_empty() {
return Err(vm.new_type_error("_PlaceholderType takes no arguments"));
}
if let Some(instance) = cls.get_attr(vm.ctx.intern_str("_instance")) {
return Ok(instance);
}
Ok(PyPlaceholderType.into_pyobject(vm))
}
fn py_new(_cls: &Py<PyType>, _args: Self::Args, _vm: &VirtualMachine) -> PyResult<Self> {
Ok(PyPlaceholderType)
}
}
#[pyclass(with(Constructor, Representable))]
impl PyPlaceholderType {
#[pymethod]
fn __reduce__(&self) -> &'static str {
"Placeholder"
}
#[pymethod]
fn __init_subclass__(_cls: PyTypeRef, vm: &VirtualMachine) -> PyResult<()> {
Err(vm.new_type_error("cannot subclass '_PlaceholderType'"))
}
}
impl Representable for PyPlaceholderType {
#[inline]
fn repr_str(_zelf: &Py<Self>, _vm: &VirtualMachine) -> PyResult<String> {
Ok("Placeholder".to_owned())
}
}
fn is_placeholder(obj: &PyObjectRef) -> bool {
&*obj.class().name() == "_PlaceholderType"
}
fn count_placeholders(args: &[PyObjectRef]) -> usize {
args.iter().filter(|a| is_placeholder(a)).count()
}
#[pyattr]
#[pyclass(name = "partial", module = "functools")]
#[derive(Debug, PyPayload)]
pub struct PyPartial {
inner: PyRwLock<PyPartialInner>,
}
#[derive(Debug)]
struct PyPartialInner {
func: PyObjectRef,
args: PyRef<PyTuple>,
keywords: PyRef<PyDict>,
phcount: usize,
}
#[pyclass(
with(Constructor, Callable, GetDescriptor, Representable),
flags(BASETYPE, HAS_DICT, HAS_WEAKREF)
)]
impl PyPartial {
#[pygetset]
fn func(&self) -> PyObjectRef {
self.inner.read().func.clone()
}
#[pygetset]
fn args(&self) -> PyRef<PyTuple> {
self.inner.read().args.clone()
}
#[pygetset]
fn keywords(&self) -> PyRef<PyDict> {
self.inner.read().keywords.clone()
}
#[pymethod]
fn __reduce__(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult {
let inner = zelf.inner.read();
let partial_type = zelf.class();
let dict_obj = match zelf.as_object().dict() {
Some(dict) if !dict.is_empty() => dict.into(),
_ => vm.ctx.none(),
};
let state = vm.ctx.new_tuple(vec![
inner.func.clone(),
inner.args.clone().into(),
inner.keywords.clone().into(),
dict_obj,
]);
Ok(vm
.ctx
.new_tuple(vec![
partial_type.to_owned().into(),
vm.ctx.new_tuple(vec![inner.func.clone()]).into(),
state.into(),
])
.into())
}
#[pymethod]
fn __setstate__(zelf: &Py<Self>, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
let state_tuple = state
.downcast::<PyTuple>()
.map_err(|_| vm.new_type_error("argument to __setstate__ must be a tuple"))?;
if state_tuple.len() != 4 {
return Err(vm.new_type_error(format!(
"expected 4 items in state, got {}",
state_tuple.len()
)));
}
let func = &state_tuple[0];
let args = &state_tuple[1];
let kwds = &state_tuple[2];
let dict = &state_tuple[3];
if !func.is_callable() {
return Err(vm.new_type_error("invalid partial state"));
}
if !args.fast_isinstance(vm.ctx.types.tuple_type) {
return Err(vm.new_type_error("invalid partial state"));
}
let args_tuple = match args.clone().downcast::<PyTuple>() {
Ok(tuple) if tuple.class().is(vm.ctx.types.tuple_type) => tuple,
_ => {
let elements: Vec<PyObjectRef> = args.try_to_value(vm)?;
vm.ctx.new_tuple(elements)
}
};
let keywords_dict = if kwds.is(&vm.ctx.none) {
vm.ctx.new_dict()
} else {
let dict = kwds
.clone()
.downcast::<PyDict>()
.map_err(|_| vm.new_type_error("invalid partial state"))?;
if dict.class().is(vm.ctx.types.dict_type) {
dict
} else {
let new_dict = vm.ctx.new_dict();
for (key, value) in dict {
new_dict.set_item(&*key, value, vm)?;
}
new_dict
}
};
let args_slice = args_tuple.as_slice();
if !args_slice.is_empty() && is_placeholder(args_slice.last().unwrap()) {
return Err(vm.new_type_error("trailing Placeholders are not allowed"));
}
let phcount = count_placeholders(args_slice);
let mut inner = zelf.inner.write();
inner.func = func.clone();
inner.args = args_tuple;
inner.keywords = keywords_dict;
inner.phcount = phcount;
let Some(instance_dict) = zelf.as_object().dict() else {
return Ok(());
};
if dict.is(&vm.ctx.none) {
instance_dict.clear();
return Ok(());
}
let dict_obj = dict
.clone()
.downcast::<PyDict>()
.map_err(|_| vm.new_type_error("invalid partial state"))?;
instance_dict.clear();
for (key, value) in dict_obj {
instance_dict.set_item(&*key, value, vm)?;
}
Ok(())
}
#[pyclassmethod]
fn __class_getitem__(
cls: PyTypeRef,
args: PyObjectRef,
vm: &VirtualMachine,
) -> PyGenericAlias {
PyGenericAlias::from_args(cls, args, vm)
}
}
impl Constructor for PyPartial {
type Args = FuncArgs;
fn py_new(
_cls: &crate::Py<crate::builtins::PyType>,
args: Self::Args,
vm: &VirtualMachine,
) -> PyResult<Self> {
let (func, args_slice) = args
.args
.split_first()
.ok_or_else(|| vm.new_type_error("partial expected at least 1 argument, got 0"))?;
if !func.is_callable() {
return Err(vm.new_type_error("the first argument must be callable"));
}
for (key, value) in &args.kwargs {
if is_placeholder(value) {
return Err(vm.new_type_error(format!(
"Placeholder cannot be passed as a keyword argument to partial(). \
Did you mean partial(..., {}=Placeholder, ...)(value)?",
key
)));
}
}
let (final_func, final_args, final_keywords) =
if let Some(partial) = func.downcast_ref::<Self>() {
let inner = partial.inner.read();
let stored_args = inner.args.as_slice();
let mut merged_args = Vec::with_capacity(stored_args.len() + args_slice.len());
let mut new_args_iter = args_slice.iter();
for stored_arg in stored_args {
if is_placeholder(stored_arg) {
if let Some(new_arg) = new_args_iter.next() {
merged_args.push(new_arg.clone());
} else {
merged_args.push(stored_arg.clone());
}
} else {
merged_args.push(stored_arg.clone());
}
}
merged_args.extend(new_args_iter.cloned());
(inner.func.clone(), merged_args, inner.keywords.clone())
} else {
(func.clone(), args_slice.to_vec(), vm.ctx.new_dict())
};
if !final_args.is_empty() && is_placeholder(final_args.last().unwrap()) {
return Err(vm.new_type_error("trailing Placeholders are not allowed"));
}
let phcount = count_placeholders(&final_args);
for (key, value) in args.kwargs {
final_keywords.set_item(vm.ctx.intern_str(key.as_str()), value, vm)?;
}
Ok(Self {
inner: PyRwLock::new(PyPartialInner {
func: final_func,
args: vm.ctx.new_tuple(final_args),
keywords: final_keywords,
phcount,
}),
})
}
}
impl Callable for PyPartial {
type Args = FuncArgs;
fn call(zelf: &Py<Self>, args: FuncArgs, vm: &VirtualMachine) -> PyResult {
let (func, stored_args, keywords, phcount) = {
let inner = zelf.inner.read();
(
inner.func.clone(),
inner.args.clone(),
inner.keywords.clone(),
inner.phcount,
)
};
if phcount > 0 && args.args.len() < phcount {
return Err(vm.new_type_error(format!(
"missing positional arguments in 'partial' call; expected at least {}, got {}",
phcount,
args.args.len()
)));
}
let mut combined_args = Vec::with_capacity(stored_args.len() + args.args.len());
let mut new_args_iter = args.args.iter();
for stored_arg in stored_args.as_slice() {
if is_placeholder(stored_arg) {
if let Some(new_arg) = new_args_iter.next() {
combined_args.push(new_arg.clone());
} else {
combined_args.push(stored_arg.clone());
}
} else {
combined_args.push(stored_arg.clone());
}
}
combined_args.extend(new_args_iter.cloned());
let mut final_kwargs = IndexMap::new();
for (key, value) in &*keywords {
let key_str = key
.downcast_ref::<crate::builtins::PyStr>()
.ok_or_else(|| vm.new_type_error("keywords must be strings"))?;
final_kwargs.insert(key_str.expect_str().to_owned(), value);
}
for (key, value) in args.kwargs {
final_kwargs.insert(key, value);
}
func.call(FuncArgs::new(combined_args, KwArgs::new(final_kwargs)), vm)
}
}
impl GetDescriptor for PyPartial {
fn descr_get(
zelf: PyObjectRef,
obj: Option<PyObjectRef>,
_cls: Option<PyObjectRef>,
vm: &VirtualMachine,
) -> PyResult {
let obj = match obj {
Some(obj) if !vm.is_none(&obj) => obj,
_ => return Ok(zelf),
};
Ok(PyBoundMethod::new(obj, zelf).into_ref(&vm.ctx).into())
}
}
impl Representable for PyPartial {
#[inline]
fn repr_wtf8(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<Wtf8Buf> {
let obj = zelf.as_object();
if let Some(_guard) = ReprGuard::enter(vm, obj) {
let (func, args, keywords) = {
let inner = zelf.inner.read();
(
inner.func.clone(),
inner.args.clone(),
inner.keywords.clone(),
)
};
let qualname = zelf.class().__qualname__(vm);
let qualname_wtf8 = qualname
.downcast_ref::<crate::builtins::PyStr>()
.map(|s| s.as_wtf8().to_owned())
.unwrap_or_else(|| Wtf8Buf::from(zelf.class().name().to_owned()));
let module = zelf.class().__module__(vm);
let mut result = Wtf8Buf::new();
if let Ok(module_str) = module.downcast::<crate::builtins::PyStr>() {
let module_name = module_str.as_wtf8();
if module_name != "builtins" && !module_name.is_empty() {
result.push_wtf8(module_name);
result.push_char('.');
}
}
result.push_wtf8(&qualname_wtf8);
result.push_char('(');
result.push_wtf8(func.repr(vm)?.as_wtf8());
for arg in args.as_slice() {
result.push_str(", ");
result.push_wtf8(arg.repr(vm)?.as_wtf8());
}
for (key, value) in &*keywords {
result.push_str(", ");
let key_str = if let Ok(s) = key.clone().downcast::<crate::builtins::PyStr>() {
s
} else {
key.str(vm)?
};
result.push_wtf8(key_str.as_wtf8());
result.push_char('=');
result.push_wtf8(value.repr(vm)?.as_wtf8());
}
result.push_char(')');
Ok(result)
} else {
Ok(Wtf8Buf::from("..."))
}
}
}
}