polars_error/
signals.rs

1use std::any::Any;
2use std::panic::{UnwindSafe, catch_unwind};
3use std::sync::atomic::{AtomicU64, Ordering};
4
5/// Python hooks SIGINT to instead generate a KeyboardInterrupt exception.
6/// So we do the same to try and abort long-running computations and return to
7/// Python so that the Python exception can be generated.
8pub struct KeyboardInterrupt;
9
10// We use a unique string so we can detect it in backtraces.
11static POLARS_KEYBOARD_INTERRUPT_STRING: &str = "__POLARS_KEYBOARD_INTERRUPT";
12
13// Bottom bit: interrupt flag.
14// Top 63 bits: number of alive interrupt catchers.
15static INTERRUPT_STATE: AtomicU64 = AtomicU64::new(0);
16
17fn is_keyboard_interrupt(p: &dyn Any) -> bool {
18    if let Some(s) = p.downcast_ref::<&str>() {
19        s.contains(POLARS_KEYBOARD_INTERRUPT_STRING)
20    } else if let Some(s) = p.downcast_ref::<String>() {
21        s.contains(POLARS_KEYBOARD_INTERRUPT_STRING)
22    } else {
23        false
24    }
25}
26
27pub fn register_polars_keyboard_interrupt_hook() {
28    let default_hook = std::panic::take_hook();
29    std::panic::set_hook(Box::new(move |p| {
30        // Suppress output if there is an active catcher and the panic message
31        // contains the keyboard interrupt string.
32        let num_catchers = INTERRUPT_STATE.load(Ordering::Relaxed) >> 1;
33        let suppress = num_catchers > 0 && is_keyboard_interrupt(p.payload());
34        if !suppress {
35            default_hook(p);
36        }
37    }));
38
39    // WASM doesn't support signals, so we just skip installing the hook there.
40    #[cfg(not(target_family = "wasm"))]
41    unsafe {
42        // SAFETY: we only do an atomic op in the signal handler, which is allowed.
43        signal_hook::low_level::register(signal_hook::consts::signal::SIGINT, move || {
44            // Set the interrupt flag, but only if there are active catchers.
45            INTERRUPT_STATE
46                .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |state| {
47                    let num_catchers = state >> 1;
48                    if num_catchers > 0 {
49                        Some(state | 1)
50                    } else {
51                        None
52                    }
53                })
54                .ok();
55        })
56        .unwrap();
57    }
58}
59
60/// Checks if the keyboard interrupt flag is set, and if yes panics as a
61/// keyboard interrupt. This function is very cheap.
62#[inline(always)]
63pub fn try_raise_keyboard_interrupt() {
64    if INTERRUPT_STATE.load(Ordering::Relaxed) & 1 != 0 {
65        try_raise_keyboard_interrupt_slow()
66    }
67}
68
69#[inline(never)]
70#[cold]
71fn try_raise_keyboard_interrupt_slow() {
72    std::panic::panic_any(POLARS_KEYBOARD_INTERRUPT_STRING);
73}
74
75/// Runs the passed function, catching any KeyboardInterrupts if they occur
76/// while running the function.
77pub fn catch_keyboard_interrupt<R, F: FnOnce() -> R + UnwindSafe>(
78    try_fn: F,
79) -> Result<R, KeyboardInterrupt> {
80    // Try to register this catcher (or immediately return if there is an
81    // uncaught interrupt).
82    try_register_catcher()?;
83    let ret = catch_unwind(try_fn);
84    unregister_catcher();
85    ret.map_err(|p| {
86        if is_keyboard_interrupt(&*p) {
87            KeyboardInterrupt
88        } else {
89            std::panic::resume_unwind(p)
90        }
91    })
92}
93
94fn try_register_catcher() -> Result<(), KeyboardInterrupt> {
95    let old_state = INTERRUPT_STATE.fetch_add(2, Ordering::Relaxed);
96    if old_state & 1 != 0 {
97        unregister_catcher();
98        return Err(KeyboardInterrupt);
99    }
100    Ok(())
101}
102
103fn unregister_catcher() {
104    INTERRUPT_STATE
105        .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |state| {
106            let num_catchers = state >> 1;
107            if num_catchers > 1 {
108                Some(state - 2)
109            } else {
110                // Last catcher, clear interrupt flag.
111                Some(0)
112            }
113        })
114        .ok();
115}