datafusion_dft/execution/executor/
io.rs1use std::{
19 cell::RefCell,
20 pin::Pin,
21 task::{Context, Poll},
22};
23
24use futures::{Future, FutureExt};
25use tokio::{runtime::Handle, task::JoinHandle};
26
27thread_local! {
28 pub static IO_RUNTIME: RefCell<Option<Handle>> = const { RefCell::new(None) };
30}
31
32pub fn register_io_runtime(handle: Option<Handle>) {
36 IO_RUNTIME.set(handle)
37}
38
39pub async fn spawn_io<Fut>(fut: Fut) -> Fut::Output
45where
46 Fut: Future + Send + 'static,
47 Fut::Output: Send,
48{
49 let h = IO_RUNTIME.with_borrow(|h| h.clone()).expect(
50 "No IO runtime registered. If you hit this panic, it likely \
51 means a DataFusion plan or other CPU bound work is running on the \
52 a tokio threadpool used for IO. Try spawning the work using \
53 `DedicatedExcutor::spawn` or for tests `register_current_runtime_for_io`",
54 );
55 DropGuard(h.spawn(fut)).await
56}
57
58struct DropGuard<T>(JoinHandle<T>);
59impl<T> Drop for DropGuard<T> {
60 fn drop(&mut self) {
61 self.0.abort()
62 }
63}
64
65impl<T> Future for DropGuard<T> {
66 type Output = T;
67
68 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
69 Poll::Ready(match std::task::ready!(self.0.poll_unpin(cx)) {
70 Ok(v) => v,
71 Err(e) if e.is_cancelled() => panic!("IO runtime was shut down"),
72 Err(e) => std::panic::resume_unwind(e.into_panic()),
73 })
74 }
75}
76
77#[cfg(test)]
78mod tests {
79 use std::time::Duration;
80
81 use super::*;
82
83 #[tokio::test]
84 async fn test_happy_path() {
85 let rt_io = tokio::runtime::Builder::new_multi_thread()
86 .worker_threads(1)
87 .enable_all()
88 .build()
89 .unwrap();
90
91 let io_thread_id = rt_io
92 .spawn(async move { std::thread::current().id() })
93 .await
94 .unwrap();
95 let parent_thread_id = std::thread::current().id();
96 assert_ne!(io_thread_id, parent_thread_id);
97
98 register_io_runtime(Some(rt_io.handle().clone()));
99
100 let measured_thread_id = spawn_io(async move { std::thread::current().id() }).await;
101 assert_eq!(measured_thread_id, io_thread_id);
102
103 rt_io.shutdown_background();
104 }
105
106 #[tokio::test]
107 #[should_panic(expected = "IO runtime registered")]
108 async fn test_panic_if_no_runtime_registered() {
109 spawn_io(futures::future::ready(())).await;
110 }
111
112 #[tokio::test]
113 #[should_panic(expected = "IO runtime was shut down")]
114 async fn test_io_runtime_down() {
115 let rt_io = tokio::runtime::Builder::new_multi_thread()
116 .worker_threads(1)
117 .enable_all()
118 .build()
119 .unwrap();
120
121 register_io_runtime(Some(rt_io.handle().clone()));
122
123 tokio::task::spawn_blocking(move || {
124 rt_io.shutdown_timeout(Duration::from_secs(1));
125 })
126 .await
127 .unwrap();
128
129 spawn_io(futures::future::ready(())).await;
130 }
131}