Skip to main content

rustpython_vm/
exception_group.rs

1//! ExceptionGroup implementation for Python 3.11+
2//!
3//! This module implements BaseExceptionGroup and ExceptionGroup with multiple inheritance support.
4
5use crate::builtins::{PyList, PyStrRef, PyTuple, PyTupleRef, PyType, PyTypeRef};
6use crate::function::{ArgIterable, FuncArgs};
7use crate::types::{PyTypeFlags, PyTypeSlots};
8use crate::{
9    AsObject, Context, Py, PyObject, PyObjectRef, PyRef, PyResult, TryFromObject, VirtualMachine,
10};
11use core::fmt::Write;
12use rustpython_common::wtf8::{Wtf8, Wtf8Buf};
13
14use crate::exceptions::types::PyBaseException;
15
16/// Create dynamic ExceptionGroup type with multiple inheritance
17fn create_exception_group(ctx: &Context) -> PyRef<PyType> {
18    let excs = &ctx.exceptions;
19    let exception_group_slots = PyTypeSlots {
20        flags: PyTypeFlags::heap_type_flags() | PyTypeFlags::HAS_DICT,
21        ..Default::default()
22    };
23    PyType::new_heap(
24        "ExceptionGroup",
25        vec![
26            excs.base_exception_group.to_owned(),
27            excs.exception_type.to_owned(),
28        ],
29        Default::default(),
30        exception_group_slots,
31        ctx.types.type_type.to_owned(),
32        ctx,
33    )
34    .expect("Failed to create ExceptionGroup type with multiple inheritance")
35}
36
37pub fn exception_group() -> &'static Py<PyType> {
38    ::rustpython_vm::common::static_cell! {
39        static CELL: ::rustpython_vm::builtins::PyTypeRef;
40    }
41    CELL.get_or_init(|| create_exception_group(Context::genesis()))
42}
43
44pub(super) mod types {
45    use super::*;
46    use crate::PyPayload;
47    use crate::builtins::PyGenericAlias;
48    use crate::types::{Constructor, Initializer};
49
50    #[pyexception(name, base = PyBaseException, ctx = "base_exception_group")]
51    #[derive(Debug)]
52    #[repr(transparent)]
53    pub struct PyBaseExceptionGroup(PyBaseException);
54
55    #[pyexception(with(Constructor, Initializer))]
56    impl PyBaseExceptionGroup {
57        #[pyclassmethod]
58        fn __class_getitem__(
59            cls: PyTypeRef,
60            args: PyObjectRef,
61            vm: &VirtualMachine,
62        ) -> PyGenericAlias {
63            PyGenericAlias::from_args(cls, args, vm)
64        }
65
66        #[pymethod]
67        fn derive(
68            zelf: PyRef<PyBaseException>,
69            excs: PyObjectRef,
70            vm: &VirtualMachine,
71        ) -> PyResult {
72            let message = zelf.get_arg(0).unwrap_or_else(|| vm.ctx.new_str("").into());
73            vm.invoke_exception(
74                vm.ctx.exceptions.base_exception_group.to_owned(),
75                vec![message, excs],
76            )
77            .map(|e| e.into())
78        }
79
80        #[pymethod]
81        fn subgroup(
82            zelf: PyRef<PyBaseException>,
83            condition: PyObjectRef,
84            vm: &VirtualMachine,
85        ) -> PyResult {
86            let matcher = get_condition_matcher(&condition, vm)?;
87
88            // If self matches the condition entirely, return self
89            let zelf_obj: PyObjectRef = zelf.clone().into();
90            if matcher.check(&zelf_obj, vm)? {
91                return Ok(zelf_obj);
92            }
93
94            let exceptions = get_exceptions_tuple(&zelf, vm)?;
95            let mut matching: Vec<PyObjectRef> = Vec::new();
96            let mut modified = false;
97
98            for exc in exceptions {
99                if is_base_exception_group(&exc, vm) {
100                    // Recursive call for nested groups
101                    let subgroup_result = vm.call_method(&exc, "subgroup", (condition.clone(),))?;
102                    if !vm.is_none(&subgroup_result) {
103                        matching.push(subgroup_result.clone());
104                    }
105                    if !subgroup_result.is(&exc) {
106                        modified = true;
107                    }
108                } else if matcher.check(&exc, vm)? {
109                    matching.push(exc);
110                } else {
111                    modified = true;
112                }
113            }
114
115            if !modified {
116                return Ok(zelf.clone().into());
117            }
118
119            if matching.is_empty() {
120                return Ok(vm.ctx.none());
121            }
122
123            // Create new group with matching exceptions and copy metadata
124            derive_and_copy_attributes(&zelf, matching, vm)
125        }
126
127        #[pymethod]
128        fn split(
129            zelf: PyRef<PyBaseException>,
130            condition: PyObjectRef,
131            vm: &VirtualMachine,
132        ) -> PyResult<PyTupleRef> {
133            let matcher = get_condition_matcher(&condition, vm)?;
134
135            // If self matches the condition entirely
136            let zelf_obj: PyObjectRef = zelf.clone().into();
137            if matcher.check(&zelf_obj, vm)? {
138                return Ok(vm.ctx.new_tuple(vec![zelf_obj, vm.ctx.none()]));
139            }
140
141            let exceptions = get_exceptions_tuple(&zelf, vm)?;
142            let mut matching: Vec<PyObjectRef> = Vec::new();
143            let mut rest: Vec<PyObjectRef> = Vec::new();
144
145            for exc in exceptions {
146                if is_base_exception_group(&exc, vm) {
147                    let result = vm.call_method(&exc, "split", (condition.clone(),))?;
148                    let result_tuple: PyTupleRef = result.try_into_value(vm)?;
149                    let match_part = result_tuple
150                        .first()
151                        .cloned()
152                        .unwrap_or_else(|| vm.ctx.none());
153                    let rest_part = result_tuple
154                        .get(1)
155                        .cloned()
156                        .unwrap_or_else(|| vm.ctx.none());
157
158                    if !vm.is_none(&match_part) {
159                        matching.push(match_part);
160                    }
161                    if !vm.is_none(&rest_part) {
162                        rest.push(rest_part);
163                    }
164                } else if matcher.check(&exc, vm)? {
165                    matching.push(exc);
166                } else {
167                    rest.push(exc);
168                }
169            }
170
171            let match_group = if matching.is_empty() {
172                vm.ctx.none()
173            } else {
174                derive_and_copy_attributes(&zelf, matching, vm)?
175            };
176
177            let rest_group = if rest.is_empty() {
178                vm.ctx.none()
179            } else {
180                derive_and_copy_attributes(&zelf, rest, vm)?
181            };
182
183            Ok(vm.ctx.new_tuple(vec![match_group, rest_group]))
184        }
185
186        #[pymethod]
187        fn __str__(zelf: &Py<PyBaseException>, vm: &VirtualMachine) -> PyResult<PyStrRef> {
188            let message = zelf.get_arg(0).map(|m| m.str(vm)).transpose()?;
189
190            let num_excs = zelf
191                .get_arg(1)
192                .and_then(|obj| obj.downcast_ref::<PyTuple>().map(|t| t.len()))
193                .unwrap_or(0);
194
195            let suffix = if num_excs == 1 { "" } else { "s" };
196            let mut result = match message {
197                Some(s) => s.as_wtf8().to_owned(),
198                None => Wtf8Buf::new(),
199            };
200            write!(result, " ({num_excs} sub-exception{suffix})").unwrap();
201            Ok(vm.ctx.new_str(result))
202        }
203
204        #[pyslot]
205        fn slot_repr(zelf: &PyObject, vm: &VirtualMachine) -> PyResult<PyStrRef> {
206            let zelf = zelf
207                .downcast_ref::<PyBaseException>()
208                .expect("exception group must be BaseException");
209            let class_name = zelf.class().name().to_owned();
210            let message = zelf.get_arg(0).map(|m| m.repr(vm)).transpose()?;
211
212            let mut result = Wtf8Buf::new();
213            write!(result, "{class_name}(").unwrap();
214            let message_wtf8: &Wtf8 = message.as_ref().map_or("''".as_ref(), |s| s.as_wtf8());
215            result.push_wtf8(message_wtf8);
216            result.push_str(", [");
217            if let Some(exceptions_obj) = zelf.get_arg(1) {
218                let iter: ArgIterable<PyObjectRef> =
219                    ArgIterable::try_from_object(vm, exceptions_obj.clone())?;
220                let mut first = true;
221                for exc in iter.iter(vm)? {
222                    if !first {
223                        result.push_str(", ");
224                    }
225                    first = false;
226                    result.push_wtf8(exc?.repr(vm)?.as_wtf8());
227                }
228            }
229            result.push_str("])");
230
231            Ok(vm.ctx.new_str(result))
232        }
233    }
234
235    impl Constructor for PyBaseExceptionGroup {
236        type Args = crate::function::PosArgs;
237
238        fn slot_new(cls: PyTypeRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult {
239            let args: Self::Args = args.bind(vm)?;
240            let args = args.into_vec();
241            // Validate exactly 2 positional arguments
242            if args.len() != 2 {
243                return Err(vm.new_type_error(format!(
244                    "BaseExceptionGroup.__new__() takes exactly 2 positional arguments ({} given)",
245                    args.len()
246                )));
247            }
248
249            // Validate message is str
250            let message = args[0].clone();
251            if !message.fast_isinstance(vm.ctx.types.str_type) {
252                return Err(vm.new_type_error(format!(
253                    "argument 1 must be str, not {}",
254                    message.class().name()
255                )));
256            }
257
258            // Validate exceptions is a sequence (not set or None)
259            let exceptions_arg = &args[1];
260
261            // Check for set/frozenset (not a sequence - unordered)
262            if exceptions_arg.fast_isinstance(vm.ctx.types.set_type)
263                || exceptions_arg.fast_isinstance(vm.ctx.types.frozenset_type)
264            {
265                return Err(vm.new_type_error("second argument (exceptions) must be a sequence"));
266            }
267
268            // Check for None
269            if exceptions_arg.is(&vm.ctx.none) {
270                return Err(vm.new_type_error("second argument (exceptions) must be a sequence"));
271            }
272
273            let exceptions: Vec<PyObjectRef> = exceptions_arg.try_to_value(vm).map_err(|_| {
274                vm.new_type_error("second argument (exceptions) must be a sequence")
275            })?;
276
277            // Validate non-empty
278            if exceptions.is_empty() {
279                return Err(
280                    vm.new_value_error("second argument (exceptions) must be a non-empty sequence")
281                );
282            }
283
284            // Validate all items are BaseException instances
285            let mut has_non_exception = false;
286            for (i, exc) in exceptions.iter().enumerate() {
287                if !exc.fast_isinstance(vm.ctx.exceptions.base_exception_type) {
288                    return Err(vm.new_value_error(format!(
289                        "Item {} of second argument (exceptions) is not an exception",
290                        i
291                    )));
292                }
293                // Check if any exception is not an Exception subclass
294                // With dynamic ExceptionGroup (inherits from both BaseExceptionGroup and Exception),
295                // ExceptionGroup instances are automatically instances of Exception
296                if !exc.fast_isinstance(vm.ctx.exceptions.exception_type) {
297                    has_non_exception = true;
298                }
299            }
300
301            // Get the dynamic ExceptionGroup type
302            let exception_group_type = crate::exception_group::exception_group();
303
304            // Determine the actual class to use
305            let actual_cls = if cls.is(exception_group_type) {
306                // ExceptionGroup cannot contain BaseExceptions that are not Exception
307                if has_non_exception {
308                    return Err(
309                        vm.new_type_error("Cannot nest BaseExceptions in an ExceptionGroup")
310                    );
311                }
312                cls
313            } else if cls.is(vm.ctx.exceptions.base_exception_group) {
314                // Auto-convert to ExceptionGroup if all are Exception subclasses
315                if !has_non_exception {
316                    exception_group_type.to_owned()
317                } else {
318                    cls
319                }
320            } else {
321                // User-defined subclass
322                if has_non_exception && cls.fast_issubclass(vm.ctx.exceptions.exception_type) {
323                    return Err(vm.new_type_error(format!(
324                        "Cannot nest BaseExceptions in '{}'",
325                        cls.name()
326                    )));
327                }
328                cls
329            };
330
331            // Create the exception with (message, exceptions_tuple) as args
332            let exceptions_tuple = vm.ctx.new_tuple(exceptions);
333            let init_args = vec![message, exceptions_tuple.into()];
334            PyBaseException::new(init_args, vm)
335                .into_ref_with_type(vm, actual_cls)
336                .map(Into::into)
337        }
338
339        fn py_new(_cls: &Py<PyType>, _args: Self::Args, _vm: &VirtualMachine) -> PyResult<Self> {
340            unimplemented!("use slot_new")
341        }
342    }
343
344    impl Initializer for PyBaseExceptionGroup {
345        type Args = FuncArgs;
346
347        fn slot_init(zelf: PyObjectRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> {
348            // BaseExceptionGroup_init: no kwargs allowed
349            if !args.kwargs.is_empty() {
350                return Err(vm.new_type_error(format!(
351                    "{} does not take keyword arguments",
352                    zelf.class().name()
353                )));
354            }
355            // Do NOT call PyBaseException::slot_init here.
356            // slot_new already set args to (message, exceptions_tuple).
357            // Calling base init would overwrite with original args (message, exceptions_list).
358            let _ = (zelf, args, vm);
359            Ok(())
360        }
361
362        fn init(_zelf: PyRef<Self>, _args: Self::Args, _vm: &VirtualMachine) -> PyResult<()> {
363            unreachable!("slot_init is overridden")
364        }
365    }
366
367    // Helper functions for ExceptionGroup
368    fn is_base_exception_group(obj: &PyObject, vm: &VirtualMachine) -> bool {
369        obj.fast_isinstance(vm.ctx.exceptions.base_exception_group)
370    }
371
372    fn get_exceptions_tuple(
373        exc: &Py<PyBaseException>,
374        vm: &VirtualMachine,
375    ) -> PyResult<Vec<PyObjectRef>> {
376        let obj = exc
377            .get_arg(1)
378            .ok_or_else(|| vm.new_type_error("exceptions must be a tuple"))?;
379        let tuple = obj
380            .downcast_ref::<PyTuple>()
381            .ok_or_else(|| vm.new_type_error("exceptions must be a tuple"))?;
382        Ok(tuple.to_vec())
383    }
384
385    enum ConditionMatcher {
386        Type(PyTypeRef),
387        Types(Vec<PyTypeRef>),
388        Callable(PyObjectRef),
389    }
390
391    fn get_condition_matcher(
392        condition: &PyObject,
393        vm: &VirtualMachine,
394    ) -> PyResult<ConditionMatcher> {
395        // If it's a type and subclass of BaseException
396        if let Some(typ) = condition.downcast_ref::<PyType>()
397            && typ.fast_issubclass(vm.ctx.exceptions.base_exception_type)
398        {
399            return Ok(ConditionMatcher::Type(typ.to_owned()));
400        }
401
402        // If it's a tuple of types
403        if let Some(tuple) = condition.downcast_ref::<PyTuple>() {
404            let mut types = Vec::new();
405            for item in tuple.iter() {
406                let typ: PyTypeRef = item.clone().try_into_value(vm).map_err(|_| {
407                    vm.new_type_error(
408                        "expected a function, exception type or tuple of exception types",
409                    )
410                })?;
411                if !typ.fast_issubclass(vm.ctx.exceptions.base_exception_type) {
412                    return Err(vm.new_type_error(
413                        "expected a function, exception type or tuple of exception types",
414                    ));
415                }
416                types.push(typ);
417            }
418            if !types.is_empty() {
419                return Ok(ConditionMatcher::Types(types));
420            }
421        }
422
423        // If it's callable (but not a type)
424        if condition.is_callable() && condition.downcast_ref::<PyType>().is_none() {
425            return Ok(ConditionMatcher::Callable(condition.to_owned()));
426        }
427
428        Err(vm.new_type_error("expected a function, exception type or tuple of exception types"))
429    }
430
431    impl ConditionMatcher {
432        fn check(&self, exc: &PyObject, vm: &VirtualMachine) -> PyResult<bool> {
433            match self {
434                ConditionMatcher::Type(typ) => Ok(exc.fast_isinstance(typ)),
435                ConditionMatcher::Types(types) => Ok(types.iter().any(|t| exc.fast_isinstance(t))),
436                ConditionMatcher::Callable(func) => {
437                    let result = func.call((exc.to_owned(),), vm)?;
438                    result.try_to_bool(vm)
439                }
440            }
441        }
442    }
443
444    fn derive_and_copy_attributes(
445        orig: &Py<PyBaseException>,
446        excs: Vec<PyObjectRef>,
447        vm: &VirtualMachine,
448    ) -> PyResult<PyObjectRef> {
449        // Call derive method to create new group
450        let excs_seq = vm.ctx.new_list(excs);
451        let new_group = vm.call_method(orig.as_object(), "derive", (excs_seq,))?;
452
453        // Verify derive returned a BaseExceptionGroup
454        if !is_base_exception_group(&new_group, vm) {
455            return Err(vm.new_type_error("derive must return an instance of BaseExceptionGroup"));
456        }
457
458        // Copy traceback
459        if let Some(tb) = orig.__traceback__() {
460            new_group.set_attr("__traceback__", tb, vm)?;
461        }
462
463        // Copy context
464        if let Some(ctx) = orig.__context__() {
465            new_group.set_attr("__context__", ctx, vm)?;
466        }
467
468        // Copy cause
469        if let Some(cause) = orig.__cause__() {
470            new_group.set_attr("__cause__", cause, vm)?;
471        }
472
473        // Copy notes (if present) - make a copy of the list
474        if let Ok(notes) = orig.as_object().get_attr("__notes__", vm)
475            && let Some(notes_list) = notes.downcast_ref::<PyList>()
476        {
477            let notes_copy = vm.ctx.new_list(notes_list.borrow_vec().to_vec());
478            new_group.set_attr("__notes__", notes_copy, vm)?;
479        }
480
481        Ok(new_group)
482    }
483}