1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
use std::panic::{catch_unwind, resume_unwind};

/// Perform `task`. On error or panic, perform `cleanup` and return the original
/// error, or continue to panic.
///
/// This is resilient to errors and panics in `cleanup` too: they will be
/// logged, but ultimately the errors and panics from `task` will be propagated.
pub fn with_cleanup<TASK, T, E, CLEANUP, CT, CE>(cleanup: CLEANUP, task: TASK) -> Result<T, E>
where
    TASK: std::panic::UnwindSafe + FnOnce() -> Result<T, E>,
    CLEANUP: std::panic::UnwindSafe + FnOnce() -> Result<CT, CE>,
    E: std::fmt::Display,
    CE: std::fmt::Display + Into<E>,
{
    match catch_unwind(task) {
        Ok(Ok(t)) => Ok(t),
        Ok(Err(e)) => match catch_unwind(cleanup) {
            Ok(Ok(_)) => Err(e),
            Ok(Err(ce)) => {
                log::error!("Task failed & cleaning-up also failed: {ce}");
                Err(e)
            }
            Err(_) => {
                log::error!("Task failed & cleaning-up panicked (suppressed)");
                Err(e)
            }
        },
        Err(panic) => match catch_unwind(cleanup) {
            Ok(Ok(_)) => resume_unwind(panic),
            Ok(Err(ce)) => {
                log::error!("Task panicked & cleaning-up failed: {ce}");
                resume_unwind(panic)
            }
            Err(_) => {
                log::error!("Task panicked & cleaning-up also panicked (suppressed)");
                resume_unwind(panic)
            }
        },
    }
}

#[cfg(test)]
mod tests {
    use super::with_cleanup;

    #[test]
    fn test_with_cleanup() {
        let result: Result<&'static str, &'static str> = with_cleanup(
            || Ok::<&'static str, &'static str>("Ok/cleanup"),
            || Ok("Ok/task"),
        );
        assert!(matches!(result, Ok("Ok/task")));
    }

    #[test]
    fn test_with_cleanup_error_in_task() {
        let result: Result<(), &'static str> =
            with_cleanup(|| Ok::<(), &'static str>(()), || Err("Err/task")?);
        assert!(matches!(result, Err("Err/task")));
    }

    #[test]
    #[should_panic(expected = "Panic/task")]
    fn test_with_cleanup_panic_in_task() {
        let _result: Result<(), &'static str> =
            with_cleanup(|| Ok::<(), &'static str>(()), || panic!("Panic/task"));
    }

    #[test]
    fn test_with_cleanup_error_in_cleanup() {
        let result: Result<(), &'static str> =
            with_cleanup(|| Err::<(), &'static str>("Err/cleanup"), || Ok(()));
        assert!(matches!(result, Ok(())));
    }

    #[test]
    fn test_with_cleanup_panic_in_cleanup() {
        let result: Result<(), &'static str> = with_cleanup(
            || -> Result<(), &'static str> { panic!("Panic/cleanup") },
            || Ok(()),
        );
        assert!(matches!(result, Ok(())));
    }

    #[test]
    fn test_with_cleanup_error_in_task_and_cleanup() {
        let result: Result<(), &'static str> = with_cleanup(
            || Err::<(), &'static str>("Err/cleanup"),
            || Err("Err/task")?,
        );
        assert!(matches!(result, Err("Err/task")));
    }

    #[test]
    #[should_panic(expected = "Panic/task")]
    fn test_with_cleanup_panic_in_task_and_cleanup() {
        let _result: Result<(), &'static str> = with_cleanup(
            || -> Result<(), &'static str> { panic!("Panic/cleanup") },
            || panic!("Panic/task"),
        );
    }
}