1use crate::builtins::{PyList, PyStrRef, PyTuple, PyTupleRef, PyType, PyTypeRef};
6use crate::function::{ArgIterable, FuncArgs};
7use crate::types::{PyTypeFlags, PyTypeSlots};
8use crate::{
9 AsObject, Context, Py, PyObject, PyObjectRef, PyRef, PyResult, TryFromObject, VirtualMachine,
10};
11use core::fmt::Write;
12use rustpython_common::wtf8::{Wtf8, Wtf8Buf};
13
14use crate::exceptions::types::PyBaseException;
15
16fn create_exception_group(ctx: &Context) -> PyRef<PyType> {
18 let excs = &ctx.exceptions;
19 let exception_group_slots = PyTypeSlots {
20 flags: PyTypeFlags::heap_type_flags() | PyTypeFlags::HAS_DICT,
21 ..Default::default()
22 };
23 PyType::new_heap(
24 "ExceptionGroup",
25 vec![
26 excs.base_exception_group.to_owned(),
27 excs.exception_type.to_owned(),
28 ],
29 Default::default(),
30 exception_group_slots,
31 ctx.types.type_type.to_owned(),
32 ctx,
33 )
34 .expect("Failed to create ExceptionGroup type with multiple inheritance")
35}
36
37pub fn exception_group() -> &'static Py<PyType> {
38 ::rustpython_vm::common::static_cell! {
39 static CELL: ::rustpython_vm::builtins::PyTypeRef;
40 }
41 CELL.get_or_init(|| create_exception_group(Context::genesis()))
42}
43
44pub(super) mod types {
45 use super::*;
46 use crate::PyPayload;
47 use crate::builtins::PyGenericAlias;
48 use crate::types::{Constructor, Initializer};
49
50 #[pyexception(name, base = PyBaseException, ctx = "base_exception_group")]
51 #[derive(Debug)]
52 #[repr(transparent)]
53 pub struct PyBaseExceptionGroup(PyBaseException);
54
55 #[pyexception(with(Constructor, Initializer))]
56 impl PyBaseExceptionGroup {
57 #[pyclassmethod]
58 fn __class_getitem__(
59 cls: PyTypeRef,
60 args: PyObjectRef,
61 vm: &VirtualMachine,
62 ) -> PyGenericAlias {
63 PyGenericAlias::from_args(cls, args, vm)
64 }
65
66 #[pymethod]
67 fn derive(
68 zelf: PyRef<PyBaseException>,
69 excs: PyObjectRef,
70 vm: &VirtualMachine,
71 ) -> PyResult {
72 let message = zelf.get_arg(0).unwrap_or_else(|| vm.ctx.new_str("").into());
73 vm.invoke_exception(
74 vm.ctx.exceptions.base_exception_group.to_owned(),
75 vec![message, excs],
76 )
77 .map(|e| e.into())
78 }
79
80 #[pymethod]
81 fn subgroup(
82 zelf: PyRef<PyBaseException>,
83 condition: PyObjectRef,
84 vm: &VirtualMachine,
85 ) -> PyResult {
86 let matcher = get_condition_matcher(&condition, vm)?;
87
88 let zelf_obj: PyObjectRef = zelf.clone().into();
90 if matcher.check(&zelf_obj, vm)? {
91 return Ok(zelf_obj);
92 }
93
94 let exceptions = get_exceptions_tuple(&zelf, vm)?;
95 let mut matching: Vec<PyObjectRef> = Vec::new();
96 let mut modified = false;
97
98 for exc in exceptions {
99 if is_base_exception_group(&exc, vm) {
100 let subgroup_result = vm.call_method(&exc, "subgroup", (condition.clone(),))?;
102 if !vm.is_none(&subgroup_result) {
103 matching.push(subgroup_result.clone());
104 }
105 if !subgroup_result.is(&exc) {
106 modified = true;
107 }
108 } else if matcher.check(&exc, vm)? {
109 matching.push(exc);
110 } else {
111 modified = true;
112 }
113 }
114
115 if !modified {
116 return Ok(zelf.clone().into());
117 }
118
119 if matching.is_empty() {
120 return Ok(vm.ctx.none());
121 }
122
123 derive_and_copy_attributes(&zelf, matching, vm)
125 }
126
127 #[pymethod]
128 fn split(
129 zelf: PyRef<PyBaseException>,
130 condition: PyObjectRef,
131 vm: &VirtualMachine,
132 ) -> PyResult<PyTupleRef> {
133 let matcher = get_condition_matcher(&condition, vm)?;
134
135 let zelf_obj: PyObjectRef = zelf.clone().into();
137 if matcher.check(&zelf_obj, vm)? {
138 return Ok(vm.ctx.new_tuple(vec![zelf_obj, vm.ctx.none()]));
139 }
140
141 let exceptions = get_exceptions_tuple(&zelf, vm)?;
142 let mut matching: Vec<PyObjectRef> = Vec::new();
143 let mut rest: Vec<PyObjectRef> = Vec::new();
144
145 for exc in exceptions {
146 if is_base_exception_group(&exc, vm) {
147 let result = vm.call_method(&exc, "split", (condition.clone(),))?;
148 let result_tuple: PyTupleRef = result.try_into_value(vm)?;
149 let match_part = result_tuple
150 .first()
151 .cloned()
152 .unwrap_or_else(|| vm.ctx.none());
153 let rest_part = result_tuple
154 .get(1)
155 .cloned()
156 .unwrap_or_else(|| vm.ctx.none());
157
158 if !vm.is_none(&match_part) {
159 matching.push(match_part);
160 }
161 if !vm.is_none(&rest_part) {
162 rest.push(rest_part);
163 }
164 } else if matcher.check(&exc, vm)? {
165 matching.push(exc);
166 } else {
167 rest.push(exc);
168 }
169 }
170
171 let match_group = if matching.is_empty() {
172 vm.ctx.none()
173 } else {
174 derive_and_copy_attributes(&zelf, matching, vm)?
175 };
176
177 let rest_group = if rest.is_empty() {
178 vm.ctx.none()
179 } else {
180 derive_and_copy_attributes(&zelf, rest, vm)?
181 };
182
183 Ok(vm.ctx.new_tuple(vec![match_group, rest_group]))
184 }
185
186 #[pymethod]
187 fn __str__(zelf: &Py<PyBaseException>, vm: &VirtualMachine) -> PyResult<PyStrRef> {
188 let message = zelf.get_arg(0).map(|m| m.str(vm)).transpose()?;
189
190 let num_excs = zelf
191 .get_arg(1)
192 .and_then(|obj| obj.downcast_ref::<PyTuple>().map(|t| t.len()))
193 .unwrap_or(0);
194
195 let suffix = if num_excs == 1 { "" } else { "s" };
196 let mut result = match message {
197 Some(s) => s.as_wtf8().to_owned(),
198 None => Wtf8Buf::new(),
199 };
200 write!(result, " ({num_excs} sub-exception{suffix})").unwrap();
201 Ok(vm.ctx.new_str(result))
202 }
203
204 #[pyslot]
205 fn slot_repr(zelf: &PyObject, vm: &VirtualMachine) -> PyResult<PyStrRef> {
206 let zelf = zelf
207 .downcast_ref::<PyBaseException>()
208 .expect("exception group must be BaseException");
209 let class_name = zelf.class().name().to_owned();
210 let message = zelf.get_arg(0).map(|m| m.repr(vm)).transpose()?;
211
212 let mut result = Wtf8Buf::new();
213 write!(result, "{class_name}(").unwrap();
214 let message_wtf8: &Wtf8 = message.as_ref().map_or("''".as_ref(), |s| s.as_wtf8());
215 result.push_wtf8(message_wtf8);
216 result.push_str(", [");
217 if let Some(exceptions_obj) = zelf.get_arg(1) {
218 let iter: ArgIterable<PyObjectRef> =
219 ArgIterable::try_from_object(vm, exceptions_obj.clone())?;
220 let mut first = true;
221 for exc in iter.iter(vm)? {
222 if !first {
223 result.push_str(", ");
224 }
225 first = false;
226 result.push_wtf8(exc?.repr(vm)?.as_wtf8());
227 }
228 }
229 result.push_str("])");
230
231 Ok(vm.ctx.new_str(result))
232 }
233 }
234
235 impl Constructor for PyBaseExceptionGroup {
236 type Args = crate::function::PosArgs;
237
238 fn slot_new(cls: PyTypeRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult {
239 let args: Self::Args = args.bind(vm)?;
240 let args = args.into_vec();
241 if args.len() != 2 {
243 return Err(vm.new_type_error(format!(
244 "BaseExceptionGroup.__new__() takes exactly 2 positional arguments ({} given)",
245 args.len()
246 )));
247 }
248
249 let message = args[0].clone();
251 if !message.fast_isinstance(vm.ctx.types.str_type) {
252 return Err(vm.new_type_error(format!(
253 "argument 1 must be str, not {}",
254 message.class().name()
255 )));
256 }
257
258 let exceptions_arg = &args[1];
260
261 if exceptions_arg.fast_isinstance(vm.ctx.types.set_type)
263 || exceptions_arg.fast_isinstance(vm.ctx.types.frozenset_type)
264 {
265 return Err(vm.new_type_error("second argument (exceptions) must be a sequence"));
266 }
267
268 if exceptions_arg.is(&vm.ctx.none) {
270 return Err(vm.new_type_error("second argument (exceptions) must be a sequence"));
271 }
272
273 let exceptions: Vec<PyObjectRef> = exceptions_arg.try_to_value(vm).map_err(|_| {
274 vm.new_type_error("second argument (exceptions) must be a sequence")
275 })?;
276
277 if exceptions.is_empty() {
279 return Err(
280 vm.new_value_error("second argument (exceptions) must be a non-empty sequence")
281 );
282 }
283
284 let mut has_non_exception = false;
286 for (i, exc) in exceptions.iter().enumerate() {
287 if !exc.fast_isinstance(vm.ctx.exceptions.base_exception_type) {
288 return Err(vm.new_value_error(format!(
289 "Item {} of second argument (exceptions) is not an exception",
290 i
291 )));
292 }
293 if !exc.fast_isinstance(vm.ctx.exceptions.exception_type) {
297 has_non_exception = true;
298 }
299 }
300
301 let exception_group_type = crate::exception_group::exception_group();
303
304 let actual_cls = if cls.is(exception_group_type) {
306 if has_non_exception {
308 return Err(
309 vm.new_type_error("Cannot nest BaseExceptions in an ExceptionGroup")
310 );
311 }
312 cls
313 } else if cls.is(vm.ctx.exceptions.base_exception_group) {
314 if !has_non_exception {
316 exception_group_type.to_owned()
317 } else {
318 cls
319 }
320 } else {
321 if has_non_exception && cls.fast_issubclass(vm.ctx.exceptions.exception_type) {
323 return Err(vm.new_type_error(format!(
324 "Cannot nest BaseExceptions in '{}'",
325 cls.name()
326 )));
327 }
328 cls
329 };
330
331 let exceptions_tuple = vm.ctx.new_tuple(exceptions);
333 let init_args = vec![message, exceptions_tuple.into()];
334 PyBaseException::new(init_args, vm)
335 .into_ref_with_type(vm, actual_cls)
336 .map(Into::into)
337 }
338
339 fn py_new(_cls: &Py<PyType>, _args: Self::Args, _vm: &VirtualMachine) -> PyResult<Self> {
340 unimplemented!("use slot_new")
341 }
342 }
343
344 impl Initializer for PyBaseExceptionGroup {
345 type Args = FuncArgs;
346
347 fn slot_init(zelf: PyObjectRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> {
348 if !args.kwargs.is_empty() {
350 return Err(vm.new_type_error(format!(
351 "{} does not take keyword arguments",
352 zelf.class().name()
353 )));
354 }
355 let _ = (zelf, args, vm);
359 Ok(())
360 }
361
362 fn init(_zelf: PyRef<Self>, _args: Self::Args, _vm: &VirtualMachine) -> PyResult<()> {
363 unreachable!("slot_init is overridden")
364 }
365 }
366
367 fn is_base_exception_group(obj: &PyObject, vm: &VirtualMachine) -> bool {
369 obj.fast_isinstance(vm.ctx.exceptions.base_exception_group)
370 }
371
372 fn get_exceptions_tuple(
373 exc: &Py<PyBaseException>,
374 vm: &VirtualMachine,
375 ) -> PyResult<Vec<PyObjectRef>> {
376 let obj = exc
377 .get_arg(1)
378 .ok_or_else(|| vm.new_type_error("exceptions must be a tuple"))?;
379 let tuple = obj
380 .downcast_ref::<PyTuple>()
381 .ok_or_else(|| vm.new_type_error("exceptions must be a tuple"))?;
382 Ok(tuple.to_vec())
383 }
384
385 enum ConditionMatcher {
386 Type(PyTypeRef),
387 Types(Vec<PyTypeRef>),
388 Callable(PyObjectRef),
389 }
390
391 fn get_condition_matcher(
392 condition: &PyObject,
393 vm: &VirtualMachine,
394 ) -> PyResult<ConditionMatcher> {
395 if let Some(typ) = condition.downcast_ref::<PyType>()
397 && typ.fast_issubclass(vm.ctx.exceptions.base_exception_type)
398 {
399 return Ok(ConditionMatcher::Type(typ.to_owned()));
400 }
401
402 if let Some(tuple) = condition.downcast_ref::<PyTuple>() {
404 let mut types = Vec::new();
405 for item in tuple.iter() {
406 let typ: PyTypeRef = item.clone().try_into_value(vm).map_err(|_| {
407 vm.new_type_error(
408 "expected a function, exception type or tuple of exception types",
409 )
410 })?;
411 if !typ.fast_issubclass(vm.ctx.exceptions.base_exception_type) {
412 return Err(vm.new_type_error(
413 "expected a function, exception type or tuple of exception types",
414 ));
415 }
416 types.push(typ);
417 }
418 if !types.is_empty() {
419 return Ok(ConditionMatcher::Types(types));
420 }
421 }
422
423 if condition.is_callable() && condition.downcast_ref::<PyType>().is_none() {
425 return Ok(ConditionMatcher::Callable(condition.to_owned()));
426 }
427
428 Err(vm.new_type_error("expected a function, exception type or tuple of exception types"))
429 }
430
431 impl ConditionMatcher {
432 fn check(&self, exc: &PyObject, vm: &VirtualMachine) -> PyResult<bool> {
433 match self {
434 ConditionMatcher::Type(typ) => Ok(exc.fast_isinstance(typ)),
435 ConditionMatcher::Types(types) => Ok(types.iter().any(|t| exc.fast_isinstance(t))),
436 ConditionMatcher::Callable(func) => {
437 let result = func.call((exc.to_owned(),), vm)?;
438 result.try_to_bool(vm)
439 }
440 }
441 }
442 }
443
444 fn derive_and_copy_attributes(
445 orig: &Py<PyBaseException>,
446 excs: Vec<PyObjectRef>,
447 vm: &VirtualMachine,
448 ) -> PyResult<PyObjectRef> {
449 let excs_seq = vm.ctx.new_list(excs);
451 let new_group = vm.call_method(orig.as_object(), "derive", (excs_seq,))?;
452
453 if !is_base_exception_group(&new_group, vm) {
455 return Err(vm.new_type_error("derive must return an instance of BaseExceptionGroup"));
456 }
457
458 if let Some(tb) = orig.__traceback__() {
460 new_group.set_attr("__traceback__", tb, vm)?;
461 }
462
463 if let Some(ctx) = orig.__context__() {
465 new_group.set_attr("__context__", ctx, vm)?;
466 }
467
468 if let Some(cause) = orig.__cause__() {
470 new_group.set_attr("__cause__", cause, vm)?;
471 }
472
473 if let Ok(notes) = orig.as_object().get_attr("__notes__", vm)
475 && let Some(notes_list) = notes.downcast_ref::<PyList>()
476 {
477 let notes_copy = vm.ctx.new_list(notes_list.borrow_vec().to_vec());
478 new_group.set_attr("__notes__", notes_copy, vm)?;
479 }
480
481 Ok(new_group)
482 }
483}