polars_python/
timeout.rs

1//! A global process-aborting timeout system, mainly intended for testing.
2
3use std::cmp::Reverse;
4use std::collections::BinaryHeap;
5use std::sync::LazyLock;
6use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
7use std::sync::mpsc::{Receiver, RecvTimeoutError, Sender, channel};
8use std::time::Duration;
9
10use polars::prelude::{InitHashMaps, PlHashSet};
11use polars_utils::priority::Priority;
12
13static TIMEOUT_REQUEST_HANDLER: LazyLock<Sender<TimeoutRequest>> = LazyLock::new(|| {
14    let (send, recv) = channel();
15    std::thread::Builder::new()
16        .name("polars-timeout".to_string())
17        .spawn(move || timeout_thread(recv))
18        .unwrap();
19    send
20});
21
22enum TimeoutRequest {
23    Start(Duration, u64),
24    Cancel(u64),
25}
26
27pub fn get_timeout() -> Option<Duration> {
28    static TIMEOUT_DISABLED: AtomicBool = AtomicBool::new(false);
29
30    // Fast path so we don't have to keep checking environment variables. Make
31    // sure that if you want to use POLARS_TIMEOUT_MS it is set before the first
32    // polars call.
33    if TIMEOUT_DISABLED.load(Ordering::Relaxed) {
34        return None;
35    }
36
37    let Ok(timeout) = std::env::var("POLARS_TIMEOUT_MS") else {
38        TIMEOUT_DISABLED.store(true, Ordering::Relaxed);
39        return None;
40    };
41
42    match timeout.parse() {
43        Ok(ms) => Some(Duration::from_millis(ms)),
44        Err(e) => {
45            eprintln!("failed to parse POLARS_TIMEOUT_MS: {e:?}");
46            None
47        },
48    }
49}
50
51fn timeout_thread(recv: Receiver<TimeoutRequest>) {
52    let mut active_timeouts: PlHashSet<u64> = PlHashSet::new();
53    let mut shortest_timeout: BinaryHeap<Priority<Reverse<Duration>, u64>> = BinaryHeap::new();
54    loop {
55        // Remove cancelled requests.
56        while let Some(Priority(_, id)) = shortest_timeout.peek() {
57            if active_timeouts.contains(id) {
58                break;
59            }
60            shortest_timeout.pop();
61        }
62
63        let request = if let Some(Priority(timeout, _)) = shortest_timeout.peek() {
64            match recv.recv_timeout(timeout.0) {
65                Err(RecvTimeoutError::Timeout) => {
66                    eprintln!("exiting the process, POLARS_TIMEOUT_MS exceeded");
67                    std::thread::sleep(Duration::from_secs_f64(1.0));
68                    std::process::exit(1);
69                },
70                r => r.unwrap(),
71            }
72        } else {
73            recv.recv().unwrap()
74        };
75
76        match request {
77            TimeoutRequest::Start(duration, id) => {
78                shortest_timeout.push(Priority(Reverse(duration), id));
79                active_timeouts.insert(id);
80            },
81            TimeoutRequest::Cancel(id) => {
82                active_timeouts.remove(&id);
83            },
84        }
85    }
86}
87
88pub fn schedule_polars_timeout() -> Option<u64> {
89    static TIMEOUT_ID: AtomicU64 = AtomicU64::new(0);
90
91    let timeout = get_timeout()?;
92    let id = TIMEOUT_ID.fetch_add(1, Ordering::Relaxed);
93    TIMEOUT_REQUEST_HANDLER
94        .send(TimeoutRequest::Start(timeout, id))
95        .unwrap();
96    Some(id)
97}
98
99pub fn cancel_polars_timeout(opt_id: Option<u64>) {
100    if let Some(id) = opt_id {
101        TIMEOUT_REQUEST_HANDLER
102            .send(TimeoutRequest::Cancel(id))
103            .unwrap();
104    }
105}