1use std::cmp::Reverse;
4use std::collections::BinaryHeap;
5use std::sync::LazyLock;
6use std::sync::mpsc::{Receiver, RecvTimeoutError, Sender, channel};
7use std::time::Duration;
8
9use polars::prelude::{InitHashMaps, PlHashSet};
10use polars_utils::priority::Priority;
11use polars_utils::relaxed_cell::RelaxedCell;
12
13static TIMEOUT_REQUEST_HANDLER: LazyLock<Sender<TimeoutRequest>> = LazyLock::new(|| {
14 let (send, recv) = channel();
15 std::thread::Builder::new()
16 .name("polars-timeout".to_string())
17 .spawn(move || timeout_thread(recv))
18 .unwrap();
19 send
20});
21
22enum TimeoutRequest {
23 Start(Duration, u64, Option<String>),
24 Cancel(u64),
25}
26
27pub fn is_timeout_enabled() -> bool {
28 static TIMEOUT_DISABLED: RelaxedCell<bool> = RelaxedCell::new_bool(false);
29
30 if TIMEOUT_DISABLED.load() {
34 return false;
35 }
36
37 let var = std::env::var("POLARS_TIMEOUT_MS").ok();
38 if var.is_none_or(|v| v.is_empty()) {
39 TIMEOUT_DISABLED.store(true);
40 return false;
41 }
42
43 true
44}
45
46pub fn get_timeout() -> Option<Duration> {
47 if !is_timeout_enabled() {
48 return None;
49 }
50
51 match std::env::var("POLARS_TIMEOUT_MS").unwrap().parse() {
52 Ok(ms) => Some(Duration::from_millis(ms)),
53 Err(e) => {
54 eprintln!("failed to parse POLARS_TIMEOUT_MS: {e:?}");
55 None
56 },
57 }
58}
59
60fn timeout_thread(recv: Receiver<TimeoutRequest>) {
61 let mut active_timeouts: PlHashSet<u64> = PlHashSet::new();
62 #[allow(clippy::type_complexity)]
63 let mut shortest_timeout: BinaryHeap<Priority<Reverse<Duration>, (u64, Option<String>)>> =
64 BinaryHeap::new();
65 loop {
66 while let Some(Priority(_, (id, _))) = shortest_timeout.peek() {
68 if active_timeouts.contains(id) {
69 break;
70 }
71 shortest_timeout.pop();
72 }
73
74 let request = if let Some(Priority(timeout, (_, traceback))) = shortest_timeout.peek() {
75 match recv.recv_timeout(timeout.0) {
76 Err(RecvTimeoutError::Timeout) => {
77 eprint!("exiting the process, POLARS_TIMEOUT_MS exceeded");
78 if let Some(tb) = traceback {
79 eprintln!(", traceback:\n{tb}");
80 } else {
81 eprintln!(", traceback unavailable");
82 }
83 std::thread::sleep(Duration::from_secs_f64(1.0));
84 std::process::exit(1);
85 },
86 r => r.unwrap(),
87 }
88 } else {
89 recv.recv().unwrap()
90 };
91
92 match request {
93 TimeoutRequest::Start(duration, id, traceback) => {
94 shortest_timeout.push(Priority(Reverse(duration), (id, traceback)));
95 active_timeouts.insert(id);
96 },
97 TimeoutRequest::Cancel(id) => {
98 active_timeouts.remove(&id);
99 },
100 }
101 }
102}
103
104pub fn schedule_polars_timeout(traceback: Option<String>) -> Option<u64> {
105 static TIMEOUT_ID: RelaxedCell<u64> = RelaxedCell::new_u64(0);
106
107 let timeout = get_timeout()?;
108 let id = TIMEOUT_ID.fetch_add(1);
109 TIMEOUT_REQUEST_HANDLER
110 .send(TimeoutRequest::Start(timeout, id, traceback))
111 .unwrap();
112 Some(id)
113}
114
115pub fn cancel_polars_timeout(opt_id: Option<u64>) {
116 if let Some(id) = opt_id {
117 TIMEOUT_REQUEST_HANDLER
118 .send(TimeoutRequest::Cancel(id))
119 .unwrap();
120 }
121}