datafusion_dft/execution/executor/
io.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::{
19    cell::RefCell,
20    pin::Pin,
21    task::{Context, Poll},
22};
23
24use futures::{Future, FutureExt};
25use tokio::{runtime::Handle, task::JoinHandle};
26
27thread_local! {
28    /// Tokio runtime `Handle` for doing network (I/O) operations, see [`spawn_io`]
29    pub static IO_RUNTIME: RefCell<Option<Handle>> = const { RefCell::new(None) };
30}
31
32/// Registers `handle` as the IO runtime for this thread
33///
34/// See [`spawn_io`]
35pub fn register_io_runtime(handle: Option<Handle>) {
36    IO_RUNTIME.set(handle)
37}
38
39/// Runs `fut` on the runtime registered by [`register_io_runtime`] if any,
40/// otherwise awaits on the current thread
41///
42/// # Panic
43/// Needs a IO runtime [registered](register_io_runtime).
44pub async fn spawn_io<Fut>(fut: Fut) -> Fut::Output
45where
46    Fut: Future + Send + 'static,
47    Fut::Output: Send,
48{
49    let h = IO_RUNTIME.with_borrow(|h| h.clone()).expect(
50        "No IO runtime registered. If you hit this panic, it likely \
51            means a DataFusion plan or other CPU bound work is running on the \
52            a tokio threadpool used for IO. Try spawning the work using \
53            `DedicatedExcutor::spawn` or for tests `register_current_runtime_for_io`",
54    );
55    DropGuard(h.spawn(fut)).await
56}
57
58struct DropGuard<T>(JoinHandle<T>);
59impl<T> Drop for DropGuard<T> {
60    fn drop(&mut self) {
61        self.0.abort()
62    }
63}
64
65impl<T> Future for DropGuard<T> {
66    type Output = T;
67
68    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
69        Poll::Ready(match std::task::ready!(self.0.poll_unpin(cx)) {
70            Ok(v) => v,
71            Err(e) if e.is_cancelled() => panic!("IO runtime was shut down"),
72            Err(e) => std::panic::resume_unwind(e.into_panic()),
73        })
74    }
75}
76
77#[cfg(test)]
78mod tests {
79    use std::time::Duration;
80
81    use super::*;
82
83    #[tokio::test]
84    async fn test_happy_path() {
85        let rt_io = tokio::runtime::Builder::new_multi_thread()
86            .worker_threads(1)
87            .enable_all()
88            .build()
89            .unwrap();
90
91        let io_thread_id = rt_io
92            .spawn(async move { std::thread::current().id() })
93            .await
94            .unwrap();
95        let parent_thread_id = std::thread::current().id();
96        assert_ne!(io_thread_id, parent_thread_id);
97
98        register_io_runtime(Some(rt_io.handle().clone()));
99
100        let measured_thread_id = spawn_io(async move { std::thread::current().id() }).await;
101        assert_eq!(measured_thread_id, io_thread_id);
102
103        rt_io.shutdown_background();
104    }
105
106    #[tokio::test]
107    #[should_panic(expected = "IO runtime registered")]
108    async fn test_panic_if_no_runtime_registered() {
109        spawn_io(futures::future::ready(())).await;
110    }
111
112    #[tokio::test]
113    #[should_panic(expected = "IO runtime was shut down")]
114    async fn test_io_runtime_down() {
115        let rt_io = tokio::runtime::Builder::new_multi_thread()
116            .worker_threads(1)
117            .enable_all()
118            .build()
119            .unwrap();
120
121        register_io_runtime(Some(rt_io.handle().clone()));
122
123        tokio::task::spawn_blocking(move || {
124            rt_io.shutdown_timeout(Duration::from_secs(1));
125        })
126        .await
127        .unwrap();
128
129        spawn_io(futures::future::ready(())).await;
130    }
131}