Skip to main content

rustpython_vm/builtins/
union.rs

1use super::{genericalias, type_};
2use crate::common::lock::LazyLock;
3use crate::{
4    AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
5    atomic_func,
6    builtins::{PyFrozenSet, PySet, PyStr, PyTuple, PyTupleRef, PyType},
7    class::PyClassImpl,
8    common::hash,
9    convert::ToPyObject,
10    function::PyComparisonValue,
11    protocol::{PyMappingMethods, PyNumberMethods},
12    stdlib::_typing::{TypeAliasType, call_typing_func_object},
13    types::{AsMapping, AsNumber, Comparable, GetAttr, Hashable, PyComparisonOp, Representable},
14};
15use alloc::fmt;
16
17const CLS_ATTRS: &[&str] = &["__module__"];
18
19#[pyclass(module = "typing", name = "Union", traverse)]
20pub struct PyUnion {
21    args: PyTupleRef,
22    /// Frozenset of hashable args, or None if all args were hashable
23    hashable_args: Option<PyRef<PyFrozenSet>>,
24    /// Tuple of initially unhashable args, or None if all args were hashable
25    unhashable_args: Option<PyTupleRef>,
26    parameters: PyTupleRef,
27}
28
29impl fmt::Debug for PyUnion {
30    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31        f.write_str("UnionObject")
32    }
33}
34
35impl PyPayload for PyUnion {
36    #[inline]
37    fn class(ctx: &Context) -> &'static Py<PyType> {
38        ctx.types.union_type
39    }
40}
41
42impl PyUnion {
43    /// Create a new union from dedup result (internal use)
44    fn from_components(result: UnionComponents, vm: &VirtualMachine) -> PyResult<Self> {
45        let parameters = make_parameters(&result.args, vm)?;
46        Ok(Self {
47            args: result.args,
48            hashable_args: result.hashable_args,
49            unhashable_args: result.unhashable_args,
50            parameters,
51        })
52    }
53
54    /// Direct access to args field (_Py_union_args)
55    #[inline]
56    pub fn args(&self) -> &Py<PyTuple> {
57        &self.args
58    }
59
60    fn repr(&self, vm: &VirtualMachine) -> PyResult<String> {
61        fn repr_item(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult<String> {
62            if obj.is(vm.ctx.types.none_type) {
63                return Ok("None".to_string());
64            }
65
66            if vm
67                .get_attribute_opt(obj.clone(), identifier!(vm, __origin__))?
68                .is_some()
69                && vm
70                    .get_attribute_opt(obj.clone(), identifier!(vm, __args__))?
71                    .is_some()
72            {
73                return Ok(obj.repr(vm)?.to_string());
74            }
75
76            match (
77                vm.get_attribute_opt(obj.clone(), identifier!(vm, __qualname__))?
78                    .and_then(|o| o.downcast_ref::<PyStr>().map(|n| n.to_string())),
79                vm.get_attribute_opt(obj.clone(), identifier!(vm, __module__))?
80                    .and_then(|o| o.downcast_ref::<PyStr>().map(|m| m.to_string())),
81            ) {
82                (None, _) | (_, None) => Ok(obj.repr(vm)?.to_string()),
83                (Some(qualname), Some(module)) => Ok(if module == "builtins" {
84                    qualname
85                } else {
86                    format!("{module}.{qualname}")
87                }),
88            }
89        }
90
91        Ok(self
92            .args
93            .iter()
94            .map(|o| repr_item(o.clone(), vm))
95            .collect::<PyResult<Vec<_>>>()?
96            .join(" | "))
97    }
98}
99
100#[pyclass(
101    flags(DISALLOW_INSTANTIATION, HAS_WEAKREF),
102    with(Hashable, Comparable, AsMapping, AsNumber, Representable)
103)]
104impl PyUnion {
105    #[pygetset]
106    fn __name__(&self, vm: &VirtualMachine) -> PyObjectRef {
107        vm.ctx.new_str("Union").into()
108    }
109
110    #[pygetset]
111    fn __qualname__(&self, vm: &VirtualMachine) -> PyObjectRef {
112        vm.ctx.new_str("Union").into()
113    }
114
115    #[pygetset]
116    fn __origin__(&self, vm: &VirtualMachine) -> PyObjectRef {
117        vm.ctx.types.union_type.to_owned().into()
118    }
119
120    #[pygetset]
121    fn __parameters__(&self) -> PyObjectRef {
122        self.parameters.clone().into()
123    }
124
125    #[pygetset]
126    fn __args__(&self) -> PyObjectRef {
127        self.args.clone().into()
128    }
129
130    #[pymethod]
131    fn __instancecheck__(
132        zelf: PyRef<Self>,
133        obj: PyObjectRef,
134        vm: &VirtualMachine,
135    ) -> PyResult<bool> {
136        if zelf
137            .args
138            .iter()
139            .any(|x| x.class().is(vm.ctx.types.generic_alias_type))
140        {
141            Err(vm.new_type_error("isinstance() argument 2 cannot be a parameterized generic"))
142        } else {
143            obj.is_instance(zelf.__args__().as_object(), vm)
144        }
145    }
146
147    #[pymethod]
148    fn __subclasscheck__(
149        zelf: PyRef<Self>,
150        obj: PyObjectRef,
151        vm: &VirtualMachine,
152    ) -> PyResult<bool> {
153        if zelf
154            .args
155            .iter()
156            .any(|x| x.class().is(vm.ctx.types.generic_alias_type))
157        {
158            Err(vm.new_type_error("issubclass() argument 2 cannot be a parameterized generic"))
159        } else {
160            obj.is_subclass(zelf.__args__().as_object(), vm)
161        }
162    }
163
164    fn __or__(zelf: PyObjectRef, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
165        type_::or_(zelf, other, vm)
166    }
167
168    #[pymethod]
169    fn __mro_entries__(zelf: PyRef<Self>, _args: PyObjectRef, vm: &VirtualMachine) -> PyResult {
170        Err(vm.new_type_error(format!("Cannot subclass {}", zelf.repr(vm)?)))
171    }
172
173    #[pyclassmethod]
174    fn __class_getitem__(
175        _cls: crate::builtins::PyTypeRef,
176        args: PyObjectRef,
177        vm: &VirtualMachine,
178    ) -> PyResult {
179        // Convert args to tuple if not already
180        let args_tuple = if let Some(tuple) = args.downcast_ref::<PyTuple>() {
181            tuple.to_owned()
182        } else {
183            PyTuple::new_ref(vec![args], &vm.ctx)
184        };
185
186        // Check for empty union
187        if args_tuple.is_empty() {
188            return Err(vm.new_type_error("Cannot create empty Union"));
189        }
190
191        // Create union using make_union to properly handle None -> NoneType conversion
192        make_union(&args_tuple, vm)
193    }
194}
195
196fn is_unionable(obj: PyObjectRef, vm: &VirtualMachine) -> bool {
197    let cls = obj.class();
198    cls.is(vm.ctx.types.none_type)
199        || obj.downcastable::<PyType>()
200        || cls.fast_issubclass(vm.ctx.types.generic_alias_type)
201        || cls.is(vm.ctx.types.union_type)
202        || obj.downcast_ref::<TypeAliasType>().is_some()
203}
204
205fn type_check(arg: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
206    // Fast path to avoid calling into typing.py
207    if is_unionable(arg.clone(), vm) {
208        return Ok(arg);
209    }
210    let message_str: PyObjectRef = vm
211        .ctx
212        .new_str("Union[arg, ...]: each arg must be a type.")
213        .into();
214    call_typing_func_object(vm, "_type_check", (arg, message_str))
215}
216
217fn has_union_operands(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> bool {
218    let union_type = vm.ctx.types.union_type;
219    a.class().is(union_type) || b.class().is(union_type)
220}
221
222pub fn or_op(zelf: PyObjectRef, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
223    if !has_union_operands(zelf.clone(), other.clone(), vm)
224        && (!is_unionable(zelf.clone(), vm) || !is_unionable(other.clone(), vm))
225    {
226        return Ok(vm.ctx.not_implemented());
227    }
228
229    let left = type_check(zelf, vm)?;
230    let right = type_check(other, vm)?;
231    let tuple = PyTuple::new_ref(vec![left, right], &vm.ctx);
232    make_union(&tuple, vm)
233}
234
235fn make_parameters(args: &Py<PyTuple>, vm: &VirtualMachine) -> PyResult<PyTupleRef> {
236    let parameters = genericalias::make_parameters(args, vm);
237    let result = dedup_and_flatten_args(&parameters, vm)?;
238    Ok(result.args)
239}
240
241fn flatten_args(args: &Py<PyTuple>, vm: &VirtualMachine) -> PyTupleRef {
242    let mut total_args = 0;
243    for arg in args {
244        if let Some(pyref) = arg.downcast_ref::<PyUnion>() {
245            total_args += pyref.args.len();
246        } else {
247            total_args += 1;
248        };
249    }
250
251    let mut flattened_args = Vec::with_capacity(total_args);
252    for arg in args {
253        if let Some(pyref) = arg.downcast_ref::<PyUnion>() {
254            flattened_args.extend(pyref.args.iter().cloned());
255        } else if vm.is_none(arg) {
256            flattened_args.push(vm.ctx.types.none_type.to_owned().into());
257        } else if arg.downcast_ref::<PyStr>().is_some() {
258            // Convert string to ForwardRef
259            match string_to_forwardref(arg.clone(), vm) {
260                Ok(fr) => flattened_args.push(fr),
261                Err(_) => flattened_args.push(arg.clone()),
262            }
263        } else {
264            flattened_args.push(arg.clone());
265        };
266    }
267
268    PyTuple::new_ref(flattened_args, &vm.ctx)
269}
270
271fn string_to_forwardref(arg: PyObjectRef, vm: &VirtualMachine) -> PyResult {
272    // Import annotationlib.ForwardRef and create a ForwardRef
273    let annotationlib = vm.import("annotationlib", 0)?;
274    let forwardref_cls = annotationlib.get_attr("ForwardRef", vm)?;
275    forwardref_cls.call((arg,), vm)
276}
277
278/// Components for creating a PyUnion after deduplication
279struct UnionComponents {
280    /// All unique args in order
281    args: PyTupleRef,
282    /// Frozenset of hashable args (for fast equality comparison)
283    hashable_args: Option<PyRef<PyFrozenSet>>,
284    /// Tuple of unhashable args at creation time (for hash error message)
285    unhashable_args: Option<PyTupleRef>,
286}
287
288fn dedup_and_flatten_args(args: &Py<PyTuple>, vm: &VirtualMachine) -> PyResult<UnionComponents> {
289    let args = flatten_args(args, vm);
290
291    // Use set-based deduplication like CPython:
292    // - For hashable elements: use Python's set semantics (hash + equality)
293    // - For unhashable elements: use equality comparison
294    //
295    // This avoids calling __eq__ when hashes differ, so `int | BadType`
296    // doesn't raise even if BadType.__eq__ raises.
297
298    let mut new_args: Vec<PyObjectRef> = Vec::with_capacity(args.len());
299
300    // Track hashable elements using a Python set (uses hash + equality)
301    let hashable_set = PySet::default().into_ref(&vm.ctx);
302    let mut hashable_list: Vec<PyObjectRef> = Vec::new();
303    let mut unhashable_list: Vec<PyObjectRef> = Vec::new();
304
305    for arg in &*args {
306        // Try to hash the element first
307        match arg.hash(vm) {
308            Ok(_) => {
309                // Element is hashable - use set for deduplication
310                // Set membership uses hash first, then equality only if hashes match
311                let contains = vm
312                    .call_method(hashable_set.as_ref(), "__contains__", (arg.clone(),))
313                    .and_then(|r| r.try_to_bool(vm))?;
314                if !contains {
315                    hashable_set.add(arg.clone(), vm)?;
316                    hashable_list.push(arg.clone());
317                    new_args.push(arg.clone());
318                }
319            }
320            Err(_) => {
321                // Element is unhashable - use equality comparison
322                let mut is_duplicate = false;
323                for existing in &unhashable_list {
324                    match existing.rich_compare_bool(arg, PyComparisonOp::Eq, vm) {
325                        Ok(true) => {
326                            is_duplicate = true;
327                            break;
328                        }
329                        Ok(false) => continue,
330                        Err(e) => return Err(e),
331                    }
332                }
333                if !is_duplicate {
334                    unhashable_list.push(arg.clone());
335                    new_args.push(arg.clone());
336                }
337            }
338        }
339    }
340
341    new_args.shrink_to_fit();
342
343    // Create hashable_args frozenset if there are hashable elements
344    let hashable_args = if !hashable_list.is_empty() {
345        Some(PyFrozenSet::from_iter(vm, hashable_list.into_iter())?.into_ref(&vm.ctx))
346    } else {
347        None
348    };
349
350    // Create unhashable_args tuple if there are unhashable elements
351    let unhashable_args = if !unhashable_list.is_empty() {
352        Some(PyTuple::new_ref(unhashable_list, &vm.ctx))
353    } else {
354        None
355    };
356
357    Ok(UnionComponents {
358        args: PyTuple::new_ref(new_args, &vm.ctx),
359        hashable_args,
360        unhashable_args,
361    })
362}
363
364pub fn make_union(args: &Py<PyTuple>, vm: &VirtualMachine) -> PyResult {
365    let result = dedup_and_flatten_args(args, vm)?;
366    Ok(match result.args.len() {
367        1 => result.args[0].to_owned(),
368        _ => PyUnion::from_components(result, vm)?.to_pyobject(vm),
369    })
370}
371
372impl PyUnion {
373    fn getitem(zelf: PyRef<Self>, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult {
374        let new_args = genericalias::subs_parameters(
375            zelf.to_owned().into(),
376            zelf.args.clone(),
377            zelf.parameters.clone(),
378            needle,
379            vm,
380        )?;
381        let res;
382        if new_args.is_empty() {
383            res = make_union(&new_args, vm)?;
384        } else {
385            let mut tmp = new_args[0].to_owned();
386            for arg in new_args.iter().skip(1) {
387                tmp = vm._or(&tmp, arg)?;
388            }
389            res = tmp;
390        }
391
392        Ok(res)
393    }
394}
395
396impl AsMapping for PyUnion {
397    fn as_mapping() -> &'static PyMappingMethods {
398        static AS_MAPPING: LazyLock<PyMappingMethods> = LazyLock::new(|| PyMappingMethods {
399            subscript: atomic_func!(|mapping, needle, vm| {
400                let zelf = PyUnion::mapping_downcast(mapping);
401                PyUnion::getitem(zelf.to_owned(), needle.to_owned(), vm)
402            }),
403            ..PyMappingMethods::NOT_IMPLEMENTED
404        });
405        &AS_MAPPING
406    }
407}
408
409impl AsNumber for PyUnion {
410    fn as_number() -> &'static PyNumberMethods {
411        static AS_NUMBER: PyNumberMethods = PyNumberMethods {
412            or: Some(|a, b, vm| PyUnion::__or__(a.to_owned(), b.to_owned(), vm)),
413            ..PyNumberMethods::NOT_IMPLEMENTED
414        };
415        &AS_NUMBER
416    }
417}
418
419impl Comparable for PyUnion {
420    fn cmp(
421        zelf: &Py<Self>,
422        other: &PyObject,
423        op: PyComparisonOp,
424        vm: &VirtualMachine,
425    ) -> PyResult<PyComparisonValue> {
426        op.eq_only(|| {
427            let other = class_or_notimplemented!(Self, other);
428
429            // Check if lengths are equal
430            if zelf.args.len() != other.args.len() {
431                return Ok(PyComparisonValue::Implemented(false));
432            }
433
434            // Fast path: if both unions have all hashable args, compare frozensets directly
435            // Always use Eq here since eq_only handles Ne by negating the result
436            if zelf.unhashable_args.is_none()
437                && other.unhashable_args.is_none()
438                && let (Some(a), Some(b)) = (&zelf.hashable_args, &other.hashable_args)
439            {
440                let eq = a
441                    .as_object()
442                    .rich_compare_bool(b.as_object(), PyComparisonOp::Eq, vm)?;
443                return Ok(PyComparisonValue::Implemented(eq));
444            }
445
446            // Slow path: O(n^2) nested loop comparison for unhashable elements
447            // Check if all elements in zelf.args are in other.args
448            for arg_a in &*zelf.args {
449                let mut found = false;
450                for arg_b in &*other.args {
451                    match arg_a.rich_compare_bool(arg_b, PyComparisonOp::Eq, vm) {
452                        Ok(true) => {
453                            found = true;
454                            break;
455                        }
456                        Ok(false) => continue,
457                        Err(e) => return Err(e), // Propagate comparison errors
458                    }
459                }
460                if !found {
461                    return Ok(PyComparisonValue::Implemented(false));
462                }
463            }
464
465            // Check if all elements in other.args are in zelf.args (for symmetry)
466            for arg_b in &*other.args {
467                let mut found = false;
468                for arg_a in &*zelf.args {
469                    match arg_b.rich_compare_bool(arg_a, PyComparisonOp::Eq, vm) {
470                        Ok(true) => {
471                            found = true;
472                            break;
473                        }
474                        Ok(false) => continue,
475                        Err(e) => return Err(e), // Propagate comparison errors
476                    }
477                }
478                if !found {
479                    return Ok(PyComparisonValue::Implemented(false));
480                }
481            }
482
483            Ok(PyComparisonValue::Implemented(true))
484        })
485    }
486}
487
488impl Hashable for PyUnion {
489    #[inline]
490    fn hash(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<hash::PyHash> {
491        // If there are any unhashable args from creation time, the union is unhashable
492        if let Some(ref unhashable_args) = zelf.unhashable_args {
493            let n = unhashable_args.len();
494            // Try to hash each previously unhashable arg to get an error
495            for arg in unhashable_args.iter() {
496                arg.hash(vm)?;
497            }
498            // All previously unhashable args somehow became hashable
499            // But still raise an error to maintain consistent hashing
500            return Err(vm.new_type_error(format!(
501                "union contains {} unhashable element{}",
502                n,
503                if n > 1 { "s" } else { "" }
504            )));
505        }
506
507        // If we have a stored frozenset of hashable args, use that
508        if let Some(ref hashable_args) = zelf.hashable_args {
509            return PyFrozenSet::hash(hashable_args, vm);
510        }
511
512        // Fallback: compute hash from args
513        let mut args_to_hash = Vec::new();
514        for arg in &*zelf.args {
515            match arg.hash(vm) {
516                Ok(_) => args_to_hash.push(arg.clone()),
517                Err(e) => return Err(e),
518            }
519        }
520        let set = PyFrozenSet::from_iter(vm, args_to_hash.into_iter())?;
521        PyFrozenSet::hash(&set.into_ref(&vm.ctx), vm)
522    }
523}
524
525impl GetAttr for PyUnion {
526    fn getattro(zelf: &Py<Self>, attr: &Py<PyStr>, vm: &VirtualMachine) -> PyResult {
527        for &exc in CLS_ATTRS {
528            if *exc == attr.to_string() {
529                return zelf.as_object().generic_getattr(attr, vm);
530            }
531        }
532        zelf.as_object().get_attr(attr, vm)
533    }
534}
535
536impl Representable for PyUnion {
537    #[inline]
538    fn repr_str(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<String> {
539        zelf.repr(vm)
540    }
541}
542
543pub fn init(context: &'static Context) {
544    let union_type = &context.types.union_type;
545    PyUnion::extend_class(context, union_type);
546}