1use std::any::Any;
2use std::panic::{UnwindSafe, catch_unwind};
3use std::sync::atomic::{AtomicU64, Ordering};
4
5pub struct KeyboardInterrupt;
9
10static POLARS_KEYBOARD_INTERRUPT_STRING: &str = "__POLARS_KEYBOARD_INTERRUPT";
12
13static INTERRUPT_STATE: AtomicU64 = AtomicU64::new(0);
16
17fn is_keyboard_interrupt(p: &dyn Any) -> bool {
18 if let Some(s) = p.downcast_ref::<&str>() {
19 s.contains(POLARS_KEYBOARD_INTERRUPT_STRING)
20 } else if let Some(s) = p.downcast_ref::<String>() {
21 s.contains(POLARS_KEYBOARD_INTERRUPT_STRING)
22 } else {
23 false
24 }
25}
26
27pub fn register_polars_keyboard_interrupt_hook() {
28 let default_hook = std::panic::take_hook();
29 std::panic::set_hook(Box::new(move |p| {
30 let num_catchers = INTERRUPT_STATE.load(Ordering::Relaxed) >> 1;
33 let suppress = num_catchers > 0 && is_keyboard_interrupt(p.payload());
34 if !suppress {
35 default_hook(p);
36 }
37 }));
38
39 #[cfg(not(target_family = "wasm"))]
41 unsafe {
42 signal_hook::low_level::register(signal_hook::consts::signal::SIGINT, move || {
44 INTERRUPT_STATE
46 .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |state| {
47 let num_catchers = state >> 1;
48 if num_catchers > 0 {
49 Some(state | 1)
50 } else {
51 None
52 }
53 })
54 .ok();
55 })
56 .unwrap();
57 }
58}
59
60#[inline(always)]
63pub fn try_raise_keyboard_interrupt() {
64 if INTERRUPT_STATE.load(Ordering::Relaxed) & 1 != 0 {
65 try_raise_keyboard_interrupt_slow()
66 }
67}
68
69#[inline(never)]
70#[cold]
71fn try_raise_keyboard_interrupt_slow() {
72 std::panic::panic_any(POLARS_KEYBOARD_INTERRUPT_STRING);
73}
74
75pub fn catch_keyboard_interrupt<R, F: FnOnce() -> R + UnwindSafe>(
78 try_fn: F,
79) -> Result<R, KeyboardInterrupt> {
80 try_register_catcher()?;
83 let ret = catch_unwind(try_fn);
84 unregister_catcher();
85 ret.map_err(|p| {
86 if is_keyboard_interrupt(&*p) {
87 KeyboardInterrupt
88 } else {
89 std::panic::resume_unwind(p)
90 }
91 })
92}
93
94fn try_register_catcher() -> Result<(), KeyboardInterrupt> {
95 let old_state = INTERRUPT_STATE.fetch_add(2, Ordering::Relaxed);
96 if old_state & 1 != 0 {
97 unregister_catcher();
98 return Err(KeyboardInterrupt);
99 }
100 Ok(())
101}
102
103fn unregister_catcher() {
104 INTERRUPT_STATE
105 .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |state| {
106 let num_catchers = state >> 1;
107 if num_catchers > 1 {
108 Some(state - 2)
109 } else {
110 Some(0)
112 }
113 })
114 .ok();
115}