Skip to main content

polars_python/
timeout.rs

1//! A global process-aborting timeout system, mainly intended for testing.
2
3use std::cmp::Reverse;
4use std::collections::BinaryHeap;
5use std::sync::LazyLock;
6use std::sync::mpsc::{Receiver, RecvTimeoutError, Sender, channel};
7use std::time::Duration;
8
9use polars::prelude::{InitHashMaps, PlHashSet};
10use polars_utils::priority::Priority;
11use polars_utils::relaxed_cell::RelaxedCell;
12
13static TIMEOUT_REQUEST_HANDLER: LazyLock<Sender<TimeoutRequest>> = LazyLock::new(|| {
14    let (send, recv) = channel();
15    std::thread::Builder::new()
16        .name("polars-timeout".to_string())
17        .spawn(move || timeout_thread(recv))
18        .unwrap();
19    send
20});
21
22enum TimeoutRequest {
23    Start(Duration, u64, Option<String>),
24    Cancel(u64),
25}
26
27pub fn is_timeout_enabled() -> bool {
28    static TIMEOUT_DISABLED: RelaxedCell<bool> = RelaxedCell::new_bool(false);
29
30    // Fast path so we don't have to keep checking environment variables. Make
31    // sure that if you want to use POLARS_TIMEOUT_MS it is set before the first
32    // polars call.
33    if TIMEOUT_DISABLED.load() {
34        return false;
35    }
36
37    let var = std::env::var("POLARS_TIMEOUT_MS").ok();
38    if var.is_none_or(|v| v.is_empty()) {
39        TIMEOUT_DISABLED.store(true);
40        return false;
41    }
42
43    true
44}
45
46pub fn get_timeout() -> Option<Duration> {
47    if !is_timeout_enabled() {
48        return None;
49    }
50
51    match std::env::var("POLARS_TIMEOUT_MS").unwrap().parse() {
52        Ok(ms) => Some(Duration::from_millis(ms)),
53        Err(e) => {
54            eprintln!("failed to parse POLARS_TIMEOUT_MS: {e:?}");
55            None
56        },
57    }
58}
59
60fn timeout_thread(recv: Receiver<TimeoutRequest>) {
61    let mut active_timeouts: PlHashSet<u64> = PlHashSet::new();
62    #[allow(clippy::type_complexity)]
63    let mut shortest_timeout: BinaryHeap<Priority<Reverse<Duration>, (u64, Option<String>)>> =
64        BinaryHeap::new();
65    loop {
66        // Remove cancelled requests.
67        while let Some(Priority(_, (id, _))) = shortest_timeout.peek() {
68            if active_timeouts.contains(id) {
69                break;
70            }
71            shortest_timeout.pop();
72        }
73
74        let request = if let Some(Priority(timeout, (_, traceback))) = shortest_timeout.peek() {
75            match recv.recv_timeout(timeout.0) {
76                Err(RecvTimeoutError::Timeout) => {
77                    eprint!("exiting the process, POLARS_TIMEOUT_MS exceeded");
78                    if let Some(tb) = traceback {
79                        eprintln!(", traceback:\n{tb}");
80                    } else {
81                        eprintln!(", traceback unavailable");
82                    }
83                    std::thread::sleep(Duration::from_secs_f64(1.0));
84                    std::process::exit(1);
85                },
86                r => r.unwrap(),
87            }
88        } else {
89            recv.recv().unwrap()
90        };
91
92        match request {
93            TimeoutRequest::Start(duration, id, traceback) => {
94                shortest_timeout.push(Priority(Reverse(duration), (id, traceback)));
95                active_timeouts.insert(id);
96            },
97            TimeoutRequest::Cancel(id) => {
98                active_timeouts.remove(&id);
99            },
100        }
101    }
102}
103
104pub fn schedule_polars_timeout(traceback: Option<String>) -> Option<u64> {
105    static TIMEOUT_ID: RelaxedCell<u64> = RelaxedCell::new_u64(0);
106
107    let timeout = get_timeout()?;
108    let id = TIMEOUT_ID.fetch_add(1);
109    TIMEOUT_REQUEST_HANDLER
110        .send(TimeoutRequest::Start(timeout, id, traceback))
111        .unwrap();
112    Some(id)
113}
114
115pub fn cancel_polars_timeout(opt_id: Option<u64>) {
116    if let Some(id) = opt_id {
117        TIMEOUT_REQUEST_HANDLER
118            .send(TimeoutRequest::Cancel(id))
119            .unwrap();
120    }
121}