1use super::{genericalias, type_};
2use crate::common::lock::LazyLock;
3use crate::{
4 AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
5 atomic_func,
6 builtins::{PyFrozenSet, PySet, PyStr, PyTuple, PyTupleRef, PyType},
7 class::PyClassImpl,
8 common::hash,
9 convert::ToPyObject,
10 function::PyComparisonValue,
11 protocol::{PyMappingMethods, PyNumberMethods},
12 stdlib::_typing::{TypeAliasType, call_typing_func_object},
13 types::{AsMapping, AsNumber, Comparable, GetAttr, Hashable, PyComparisonOp, Representable},
14};
15use alloc::fmt;
16
17const CLS_ATTRS: &[&str] = &["__module__"];
18
19#[pyclass(module = "typing", name = "Union", traverse)]
20pub struct PyUnion {
21 args: PyTupleRef,
22 hashable_args: Option<PyRef<PyFrozenSet>>,
24 unhashable_args: Option<PyTupleRef>,
26 parameters: PyTupleRef,
27}
28
29impl fmt::Debug for PyUnion {
30 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31 f.write_str("UnionObject")
32 }
33}
34
35impl PyPayload for PyUnion {
36 #[inline]
37 fn class(ctx: &Context) -> &'static Py<PyType> {
38 ctx.types.union_type
39 }
40}
41
42impl PyUnion {
43 fn from_components(result: UnionComponents, vm: &VirtualMachine) -> PyResult<Self> {
45 let parameters = make_parameters(&result.args, vm)?;
46 Ok(Self {
47 args: result.args,
48 hashable_args: result.hashable_args,
49 unhashable_args: result.unhashable_args,
50 parameters,
51 })
52 }
53
54 #[inline]
56 pub fn args(&self) -> &Py<PyTuple> {
57 &self.args
58 }
59
60 fn repr(&self, vm: &VirtualMachine) -> PyResult<String> {
61 fn repr_item(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult<String> {
62 if obj.is(vm.ctx.types.none_type) {
63 return Ok("None".to_string());
64 }
65
66 if vm
67 .get_attribute_opt(obj.clone(), identifier!(vm, __origin__))?
68 .is_some()
69 && vm
70 .get_attribute_opt(obj.clone(), identifier!(vm, __args__))?
71 .is_some()
72 {
73 return Ok(obj.repr(vm)?.to_string());
74 }
75
76 match (
77 vm.get_attribute_opt(obj.clone(), identifier!(vm, __qualname__))?
78 .and_then(|o| o.downcast_ref::<PyStr>().map(|n| n.to_string())),
79 vm.get_attribute_opt(obj.clone(), identifier!(vm, __module__))?
80 .and_then(|o| o.downcast_ref::<PyStr>().map(|m| m.to_string())),
81 ) {
82 (None, _) | (_, None) => Ok(obj.repr(vm)?.to_string()),
83 (Some(qualname), Some(module)) => Ok(if module == "builtins" {
84 qualname
85 } else {
86 format!("{module}.{qualname}")
87 }),
88 }
89 }
90
91 Ok(self
92 .args
93 .iter()
94 .map(|o| repr_item(o.clone(), vm))
95 .collect::<PyResult<Vec<_>>>()?
96 .join(" | "))
97 }
98}
99
100#[pyclass(
101 flags(DISALLOW_INSTANTIATION, HAS_WEAKREF),
102 with(Hashable, Comparable, AsMapping, AsNumber, Representable)
103)]
104impl PyUnion {
105 #[pygetset]
106 fn __name__(&self, vm: &VirtualMachine) -> PyObjectRef {
107 vm.ctx.new_str("Union").into()
108 }
109
110 #[pygetset]
111 fn __qualname__(&self, vm: &VirtualMachine) -> PyObjectRef {
112 vm.ctx.new_str("Union").into()
113 }
114
115 #[pygetset]
116 fn __origin__(&self, vm: &VirtualMachine) -> PyObjectRef {
117 vm.ctx.types.union_type.to_owned().into()
118 }
119
120 #[pygetset]
121 fn __parameters__(&self) -> PyObjectRef {
122 self.parameters.clone().into()
123 }
124
125 #[pygetset]
126 fn __args__(&self) -> PyObjectRef {
127 self.args.clone().into()
128 }
129
130 #[pymethod]
131 fn __instancecheck__(
132 zelf: PyRef<Self>,
133 obj: PyObjectRef,
134 vm: &VirtualMachine,
135 ) -> PyResult<bool> {
136 if zelf
137 .args
138 .iter()
139 .any(|x| x.class().is(vm.ctx.types.generic_alias_type))
140 {
141 Err(vm.new_type_error("isinstance() argument 2 cannot be a parameterized generic"))
142 } else {
143 obj.is_instance(zelf.__args__().as_object(), vm)
144 }
145 }
146
147 #[pymethod]
148 fn __subclasscheck__(
149 zelf: PyRef<Self>,
150 obj: PyObjectRef,
151 vm: &VirtualMachine,
152 ) -> PyResult<bool> {
153 if zelf
154 .args
155 .iter()
156 .any(|x| x.class().is(vm.ctx.types.generic_alias_type))
157 {
158 Err(vm.new_type_error("issubclass() argument 2 cannot be a parameterized generic"))
159 } else {
160 obj.is_subclass(zelf.__args__().as_object(), vm)
161 }
162 }
163
164 fn __or__(zelf: PyObjectRef, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
165 type_::or_(zelf, other, vm)
166 }
167
168 #[pymethod]
169 fn __mro_entries__(zelf: PyRef<Self>, _args: PyObjectRef, vm: &VirtualMachine) -> PyResult {
170 Err(vm.new_type_error(format!("Cannot subclass {}", zelf.repr(vm)?)))
171 }
172
173 #[pyclassmethod]
174 fn __class_getitem__(
175 _cls: crate::builtins::PyTypeRef,
176 args: PyObjectRef,
177 vm: &VirtualMachine,
178 ) -> PyResult {
179 let args_tuple = if let Some(tuple) = args.downcast_ref::<PyTuple>() {
181 tuple.to_owned()
182 } else {
183 PyTuple::new_ref(vec![args], &vm.ctx)
184 };
185
186 if args_tuple.is_empty() {
188 return Err(vm.new_type_error("Cannot create empty Union"));
189 }
190
191 make_union(&args_tuple, vm)
193 }
194}
195
196fn is_unionable(obj: PyObjectRef, vm: &VirtualMachine) -> bool {
197 let cls = obj.class();
198 cls.is(vm.ctx.types.none_type)
199 || obj.downcastable::<PyType>()
200 || cls.fast_issubclass(vm.ctx.types.generic_alias_type)
201 || cls.is(vm.ctx.types.union_type)
202 || obj.downcast_ref::<TypeAliasType>().is_some()
203}
204
205fn type_check(arg: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
206 if is_unionable(arg.clone(), vm) {
208 return Ok(arg);
209 }
210 let message_str: PyObjectRef = vm
211 .ctx
212 .new_str("Union[arg, ...]: each arg must be a type.")
213 .into();
214 call_typing_func_object(vm, "_type_check", (arg, message_str))
215}
216
217fn has_union_operands(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> bool {
218 let union_type = vm.ctx.types.union_type;
219 a.class().is(union_type) || b.class().is(union_type)
220}
221
222pub fn or_op(zelf: PyObjectRef, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
223 if !has_union_operands(zelf.clone(), other.clone(), vm)
224 && (!is_unionable(zelf.clone(), vm) || !is_unionable(other.clone(), vm))
225 {
226 return Ok(vm.ctx.not_implemented());
227 }
228
229 let left = type_check(zelf, vm)?;
230 let right = type_check(other, vm)?;
231 let tuple = PyTuple::new_ref(vec![left, right], &vm.ctx);
232 make_union(&tuple, vm)
233}
234
235fn make_parameters(args: &Py<PyTuple>, vm: &VirtualMachine) -> PyResult<PyTupleRef> {
236 let parameters = genericalias::make_parameters(args, vm);
237 let result = dedup_and_flatten_args(¶meters, vm)?;
238 Ok(result.args)
239}
240
241fn flatten_args(args: &Py<PyTuple>, vm: &VirtualMachine) -> PyTupleRef {
242 let mut total_args = 0;
243 for arg in args {
244 if let Some(pyref) = arg.downcast_ref::<PyUnion>() {
245 total_args += pyref.args.len();
246 } else {
247 total_args += 1;
248 };
249 }
250
251 let mut flattened_args = Vec::with_capacity(total_args);
252 for arg in args {
253 if let Some(pyref) = arg.downcast_ref::<PyUnion>() {
254 flattened_args.extend(pyref.args.iter().cloned());
255 } else if vm.is_none(arg) {
256 flattened_args.push(vm.ctx.types.none_type.to_owned().into());
257 } else if arg.downcast_ref::<PyStr>().is_some() {
258 match string_to_forwardref(arg.clone(), vm) {
260 Ok(fr) => flattened_args.push(fr),
261 Err(_) => flattened_args.push(arg.clone()),
262 }
263 } else {
264 flattened_args.push(arg.clone());
265 };
266 }
267
268 PyTuple::new_ref(flattened_args, &vm.ctx)
269}
270
271fn string_to_forwardref(arg: PyObjectRef, vm: &VirtualMachine) -> PyResult {
272 let annotationlib = vm.import("annotationlib", 0)?;
274 let forwardref_cls = annotationlib.get_attr("ForwardRef", vm)?;
275 forwardref_cls.call((arg,), vm)
276}
277
278struct UnionComponents {
280 args: PyTupleRef,
282 hashable_args: Option<PyRef<PyFrozenSet>>,
284 unhashable_args: Option<PyTupleRef>,
286}
287
288fn dedup_and_flatten_args(args: &Py<PyTuple>, vm: &VirtualMachine) -> PyResult<UnionComponents> {
289 let args = flatten_args(args, vm);
290
291 let mut new_args: Vec<PyObjectRef> = Vec::with_capacity(args.len());
299
300 let hashable_set = PySet::default().into_ref(&vm.ctx);
302 let mut hashable_list: Vec<PyObjectRef> = Vec::new();
303 let mut unhashable_list: Vec<PyObjectRef> = Vec::new();
304
305 for arg in &*args {
306 match arg.hash(vm) {
308 Ok(_) => {
309 let contains = vm
312 .call_method(hashable_set.as_ref(), "__contains__", (arg.clone(),))
313 .and_then(|r| r.try_to_bool(vm))?;
314 if !contains {
315 hashable_set.add(arg.clone(), vm)?;
316 hashable_list.push(arg.clone());
317 new_args.push(arg.clone());
318 }
319 }
320 Err(_) => {
321 let mut is_duplicate = false;
323 for existing in &unhashable_list {
324 match existing.rich_compare_bool(arg, PyComparisonOp::Eq, vm) {
325 Ok(true) => {
326 is_duplicate = true;
327 break;
328 }
329 Ok(false) => continue,
330 Err(e) => return Err(e),
331 }
332 }
333 if !is_duplicate {
334 unhashable_list.push(arg.clone());
335 new_args.push(arg.clone());
336 }
337 }
338 }
339 }
340
341 new_args.shrink_to_fit();
342
343 let hashable_args = if !hashable_list.is_empty() {
345 Some(PyFrozenSet::from_iter(vm, hashable_list.into_iter())?.into_ref(&vm.ctx))
346 } else {
347 None
348 };
349
350 let unhashable_args = if !unhashable_list.is_empty() {
352 Some(PyTuple::new_ref(unhashable_list, &vm.ctx))
353 } else {
354 None
355 };
356
357 Ok(UnionComponents {
358 args: PyTuple::new_ref(new_args, &vm.ctx),
359 hashable_args,
360 unhashable_args,
361 })
362}
363
364pub fn make_union(args: &Py<PyTuple>, vm: &VirtualMachine) -> PyResult {
365 let result = dedup_and_flatten_args(args, vm)?;
366 Ok(match result.args.len() {
367 1 => result.args[0].to_owned(),
368 _ => PyUnion::from_components(result, vm)?.to_pyobject(vm),
369 })
370}
371
372impl PyUnion {
373 fn getitem(zelf: PyRef<Self>, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult {
374 let new_args = genericalias::subs_parameters(
375 zelf.to_owned().into(),
376 zelf.args.clone(),
377 zelf.parameters.clone(),
378 needle,
379 vm,
380 )?;
381 let res;
382 if new_args.is_empty() {
383 res = make_union(&new_args, vm)?;
384 } else {
385 let mut tmp = new_args[0].to_owned();
386 for arg in new_args.iter().skip(1) {
387 tmp = vm._or(&tmp, arg)?;
388 }
389 res = tmp;
390 }
391
392 Ok(res)
393 }
394}
395
396impl AsMapping for PyUnion {
397 fn as_mapping() -> &'static PyMappingMethods {
398 static AS_MAPPING: LazyLock<PyMappingMethods> = LazyLock::new(|| PyMappingMethods {
399 subscript: atomic_func!(|mapping, needle, vm| {
400 let zelf = PyUnion::mapping_downcast(mapping);
401 PyUnion::getitem(zelf.to_owned(), needle.to_owned(), vm)
402 }),
403 ..PyMappingMethods::NOT_IMPLEMENTED
404 });
405 &AS_MAPPING
406 }
407}
408
409impl AsNumber for PyUnion {
410 fn as_number() -> &'static PyNumberMethods {
411 static AS_NUMBER: PyNumberMethods = PyNumberMethods {
412 or: Some(|a, b, vm| PyUnion::__or__(a.to_owned(), b.to_owned(), vm)),
413 ..PyNumberMethods::NOT_IMPLEMENTED
414 };
415 &AS_NUMBER
416 }
417}
418
419impl Comparable for PyUnion {
420 fn cmp(
421 zelf: &Py<Self>,
422 other: &PyObject,
423 op: PyComparisonOp,
424 vm: &VirtualMachine,
425 ) -> PyResult<PyComparisonValue> {
426 op.eq_only(|| {
427 let other = class_or_notimplemented!(Self, other);
428
429 if zelf.args.len() != other.args.len() {
431 return Ok(PyComparisonValue::Implemented(false));
432 }
433
434 if zelf.unhashable_args.is_none()
437 && other.unhashable_args.is_none()
438 && let (Some(a), Some(b)) = (&zelf.hashable_args, &other.hashable_args)
439 {
440 let eq = a
441 .as_object()
442 .rich_compare_bool(b.as_object(), PyComparisonOp::Eq, vm)?;
443 return Ok(PyComparisonValue::Implemented(eq));
444 }
445
446 for arg_a in &*zelf.args {
449 let mut found = false;
450 for arg_b in &*other.args {
451 match arg_a.rich_compare_bool(arg_b, PyComparisonOp::Eq, vm) {
452 Ok(true) => {
453 found = true;
454 break;
455 }
456 Ok(false) => continue,
457 Err(e) => return Err(e), }
459 }
460 if !found {
461 return Ok(PyComparisonValue::Implemented(false));
462 }
463 }
464
465 for arg_b in &*other.args {
467 let mut found = false;
468 for arg_a in &*zelf.args {
469 match arg_b.rich_compare_bool(arg_a, PyComparisonOp::Eq, vm) {
470 Ok(true) => {
471 found = true;
472 break;
473 }
474 Ok(false) => continue,
475 Err(e) => return Err(e), }
477 }
478 if !found {
479 return Ok(PyComparisonValue::Implemented(false));
480 }
481 }
482
483 Ok(PyComparisonValue::Implemented(true))
484 })
485 }
486}
487
488impl Hashable for PyUnion {
489 #[inline]
490 fn hash(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<hash::PyHash> {
491 if let Some(ref unhashable_args) = zelf.unhashable_args {
493 let n = unhashable_args.len();
494 for arg in unhashable_args.iter() {
496 arg.hash(vm)?;
497 }
498 return Err(vm.new_type_error(format!(
501 "union contains {} unhashable element{}",
502 n,
503 if n > 1 { "s" } else { "" }
504 )));
505 }
506
507 if let Some(ref hashable_args) = zelf.hashable_args {
509 return PyFrozenSet::hash(hashable_args, vm);
510 }
511
512 let mut args_to_hash = Vec::new();
514 for arg in &*zelf.args {
515 match arg.hash(vm) {
516 Ok(_) => args_to_hash.push(arg.clone()),
517 Err(e) => return Err(e),
518 }
519 }
520 let set = PyFrozenSet::from_iter(vm, args_to_hash.into_iter())?;
521 PyFrozenSet::hash(&set.into_ref(&vm.ctx), vm)
522 }
523}
524
525impl GetAttr for PyUnion {
526 fn getattro(zelf: &Py<Self>, attr: &Py<PyStr>, vm: &VirtualMachine) -> PyResult {
527 for &exc in CLS_ATTRS {
528 if *exc == attr.to_string() {
529 return zelf.as_object().generic_getattr(attr, vm);
530 }
531 }
532 zelf.as_object().get_attr(attr, vm)
533 }
534}
535
536impl Representable for PyUnion {
537 #[inline]
538 fn repr_str(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<String> {
539 zelf.repr(vm)
540 }
541}
542
543pub fn init(context: &'static Context) {
544 let union_type = &context.types.union_type;
545 PyUnion::extend_class(context, union_type);
546}