use super::{
IterStatus, PositionIterInternal, PyDict, PyDictRef, PyGenericAlias, PyTupleRef, PyType,
PyTypeRef, builtins_iter,
};
use crate::common::lock::LazyLock;
use crate::{
AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject,
atomic_func,
class::PyClassImpl,
common::{ascii, hash::PyHash, lock::PyMutex, rc::PyRc, wtf8::Wtf8Buf},
convert::ToPyResult,
dict_inner::{self, DictSize},
function::{ArgIterable, FuncArgs, OptionalArg, PosArgs, PyArithmeticValue, PyComparisonValue},
protocol::{PyIterReturn, PyNumberMethods, PySequenceMethods},
recursion::ReprGuard,
types::AsNumber,
types::{
AsSequence, Comparable, Constructor, DefaultConstructor, Hashable, Initializer, IterNext,
Iterable, PyComparisonOp, Representable, SelfIter,
},
utils::collection_repr,
vm::VirtualMachine,
};
use alloc::fmt;
use core::borrow::Borrow;
use core::ops::Deref;
use rustpython_common::{
atomic::{Ordering, PyAtomic, Radium},
hash,
};
pub type SetContentType = dict_inner::Dict<()>;
#[pyclass(module = false, name = "set", unhashable = true, traverse)]
#[derive(Default)]
pub struct PySet {
pub(super) inner: PySetInner,
}
impl PySet {
#[deprecated(note = "Use `PySet::default().into_ref(ctx)` instead")]
pub fn new_ref(ctx: &Context) -> PyRef<Self> {
Self::default().into_ref(ctx)
}
pub fn elements(&self) -> Vec<PyObjectRef> {
self.inner.elements()
}
fn fold_op(
&self,
others: impl core::iter::Iterator<Item = ArgIterable>,
op: fn(&PySetInner, ArgIterable, &VirtualMachine) -> PyResult<PySetInner>,
vm: &VirtualMachine,
) -> PyResult<Self> {
Ok(Self {
inner: self.inner.fold_op(others, op, vm)?,
})
}
fn op(
&self,
other: AnySet,
op: fn(&PySetInner, ArgIterable, &VirtualMachine) -> PyResult<PySetInner>,
vm: &VirtualMachine,
) -> PyResult<Self> {
Ok(Self {
inner: self
.inner
.fold_op(core::iter::once(other.into_iterable(vm)?), op, vm)?,
})
}
}
#[pyclass(module = false, name = "frozenset", unhashable = true)]
pub struct PyFrozenSet {
inner: PySetInner,
hash: PyAtomic<PyHash>,
}
impl Default for PyFrozenSet {
fn default() -> Self {
Self {
inner: PySetInner::default(),
hash: hash::SENTINEL.into(),
}
}
}
impl PyFrozenSet {
pub fn from_iter(
vm: &VirtualMachine,
it: impl IntoIterator<Item = PyObjectRef>,
) -> PyResult<Self> {
let inner = PySetInner::default();
for elem in it {
inner.add(elem, vm)?;
}
Ok(Self {
inner,
..Default::default()
})
}
pub fn elements(&self) -> Vec<PyObjectRef> {
self.inner.elements()
}
fn fold_op(
&self,
others: impl core::iter::Iterator<Item = ArgIterable>,
op: fn(&PySetInner, ArgIterable, &VirtualMachine) -> PyResult<PySetInner>,
vm: &VirtualMachine,
) -> PyResult<Self> {
Ok(Self {
inner: self.inner.fold_op(others, op, vm)?,
..Default::default()
})
}
fn op(
&self,
other: AnySet,
op: fn(&PySetInner, ArgIterable, &VirtualMachine) -> PyResult<PySetInner>,
vm: &VirtualMachine,
) -> PyResult<Self> {
Ok(Self {
inner: self
.inner
.fold_op(core::iter::once(other.into_iterable(vm)?), op, vm)?,
..Default::default()
})
}
}
impl fmt::Debug for PySet {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("set")
}
}
impl fmt::Debug for PyFrozenSet {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("PyFrozenSet ")?;
f.debug_set().entries(self.elements().iter()).finish()
}
}
impl PyPayload for PySet {
#[inline]
fn class(ctx: &Context) -> &'static Py<PyType> {
ctx.types.set_type
}
}
impl PyPayload for PyFrozenSet {
#[inline]
fn class(ctx: &Context) -> &'static Py<PyType> {
ctx.types.frozenset_type
}
}
#[derive(Default, Clone)]
pub(super) struct PySetInner {
content: PyRc<SetContentType>,
}
unsafe impl crate::object::Traverse for PySetInner {
fn traverse(&self, tracer_fn: &mut crate::object::TraverseFn<'_>) {
self.content.traverse(tracer_fn)
}
}
impl PySetInner {
pub(super) fn from_iter<T>(iter: T, vm: &VirtualMachine) -> PyResult<Self>
where
T: IntoIterator<Item = PyResult<PyObjectRef>>,
{
let set = Self::default();
for item in iter {
set.add(item?, vm)?;
}
Ok(set)
}
fn fold_op<O>(
&self,
others: impl core::iter::Iterator<Item = O>,
op: fn(&Self, O, &VirtualMachine) -> PyResult<Self>,
vm: &VirtualMachine,
) -> PyResult<Self> {
let mut res = self.copy();
for other in others {
res = op(&res, other, vm)?;
}
Ok(res)
}
fn len(&self) -> usize {
self.content.len()
}
fn sizeof(&self) -> usize {
self.content.sizeof()
}
fn copy(&self) -> Self {
Self {
content: PyRc::new((*self.content).clone()),
}
}
fn contains(&self, needle: &PyObject, vm: &VirtualMachine) -> PyResult<bool> {
self.retry_op_with_frozenset(needle, vm, |needle, vm| self.content.contains(vm, needle))
}
fn compare(&self, other: &Self, op: PyComparisonOp, vm: &VirtualMachine) -> PyResult<bool> {
if op == PyComparisonOp::Ne {
return self.compare(other, PyComparisonOp::Eq, vm).map(|eq| !eq);
}
if !op.eval_ord(self.len().cmp(&other.len())) {
return Ok(false);
}
let (superset, subset) = match op {
PyComparisonOp::Lt | PyComparisonOp::Le => (other, self),
_ => (self, other),
};
for key in subset.elements() {
if !superset.contains(&key, vm)? {
return Ok(false);
}
}
Ok(true)
}
pub(super) fn union(&self, other: ArgIterable, vm: &VirtualMachine) -> PyResult<Self> {
let set = self.clone();
for item in other.iter(vm)? {
set.add(item?, vm)?;
}
Ok(set)
}
pub(super) fn intersection(&self, other: ArgIterable, vm: &VirtualMachine) -> PyResult<Self> {
let set = Self::default();
for item in other.iter(vm)? {
let obj = item?;
if self.contains(&obj, vm)? {
set.add(obj, vm)?;
}
}
Ok(set)
}
pub(super) fn difference(&self, other: ArgIterable, vm: &VirtualMachine) -> PyResult<Self> {
let set = self.copy();
for item in other.iter(vm)? {
set.content.delete_if_exists(vm, &*item?)?;
}
Ok(set)
}
pub(super) fn symmetric_difference(
&self,
other: ArgIterable,
vm: &VirtualMachine,
) -> PyResult<Self> {
let new_inner = self.clone();
let other_set = Self::from_iter(other.iter(vm)?, vm)?;
for item in other_set.elements() {
new_inner.content.delete_or_insert(vm, &item, ())?
}
Ok(new_inner)
}
fn issuperset(&self, other: ArgIterable, vm: &VirtualMachine) -> PyResult<bool> {
for item in other.iter(vm)? {
if !self.contains(&*item?, vm)? {
return Ok(false);
}
}
Ok(true)
}
fn issubset(&self, other: ArgIterable, vm: &VirtualMachine) -> PyResult<bool> {
let other_set = Self::from_iter(other.iter(vm)?, vm)?;
self.compare(&other_set, PyComparisonOp::Le, vm)
}
pub(super) fn isdisjoint(&self, other: ArgIterable, vm: &VirtualMachine) -> PyResult<bool> {
for item in other.iter(vm)? {
if self.contains(&*item?, vm)? {
return Ok(false);
}
}
Ok(true)
}
fn iter(&self) -> PySetIterator {
PySetIterator {
size: self.content.size(),
internal: PyMutex::new(PositionIterInternal::new(self.content.clone(), 0)),
}
}
fn repr(&self, class_name: Option<&str>, vm: &VirtualMachine) -> PyResult<Wtf8Buf> {
collection_repr(class_name, "{", "}", self.elements().iter(), vm)
}
fn add(&self, item: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
self.content.insert(vm, &*item, ())
}
fn remove(&self, item: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
self.retry_op_with_frozenset(&item, vm, |item, vm| self.content.delete(vm, item))
}
fn discard(&self, item: &PyObject, vm: &VirtualMachine) -> PyResult<bool> {
self.retry_op_with_frozenset(item, vm, |item, vm| self.content.delete_if_exists(vm, item))
}
fn clear(&self) {
self.content.clear()
}
fn elements(&self) -> Vec<PyObjectRef> {
self.content.keys()
}
fn pop(&self, vm: &VirtualMachine) -> PyResult {
if let Some((key, _)) = self.content.pop_back() {
Ok(key)
} else {
let err_msg = vm.ctx.new_str(ascii!("pop from an empty set")).into();
Err(vm.new_key_error(err_msg))
}
}
fn update(
&self,
others: impl core::iter::Iterator<Item = ArgIterable>,
vm: &VirtualMachine,
) -> PyResult<()> {
for iterable in others {
for item in iterable.iter(vm)? {
self.add(item?, vm)?;
}
}
Ok(())
}
fn update_internal(&self, iterable: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
if let Ok(any_set) = AnySet::try_from_object(vm, iterable.to_owned()) {
self.merge_set(any_set, vm)
} else if let Ok(dict) = iterable.to_owned().downcast_exact::<PyDict>(vm) {
self.merge_dict(dict.into_pyref(), vm)
} else {
for item in iterable.try_into_value::<ArgIterable>(vm)?.iter(vm)? {
self.add(item?, vm)?;
}
Ok(())
}
}
fn merge_set(&self, any_set: AnySet, vm: &VirtualMachine) -> PyResult<()> {
for item in any_set.as_inner().elements() {
self.add(item, vm)?;
}
Ok(())
}
fn merge_dict(&self, dict: PyDictRef, vm: &VirtualMachine) -> PyResult<()> {
for (key, _value) in dict {
self.add(key, vm)?;
}
Ok(())
}
fn intersection_update(
&self,
others: impl core::iter::Iterator<Item = ArgIterable>,
vm: &VirtualMachine,
) -> PyResult<()> {
let temp_inner = self.fold_op(others, Self::intersection, vm)?;
self.clear();
for obj in temp_inner.elements() {
self.add(obj, vm)?;
}
Ok(())
}
fn difference_update(
&self,
others: impl core::iter::Iterator<Item = ArgIterable>,
vm: &VirtualMachine,
) -> PyResult<()> {
for iterable in others {
let items = iterable.iter(vm)?.collect::<Result<Vec<_>, _>>()?;
for item in items {
self.content.delete_if_exists(vm, &*item)?;
}
}
Ok(())
}
fn symmetric_difference_update(
&self,
others: impl core::iter::Iterator<Item = ArgIterable>,
vm: &VirtualMachine,
) -> PyResult<()> {
for iterable in others {
let iterable_set = Self::from_iter(iterable.iter(vm)?, vm)?;
for item in iterable_set.elements() {
self.content.delete_or_insert(vm, &item, ())?;
}
}
Ok(())
}
fn hash(&self, vm: &VirtualMachine) -> PyResult<PyHash> {
const fn _shuffle_bits(h: u64) -> u64 {
((h ^ 89869747) ^ (h.wrapping_shl(16))).wrapping_mul(3644798167)
}
let mut hash: u64 = (self.len() as u64 + 1).wrapping_mul(1927868237);
hash = self.content.try_fold_keys(hash, |h, element| {
Ok(h ^ _shuffle_bits(element.hash(vm)? as u64))
})?;
hash ^= (hash >> 11) ^ (hash >> 25);
hash = hash.wrapping_mul(69069).wrapping_add(907133923);
if hash == u64::MAX {
hash = 590923713;
}
Ok(hash as PyHash)
}
fn retry_op_with_frozenset<T, F>(
&self,
item: &PyObject,
vm: &VirtualMachine,
op: F,
) -> PyResult<T>
where
F: Fn(&PyObject, &VirtualMachine) -> PyResult<T>,
{
op(item, vm).or_else(|original_err| {
item.downcast_ref::<PySet>()
.ok_or(original_err)
.and_then(|set| {
op(
&PyFrozenSet {
inner: set.inner.copy(),
..Default::default()
}
.into_pyobject(vm),
vm,
)
.map_err(|op_err| {
if op_err.fast_isinstance(vm.ctx.exceptions.key_error) {
vm.new_key_error(item.to_owned())
} else {
op_err
}
})
})
})
}
}
fn extract_set(obj: &PyObject) -> Option<&PySetInner> {
match_class!(match obj {
ref set @ PySet => Some(&set.inner),
ref frozen @ PyFrozenSet => Some(&frozen.inner),
_ => None,
})
}
fn reduce_set(
zelf: &PyObject,
vm: &VirtualMachine,
) -> PyResult<(PyTypeRef, PyTupleRef, Option<PyDictRef>)> {
Ok((
zelf.class().to_owned(),
vm.new_tuple((extract_set(zelf)
.unwrap_or(&PySetInner::default())
.elements(),)),
zelf.dict(),
))
}
#[pyclass(
with(
Constructor,
Initializer,
AsSequence,
Comparable,
Iterable,
AsNumber,
Representable
),
flags(BASETYPE, _MATCH_SELF, HAS_WEAKREF)
)]
impl PySet {
fn __len__(&self) -> usize {
self.inner.len()
}
fn __contains__(&self, needle: &PyObject, vm: &VirtualMachine) -> PyResult<bool> {
self.inner.contains(needle, vm)
}
#[pymethod]
fn __sizeof__(&self) -> usize {
core::mem::size_of::<Self>() + self.inner.sizeof()
}
#[pymethod]
fn copy(&self) -> Self {
Self {
inner: self.inner.copy(),
}
}
#[pymethod]
fn union(&self, others: PosArgs<ArgIterable>, vm: &VirtualMachine) -> PyResult<Self> {
self.fold_op(others.into_iter(), PySetInner::union, vm)
}
#[pymethod]
fn intersection(&self, others: PosArgs<ArgIterable>, vm: &VirtualMachine) -> PyResult<Self> {
self.fold_op(others.into_iter(), PySetInner::intersection, vm)
}
#[pymethod]
fn difference(&self, others: PosArgs<ArgIterable>, vm: &VirtualMachine) -> PyResult<Self> {
self.fold_op(others.into_iter(), PySetInner::difference, vm)
}
#[pymethod]
fn symmetric_difference(
&self,
others: PosArgs<ArgIterable>,
vm: &VirtualMachine,
) -> PyResult<Self> {
self.fold_op(others.into_iter(), PySetInner::symmetric_difference, vm)
}
#[pymethod]
fn issubset(&self, other: ArgIterable, vm: &VirtualMachine) -> PyResult<bool> {
self.inner.issubset(other, vm)
}
#[pymethod]
fn issuperset(&self, other: ArgIterable, vm: &VirtualMachine) -> PyResult<bool> {
self.inner.issuperset(other, vm)
}
#[pymethod]
fn isdisjoint(&self, other: ArgIterable, vm: &VirtualMachine) -> PyResult<bool> {
self.inner.isdisjoint(other, vm)
}
fn __or__(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyArithmeticValue<Self>> {
if let Ok(other) = AnySet::try_from_object(vm, other) {
Ok(PyArithmeticValue::Implemented(self.op(
other,
PySetInner::union,
vm,
)?))
} else {
Ok(PyArithmeticValue::NotImplemented)
}
}
fn __and__(
&self,
other: PyObjectRef,
vm: &VirtualMachine,
) -> PyResult<PyArithmeticValue<Self>> {
if let Ok(other) = AnySet::try_from_object(vm, other) {
Ok(PyArithmeticValue::Implemented(self.op(
other,
PySetInner::intersection,
vm,
)?))
} else {
Ok(PyArithmeticValue::NotImplemented)
}
}
fn __sub__(
&self,
other: PyObjectRef,
vm: &VirtualMachine,
) -> PyResult<PyArithmeticValue<Self>> {
if let Ok(other) = AnySet::try_from_object(vm, other) {
Ok(PyArithmeticValue::Implemented(self.op(
other,
PySetInner::difference,
vm,
)?))
} else {
Ok(PyArithmeticValue::NotImplemented)
}
}
fn __rsub__(
zelf: PyRef<Self>,
other: PyObjectRef,
vm: &VirtualMachine,
) -> PyResult<PyArithmeticValue<Self>> {
if let Ok(other) = AnySet::try_from_object(vm, other) {
Ok(PyArithmeticValue::Implemented(Self {
inner: other
.as_inner()
.difference(ArgIterable::try_from_object(vm, zelf.into())?, vm)?,
}))
} else {
Ok(PyArithmeticValue::NotImplemented)
}
}
fn __xor__(
&self,
other: PyObjectRef,
vm: &VirtualMachine,
) -> PyResult<PyArithmeticValue<Self>> {
if let Ok(other) = AnySet::try_from_object(vm, other) {
Ok(PyArithmeticValue::Implemented(self.op(
other,
PySetInner::symmetric_difference,
vm,
)?))
} else {
Ok(PyArithmeticValue::NotImplemented)
}
}
#[pymethod]
pub fn add(&self, item: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
self.inner.add(item, vm)
}
#[pymethod]
fn remove(&self, item: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
self.inner.remove(item, vm)
}
#[pymethod]
fn discard(&self, item: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
self.inner.discard(&item, vm).map(|_| ())
}
#[pymethod]
fn clear(&self) {
self.inner.clear()
}
#[pymethod]
fn pop(&self, vm: &VirtualMachine) -> PyResult {
self.inner.pop(vm)
}
fn __ior__(zelf: PyRef<Self>, set: AnySet, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
zelf.inner.update(set.into_iterable_iter(vm)?, vm)?;
Ok(zelf)
}
#[pymethod]
fn update(&self, others: PosArgs<PyObjectRef>, vm: &VirtualMachine) -> PyResult<()> {
for iterable in others {
self.inner.update_internal(iterable, vm)?;
}
Ok(())
}
#[pymethod]
fn intersection_update(
&self,
others: PosArgs<ArgIterable>,
vm: &VirtualMachine,
) -> PyResult<()> {
self.inner.intersection_update(others.into_iter(), vm)?;
Ok(())
}
fn __iand__(zelf: PyRef<Self>, set: AnySet, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
if !set.is(zelf.as_object()) {
zelf.inner
.intersection_update(core::iter::once(set.into_iterable(vm)?), vm)?;
}
Ok(zelf)
}
#[pymethod]
fn difference_update(&self, others: PosArgs<ArgIterable>, vm: &VirtualMachine) -> PyResult<()> {
self.inner.difference_update(others.into_iter(), vm)
}
fn __isub__(zelf: PyRef<Self>, set: AnySet, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
if set.is(zelf.as_object()) {
zelf.inner.clear();
} else {
zelf.inner
.difference_update(set.into_iterable_iter(vm)?, vm)?;
}
Ok(zelf)
}
#[pymethod]
fn symmetric_difference_update(
&self,
others: PosArgs<ArgIterable>,
vm: &VirtualMachine,
) -> PyResult<()> {
self.inner
.symmetric_difference_update(others.into_iter(), vm)
}
fn __ixor__(zelf: PyRef<Self>, set: AnySet, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
if set.is(zelf.as_object()) {
zelf.inner.clear();
} else {
zelf.inner
.symmetric_difference_update(set.into_iterable_iter(vm)?, vm)?;
}
Ok(zelf)
}
#[pymethod]
fn __reduce__(
zelf: PyRef<Self>,
vm: &VirtualMachine,
) -> PyResult<(PyTypeRef, PyTupleRef, Option<PyDictRef>)> {
reduce_set(zelf.as_ref(), vm)
}
#[pyclassmethod]
fn __class_getitem__(cls: PyTypeRef, args: PyObjectRef, vm: &VirtualMachine) -> PyGenericAlias {
PyGenericAlias::from_args(cls, args, vm)
}
}
impl DefaultConstructor for PySet {}
impl Initializer for PySet {
type Args = OptionalArg<PyObjectRef>;
fn init(zelf: PyRef<Self>, iterable: Self::Args, vm: &VirtualMachine) -> PyResult<()> {
zelf.clear();
if let OptionalArg::Present(it) = iterable {
zelf.update(PosArgs::new(vec![it]), vm)?;
}
Ok(())
}
}
impl AsSequence for PySet {
fn as_sequence() -> &'static PySequenceMethods {
static AS_SEQUENCE: LazyLock<PySequenceMethods> = LazyLock::new(|| PySequenceMethods {
length: atomic_func!(|seq, _vm| Ok(PySet::sequence_downcast(seq).__len__())),
contains: atomic_func!(
|seq, needle, vm| PySet::sequence_downcast(seq).__contains__(needle, vm)
),
..PySequenceMethods::NOT_IMPLEMENTED
});
&AS_SEQUENCE
}
}
impl Comparable for PySet {
fn cmp(
zelf: &crate::Py<Self>,
other: &PyObject,
op: PyComparisonOp,
vm: &VirtualMachine,
) -> PyResult<PyComparisonValue> {
extract_set(other).map_or(Ok(PyComparisonValue::NotImplemented), |other| {
Ok(zelf.inner.compare(other, op, vm)?.into())
})
}
}
impl Iterable for PySet {
fn iter(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult {
Ok(zelf.inner.iter().into_pyobject(vm))
}
}
impl AsNumber for PySet {
fn as_number() -> &'static PyNumberMethods {
static AS_NUMBER: PyNumberMethods = PyNumberMethods {
subtract: Some(|a, b, vm| {
if !AnySet::check(a, vm) || !AnySet::check(b, vm) {
return Ok(vm.ctx.not_implemented());
}
if let Some(a) = a.downcast_ref::<PySet>() {
a.__sub__(b.to_owned(), vm).to_pyresult(vm)
} else if let Some(a) = a.downcast_ref::<PyFrozenSet>() {
a.__sub__(b.to_owned(), vm)
.map(|r| {
r.map(|s| PySet {
inner: s.inner.clone(),
})
})
.to_pyresult(vm)
} else {
Ok(vm.ctx.not_implemented())
}
}),
and: Some(|a, b, vm| {
if !AnySet::check(a, vm) || !AnySet::check(b, vm) {
return Ok(vm.ctx.not_implemented());
}
if let Some(a) = a.downcast_ref::<PySet>() {
a.__and__(b.to_owned(), vm).to_pyresult(vm)
} else if let Some(a) = a.downcast_ref::<PyFrozenSet>() {
a.__and__(b.to_owned(), vm)
.map(|r| {
r.map(|s| PySet {
inner: s.inner.clone(),
})
})
.to_pyresult(vm)
} else {
Ok(vm.ctx.not_implemented())
}
}),
xor: Some(|a, b, vm| {
if !AnySet::check(a, vm) || !AnySet::check(b, vm) {
return Ok(vm.ctx.not_implemented());
}
if let Some(a) = a.downcast_ref::<PySet>() {
a.__xor__(b.to_owned(), vm).to_pyresult(vm)
} else if let Some(a) = a.downcast_ref::<PyFrozenSet>() {
a.__xor__(b.to_owned(), vm)
.map(|r| {
r.map(|s| PySet {
inner: s.inner.clone(),
})
})
.to_pyresult(vm)
} else {
Ok(vm.ctx.not_implemented())
}
}),
or: Some(|a, b, vm| {
if !AnySet::check(a, vm) || !AnySet::check(b, vm) {
return Ok(vm.ctx.not_implemented());
}
if let Some(a) = a.downcast_ref::<PySet>() {
a.__or__(b.to_owned(), vm).to_pyresult(vm)
} else if let Some(a) = a.downcast_ref::<PyFrozenSet>() {
a.__or__(b.to_owned(), vm)
.map(|r| {
r.map(|s| PySet {
inner: s.inner.clone(),
})
})
.to_pyresult(vm)
} else {
Ok(vm.ctx.not_implemented())
}
}),
inplace_subtract: Some(|a, b, vm| {
if let Some(a) = a.downcast_ref::<PySet>() {
PySet::__isub__(a.to_owned(), AnySet::try_from_object(vm, b.to_owned())?, vm)
.to_pyresult(vm)
} else {
Ok(vm.ctx.not_implemented())
}
}),
inplace_and: Some(|a, b, vm| {
if let Some(a) = a.downcast_ref::<PySet>() {
PySet::__iand__(a.to_owned(), AnySet::try_from_object(vm, b.to_owned())?, vm)
.to_pyresult(vm)
} else {
Ok(vm.ctx.not_implemented())
}
}),
inplace_xor: Some(|a, b, vm| {
if let Some(a) = a.downcast_ref::<PySet>() {
PySet::__ixor__(a.to_owned(), AnySet::try_from_object(vm, b.to_owned())?, vm)
.to_pyresult(vm)
} else {
Ok(vm.ctx.not_implemented())
}
}),
inplace_or: Some(|a, b, vm| {
if let Some(a) = a.downcast_ref::<PySet>() {
PySet::__ior__(a.to_owned(), AnySet::try_from_object(vm, b.to_owned())?, vm)
.to_pyresult(vm)
} else {
Ok(vm.ctx.not_implemented())
}
}),
..PyNumberMethods::NOT_IMPLEMENTED
};
&AS_NUMBER
}
}
impl Representable for PySet {
#[inline]
fn repr_wtf8(zelf: &crate::Py<Self>, vm: &VirtualMachine) -> PyResult<Wtf8Buf> {
let class = zelf.class();
let borrowed_name = class.name();
let class_name = borrowed_name.deref();
if zelf.inner.len() == 0 {
return Ok(Wtf8Buf::from(format!("{class_name}()")));
}
if let Some(_guard) = ReprGuard::enter(vm, zelf.as_object()) {
let name = (class_name != "set").then_some(class_name);
zelf.inner.repr(name, vm)
} else {
Ok(Wtf8Buf::from(format!("{class_name}(...)")))
}
}
}
impl Constructor for PyFrozenSet {
type Args = Vec<PyObjectRef>;
fn slot_new(cls: PyTypeRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult {
let iterable: OptionalArg<PyObjectRef> = args.bind(vm)?;
if cls.is(vm.ctx.types.frozenset_type) {
if let OptionalArg::Present(ref input) = iterable
&& let Ok(fs) = input.clone().downcast_exact::<PyFrozenSet>(vm)
{
return Ok(fs.into_pyref().into());
}
if iterable.is_missing() {
return Ok(vm.ctx.empty_frozenset.clone().into());
}
}
let elements: Vec<PyObjectRef> = if let OptionalArg::Present(iterable) = iterable {
iterable.try_to_value(vm)?
} else {
vec![]
};
if elements.is_empty() && cls.is(vm.ctx.types.frozenset_type) {
return Ok(vm.ctx.empty_frozenset.clone().into());
}
let payload = Self::py_new(&cls, elements, vm)?;
payload.into_ref_with_type(vm, cls).map(Into::into)
}
fn py_new(_cls: &Py<PyType>, elements: Self::Args, vm: &VirtualMachine) -> PyResult<Self> {
Self::from_iter(vm, elements)
}
}
#[pyclass(
flags(BASETYPE, _MATCH_SELF, HAS_WEAKREF),
with(
Constructor,
AsSequence,
Hashable,
Comparable,
Iterable,
AsNumber,
Representable
)
)]
impl PyFrozenSet {
fn __len__(&self) -> usize {
self.inner.len()
}
fn __contains__(&self, needle: &PyObject, vm: &VirtualMachine) -> PyResult<bool> {
self.inner.contains(needle, vm)
}
#[pymethod]
fn __sizeof__(&self) -> usize {
core::mem::size_of::<Self>() + self.inner.sizeof()
}
#[pymethod]
fn copy(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyRef<Self> {
if zelf.class().is(vm.ctx.types.frozenset_type) {
zelf
} else {
Self {
inner: zelf.inner.copy(),
..Default::default()
}
.into_ref(&vm.ctx)
}
}
#[pymethod]
fn union(&self, others: PosArgs<ArgIterable>, vm: &VirtualMachine) -> PyResult<Self> {
self.fold_op(others.into_iter(), PySetInner::union, vm)
}
#[pymethod]
fn intersection(&self, others: PosArgs<ArgIterable>, vm: &VirtualMachine) -> PyResult<Self> {
self.fold_op(others.into_iter(), PySetInner::intersection, vm)
}
#[pymethod]
fn difference(&self, others: PosArgs<ArgIterable>, vm: &VirtualMachine) -> PyResult<Self> {
self.fold_op(others.into_iter(), PySetInner::difference, vm)
}
#[pymethod]
fn symmetric_difference(
&self,
others: PosArgs<ArgIterable>,
vm: &VirtualMachine,
) -> PyResult<Self> {
self.fold_op(others.into_iter(), PySetInner::symmetric_difference, vm)
}
#[pymethod]
fn issubset(&self, other: ArgIterable, vm: &VirtualMachine) -> PyResult<bool> {
self.inner.issubset(other, vm)
}
#[pymethod]
fn issuperset(&self, other: ArgIterable, vm: &VirtualMachine) -> PyResult<bool> {
self.inner.issuperset(other, vm)
}
#[pymethod]
fn isdisjoint(&self, other: ArgIterable, vm: &VirtualMachine) -> PyResult<bool> {
self.inner.isdisjoint(other, vm)
}
fn __or__(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyArithmeticValue<Self>> {
if let Ok(set) = AnySet::try_from_object(vm, other) {
Ok(PyArithmeticValue::Implemented(self.op(
set,
PySetInner::union,
vm,
)?))
} else {
Ok(PyArithmeticValue::NotImplemented)
}
}
fn __and__(
&self,
other: PyObjectRef,
vm: &VirtualMachine,
) -> PyResult<PyArithmeticValue<Self>> {
if let Ok(other) = AnySet::try_from_object(vm, other) {
Ok(PyArithmeticValue::Implemented(self.op(
other,
PySetInner::intersection,
vm,
)?))
} else {
Ok(PyArithmeticValue::NotImplemented)
}
}
fn __sub__(
&self,
other: PyObjectRef,
vm: &VirtualMachine,
) -> PyResult<PyArithmeticValue<Self>> {
if let Ok(other) = AnySet::try_from_object(vm, other) {
Ok(PyArithmeticValue::Implemented(self.op(
other,
PySetInner::difference,
vm,
)?))
} else {
Ok(PyArithmeticValue::NotImplemented)
}
}
fn __rsub__(
zelf: PyRef<Self>,
other: PyObjectRef,
vm: &VirtualMachine,
) -> PyResult<PyArithmeticValue<Self>> {
if let Ok(other) = AnySet::try_from_object(vm, other) {
Ok(PyArithmeticValue::Implemented(Self {
inner: other
.as_inner()
.difference(ArgIterable::try_from_object(vm, zelf.into())?, vm)?,
..Default::default()
}))
} else {
Ok(PyArithmeticValue::NotImplemented)
}
}
fn __xor__(
&self,
other: PyObjectRef,
vm: &VirtualMachine,
) -> PyResult<PyArithmeticValue<Self>> {
if let Ok(other) = AnySet::try_from_object(vm, other) {
Ok(PyArithmeticValue::Implemented(self.op(
other,
PySetInner::symmetric_difference,
vm,
)?))
} else {
Ok(PyArithmeticValue::NotImplemented)
}
}
#[pymethod]
fn __reduce__(
zelf: PyRef<Self>,
vm: &VirtualMachine,
) -> PyResult<(PyTypeRef, PyTupleRef, Option<PyDictRef>)> {
reduce_set(zelf.as_ref(), vm)
}
#[pyclassmethod]
fn __class_getitem__(cls: PyTypeRef, args: PyObjectRef, vm: &VirtualMachine) -> PyGenericAlias {
PyGenericAlias::from_args(cls, args, vm)
}
}
impl AsSequence for PyFrozenSet {
fn as_sequence() -> &'static PySequenceMethods {
static AS_SEQUENCE: LazyLock<PySequenceMethods> = LazyLock::new(|| PySequenceMethods {
length: atomic_func!(|seq, _vm| Ok(PyFrozenSet::sequence_downcast(seq).__len__())),
contains: atomic_func!(
|seq, needle, vm| PyFrozenSet::sequence_downcast(seq).__contains__(needle, vm)
),
..PySequenceMethods::NOT_IMPLEMENTED
});
&AS_SEQUENCE
}
}
impl Hashable for PyFrozenSet {
#[inline]
fn hash(zelf: &crate::Py<Self>, vm: &VirtualMachine) -> PyResult<PyHash> {
let hash = match zelf.hash.load(Ordering::Relaxed) {
hash::SENTINEL => {
let hash = zelf.inner.hash(vm)?;
match Radium::compare_exchange(
&zelf.hash,
hash::SENTINEL,
hash::fix_sentinel(hash),
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => hash,
Err(prev_stored) => prev_stored,
}
}
hash => hash,
};
Ok(hash)
}
}
impl Comparable for PyFrozenSet {
fn cmp(
zelf: &crate::Py<Self>,
other: &PyObject,
op: PyComparisonOp,
vm: &VirtualMachine,
) -> PyResult<PyComparisonValue> {
extract_set(other).map_or(Ok(PyComparisonValue::NotImplemented), |other| {
Ok(zelf.inner.compare(other, op, vm)?.into())
})
}
}
impl Iterable for PyFrozenSet {
fn iter(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult {
Ok(zelf.inner.iter().into_pyobject(vm))
}
}
impl AsNumber for PyFrozenSet {
fn as_number() -> &'static PyNumberMethods {
static AS_NUMBER: PyNumberMethods = PyNumberMethods {
subtract: Some(|a, b, vm| {
if !AnySet::check(a, vm) || !AnySet::check(b, vm) {
return Ok(vm.ctx.not_implemented());
}
if let Some(a) = a.downcast_ref::<PyFrozenSet>() {
a.__sub__(b.to_owned(), vm).to_pyresult(vm)
} else if let Some(a) = a.downcast_ref::<PySet>() {
a.__sub__(b.to_owned(), vm).to_pyresult(vm)
} else {
Ok(vm.ctx.not_implemented())
}
}),
and: Some(|a, b, vm| {
if !AnySet::check(a, vm) || !AnySet::check(b, vm) {
return Ok(vm.ctx.not_implemented());
}
if let Some(a) = a.downcast_ref::<PyFrozenSet>() {
a.__and__(b.to_owned(), vm).to_pyresult(vm)
} else if let Some(a) = a.downcast_ref::<PySet>() {
a.__and__(b.to_owned(), vm).to_pyresult(vm)
} else {
Ok(vm.ctx.not_implemented())
}
}),
xor: Some(|a, b, vm| {
if !AnySet::check(a, vm) || !AnySet::check(b, vm) {
return Ok(vm.ctx.not_implemented());
}
if let Some(a) = a.downcast_ref::<PyFrozenSet>() {
a.__xor__(b.to_owned(), vm).to_pyresult(vm)
} else if let Some(a) = a.downcast_ref::<PySet>() {
a.__xor__(b.to_owned(), vm).to_pyresult(vm)
} else {
Ok(vm.ctx.not_implemented())
}
}),
or: Some(|a, b, vm| {
if !AnySet::check(a, vm) || !AnySet::check(b, vm) {
return Ok(vm.ctx.not_implemented());
}
if let Some(a) = a.downcast_ref::<PyFrozenSet>() {
a.__or__(b.to_owned(), vm).to_pyresult(vm)
} else if let Some(a) = a.downcast_ref::<PySet>() {
a.__or__(b.to_owned(), vm).to_pyresult(vm)
} else {
Ok(vm.ctx.not_implemented())
}
}),
..PyNumberMethods::NOT_IMPLEMENTED
};
&AS_NUMBER
}
}
impl Representable for PyFrozenSet {
#[inline]
fn repr_wtf8(zelf: &crate::Py<Self>, vm: &VirtualMachine) -> PyResult<Wtf8Buf> {
let inner = &zelf.inner;
let class = zelf.class();
let class_name = class.name();
if inner.len() == 0 {
return Ok(Wtf8Buf::from(format!("{class_name}()")));
}
if let Some(_guard) = ReprGuard::enter(vm, zelf.as_object()) {
inner.repr(Some(&class_name), vm)
} else {
Ok(Wtf8Buf::from(format!("{class_name}(...)")))
}
}
}
struct AnySet {
object: PyObjectRef,
}
impl Borrow<PyObject> for AnySet {
#[inline(always)]
fn borrow(&self) -> &PyObject {
&self.object
}
}
impl AnySet {
fn check(obj: &PyObject, vm: &VirtualMachine) -> bool {
let ctx = &vm.ctx;
obj.fast_isinstance(ctx.types.set_type) || obj.fast_isinstance(ctx.types.frozenset_type)
}
fn into_iterable(self, vm: &VirtualMachine) -> PyResult<ArgIterable> {
self.object.try_into_value(vm)
}
fn into_iterable_iter(
self,
vm: &VirtualMachine,
) -> PyResult<impl core::iter::Iterator<Item = ArgIterable>> {
Ok(core::iter::once(self.into_iterable(vm)?))
}
fn as_inner(&self) -> &PySetInner {
match_class!(match self.object.as_object() {
ref set @ PySet => &set.inner,
ref frozen @ PyFrozenSet => &frozen.inner,
_ => unreachable!("AnySet is always PySet or PyFrozenSet"), })
}
}
impl TryFromObject for AnySet {
fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult<Self> {
let class = obj.class();
if class.fast_issubclass(vm.ctx.types.set_type)
|| class.fast_issubclass(vm.ctx.types.frozenset_type)
{
Ok(Self { object: obj })
} else {
Err(vm.new_type_error(format!("{class} is not a subtype of set or frozenset")))
}
}
}
#[pyclass(module = false, name = "set_iterator")]
pub(crate) struct PySetIterator {
size: DictSize,
internal: PyMutex<PositionIterInternal<PyRc<SetContentType>>>,
}
impl fmt::Debug for PySetIterator {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("set_iterator")
}
}
impl PyPayload for PySetIterator {
#[inline]
fn class(ctx: &Context) -> &'static Py<PyType> {
ctx.types.set_iterator_type
}
}
#[pyclass(flags(DISALLOW_INSTANTIATION), with(IterNext, Iterable))]
impl PySetIterator {
#[pymethod]
fn __length_hint__(&self) -> usize {
self.internal.lock().length_hint(|_| self.size.entries_size)
}
#[pymethod]
fn __reduce__(
zelf: PyRef<Self>,
vm: &VirtualMachine,
) -> PyResult<(PyObjectRef, (PyObjectRef,))> {
let internal = zelf.internal.lock();
Ok((
builtins_iter(vm),
(vm.ctx
.new_list(match &internal.status {
IterStatus::Exhausted => vec![],
IterStatus::Active(dict) => {
dict.keys().into_iter().skip(internal.position).collect()
}
})
.into(),),
))
}
}
impl SelfIter for PySetIterator {}
impl IterNext for PySetIterator {
fn next(zelf: &crate::Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
let mut internal = zelf.internal.lock();
let next = if let IterStatus::Active(dict) = &internal.status {
if dict.has_changed_size(&zelf.size) {
internal.status = IterStatus::Exhausted;
return Err(vm.new_runtime_error("set changed size during iteration"));
}
match dict.next_entry(internal.position) {
Some((position, key, _)) => {
internal.position = position;
PyIterReturn::Return(key)
}
None => {
internal.status = IterStatus::Exhausted;
PyIterReturn::StopIteration(None)
}
}
} else {
PyIterReturn::StopIteration(None)
};
Ok(next)
}
}
fn vectorcall_set(
zelf_obj: &PyObject,
args: Vec<PyObjectRef>,
nargs: usize,
kwnames: Option<&[PyObjectRef]>,
vm: &VirtualMachine,
) -> PyResult {
let zelf: &Py<PyType> = zelf_obj.downcast_ref().unwrap();
let obj = PySet::default().into_ref_with_type(vm, zelf.to_owned())?;
let func_args = FuncArgs::from_vectorcall_owned(args, nargs, kwnames);
PySet::slot_init(obj.clone().into(), func_args, vm)?;
Ok(obj.into())
}
fn vectorcall_frozenset(
zelf_obj: &PyObject,
args: Vec<PyObjectRef>,
nargs: usize,
kwnames: Option<&[PyObjectRef]>,
vm: &VirtualMachine,
) -> PyResult {
let zelf: &Py<PyType> = zelf_obj.downcast_ref().unwrap();
let func_args = FuncArgs::from_vectorcall_owned(args, nargs, kwnames);
(zelf.slots.new.load().unwrap())(zelf.to_owned(), func_args, vm)
}
pub fn init(context: &'static Context) {
PySet::extend_class(context, context.types.set_type);
context
.types
.set_type
.slots
.vectorcall
.store(Some(vectorcall_set));
PyFrozenSet::extend_class(context, context.types.frozenset_type);
context
.types
.frozenset_type
.slots
.vectorcall
.store(Some(vectorcall_frozenset));
PySetIterator::extend_class(context, context.types.set_iterator_type);
}