Skip to main content

tauri/
async_runtime.rs

1// Copyright 2019-2024 Tauri Programme within The Commons Conservancy
2// SPDX-License-Identifier: Apache-2.0
3// SPDX-License-Identifier: MIT
4
5//! The singleton async runtime used by Tauri and exposed to users.
6//!
7//! Tauri uses [`tokio`] Runtime to initialize code, such as
8//! [`Plugin::initialize`](../plugin/trait.Plugin.html#method.initialize) and [`crate::Builder::setup`] hooks.
9//! This module also re-export some common items most developers need from [`tokio`]. If there's
10//! one you need isn't here, you could use types in [`tokio`] directly.
11//! For custom command handlers, it's recommended to use a plain `async fn` command.
12
13pub use tokio::{
14  runtime::{Handle as TokioHandle, Runtime as TokioRuntime},
15  sync::{
16    mpsc::{channel, Receiver, Sender},
17    Mutex, RwLock,
18  },
19  task::JoinHandle as TokioJoinHandle,
20};
21
22use std::{
23  future::Future,
24  pin::Pin,
25  sync::OnceLock,
26  task::{Context, Poll},
27};
28
29static RUNTIME: OnceLock<GlobalRuntime> = OnceLock::new();
30
31struct GlobalRuntime {
32  runtime: Option<Runtime>,
33  handle: RuntimeHandle,
34}
35
36impl GlobalRuntime {
37  fn handle(&self) -> RuntimeHandle {
38    if let Some(r) = &self.runtime {
39      r.handle()
40    } else {
41      self.handle.clone()
42    }
43  }
44
45  #[track_caller]
46  fn spawn<F>(&self, task: F) -> JoinHandle<F::Output>
47  where
48    F: Future + Send + 'static,
49    F::Output: Send + 'static,
50  {
51    if let Some(r) = &self.runtime {
52      r.spawn(task)
53    } else {
54      self.handle.spawn(task)
55    }
56  }
57
58  #[track_caller]
59  pub fn spawn_blocking<F, R>(&self, func: F) -> JoinHandle<R>
60  where
61    F: FnOnce() -> R + Send + 'static,
62    R: Send + 'static,
63  {
64    if let Some(r) = &self.runtime {
65      r.spawn_blocking(func)
66    } else {
67      self.handle.spawn_blocking(func)
68    }
69  }
70
71  #[track_caller]
72  fn block_on<F: Future>(&self, task: F) -> F::Output {
73    if let Some(r) = &self.runtime {
74      r.block_on(task)
75    } else {
76      self.handle.block_on(task)
77    }
78  }
79}
80
81/// A runtime used to execute asynchronous tasks.
82pub enum Runtime {
83  /// The tokio runtime.
84  Tokio(TokioRuntime),
85}
86
87impl Runtime {
88  /// Gets a reference to the [`TokioRuntime`].
89  pub fn inner(&self) -> &TokioRuntime {
90    let Self::Tokio(r) = self;
91    r
92  }
93
94  /// Returns a handle of the async runtime.
95  pub fn handle(&self) -> RuntimeHandle {
96    match self {
97      Self::Tokio(r) => RuntimeHandle::Tokio(r.handle().clone()),
98    }
99  }
100
101  #[track_caller]
102  /// Spawns a future onto the runtime.
103  pub fn spawn<F>(&self, task: F) -> JoinHandle<F::Output>
104  where
105    F: Future + Send + 'static,
106    F::Output: Send + 'static,
107  {
108    match self {
109      Self::Tokio(r) => {
110        let _guard = r.enter();
111        JoinHandle::Tokio(tokio::spawn(task))
112      }
113    }
114  }
115
116  #[track_caller]
117  /// Runs the provided function on an executor dedicated to blocking operations.
118  pub fn spawn_blocking<F, R>(&self, func: F) -> JoinHandle<R>
119  where
120    F: FnOnce() -> R + Send + 'static,
121    R: Send + 'static,
122  {
123    match self {
124      Self::Tokio(r) => JoinHandle::Tokio(r.spawn_blocking(func)),
125    }
126  }
127
128  #[track_caller]
129  /// Runs a future to completion on runtime.
130  pub fn block_on<F: Future>(&self, task: F) -> F::Output {
131    match self {
132      Self::Tokio(r) => r.block_on(task),
133    }
134  }
135}
136
137/// An owned permission to join on a task (await its termination).
138#[derive(Debug)]
139pub enum JoinHandle<T> {
140  /// The tokio JoinHandle.
141  Tokio(TokioJoinHandle<T>),
142}
143
144impl<T> JoinHandle<T> {
145  /// Gets a reference to the [`TokioJoinHandle`].
146  pub fn inner(&self) -> &TokioJoinHandle<T> {
147    let Self::Tokio(t) = self;
148    t
149  }
150
151  /// Abort the task associated with the handle.
152  ///
153  /// Awaiting a cancelled task might complete as usual if the task was
154  /// already completed at the time it was cancelled, but most likely it
155  /// will fail with a cancelled `JoinError`.
156  pub fn abort(&self) {
157    match self {
158      Self::Tokio(t) => t.abort(),
159    }
160  }
161}
162
163impl<T> Future for JoinHandle<T> {
164  type Output = crate::Result<T>;
165  fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
166    match self.get_mut() {
167      Self::Tokio(t) => Pin::new(t).poll(cx).map_err(Into::into),
168    }
169  }
170}
171
172/// A handle to the async runtime
173#[derive(Clone)]
174pub enum RuntimeHandle {
175  /// The tokio handle.
176  Tokio(TokioHandle),
177}
178
179impl RuntimeHandle {
180  /// Gets a reference to the [`TokioHandle`].
181  pub fn inner(&self) -> &TokioHandle {
182    let Self::Tokio(h) = self;
183    h
184  }
185
186  #[track_caller]
187  /// Runs the provided function on an executor dedicated to blocking operations.
188  pub fn spawn_blocking<F, R>(&self, func: F) -> JoinHandle<R>
189  where
190    F: FnOnce() -> R + Send + 'static,
191    R: Send + 'static,
192  {
193    match self {
194      Self::Tokio(h) => JoinHandle::Tokio(h.spawn_blocking(func)),
195    }
196  }
197
198  #[track_caller]
199  /// Spawns a future onto the runtime.
200  pub fn spawn<F>(&self, task: F) -> JoinHandle<F::Output>
201  where
202    F: Future + Send + 'static,
203    F::Output: Send + 'static,
204  {
205    match self {
206      Self::Tokio(h) => {
207        let _guard = h.enter();
208        JoinHandle::Tokio(tokio::spawn(task))
209      }
210    }
211  }
212
213  #[track_caller]
214  /// Runs a future to completion on runtime.
215  pub fn block_on<F: Future>(&self, task: F) -> F::Output {
216    match self {
217      Self::Tokio(h) => h.block_on(task),
218    }
219  }
220}
221
222fn default_runtime() -> GlobalRuntime {
223  let runtime = Runtime::Tokio(TokioRuntime::new().unwrap());
224  let handle = runtime.handle();
225  GlobalRuntime {
226    runtime: Some(runtime),
227    handle,
228  }
229}
230
231/// Sets the runtime to use to execute asynchronous tasks.
232/// For convenience, this method takes a [`TokioHandle`].
233/// Note that you cannot drop the underlying [`TokioRuntime`].
234///
235/// # Examples
236///
237/// ```rust
238/// #[tokio::main]
239/// async fn main() {
240///   // perform some async task before initializing the app
241///   do_something().await;
242///   // share the current runtime with Tauri
243///   tauri::async_runtime::set(tokio::runtime::Handle::current());
244///
245///   // bootstrap the tauri app...
246///   // tauri::Builder::default().run().unwrap();
247/// }
248///
249/// async fn do_something() {}
250/// ```
251///
252/// # Panics
253///
254/// Panics if the runtime is already set.
255pub fn set(handle: TokioHandle) {
256  RUNTIME
257    .set(GlobalRuntime {
258      runtime: None,
259      handle: RuntimeHandle::Tokio(handle),
260    })
261    .unwrap_or_else(|_| panic!("runtime already initialized"))
262}
263
264/// Returns a handle of the async runtime.
265pub fn handle() -> RuntimeHandle {
266  let runtime = RUNTIME.get_or_init(default_runtime);
267  runtime.handle()
268}
269
270#[track_caller]
271/// Runs a future to completion on runtime.
272pub fn block_on<F: Future>(task: F) -> F::Output {
273  let runtime = RUNTIME.get_or_init(default_runtime);
274  runtime.block_on(task)
275}
276
277#[track_caller]
278/// Spawns a future onto the runtime.
279pub fn spawn<F>(task: F) -> JoinHandle<F::Output>
280where
281  F: Future + Send + 'static,
282  F::Output: Send + 'static,
283{
284  let runtime = RUNTIME.get_or_init(default_runtime);
285  runtime.spawn(task)
286}
287
288#[track_caller]
289/// Runs the provided function on an executor dedicated to blocking operations.
290pub fn spawn_blocking<F, R>(func: F) -> JoinHandle<R>
291where
292  F: FnOnce() -> R + Send + 'static,
293  R: Send + 'static,
294{
295  let runtime = RUNTIME.get_or_init(default_runtime);
296  runtime.spawn_blocking(func)
297}
298
299#[track_caller]
300#[allow(dead_code)]
301pub(crate) fn safe_block_on<F>(task: F) -> F::Output
302where
303  F: Future + Send + 'static,
304  F::Output: Send + 'static,
305{
306  if let Ok(handle) = tokio::runtime::Handle::try_current() {
307    let (tx, rx) = std::sync::mpsc::sync_channel(1);
308    let handle_ = handle.clone();
309    handle.spawn_blocking(move || {
310      tx.send(handle_.block_on(task)).unwrap();
311    });
312    rx.recv().unwrap()
313  } else {
314    block_on(task)
315  }
316}
317
318#[cfg(test)]
319mod tests {
320  use super::*;
321
322  #[tokio::test]
323  async fn runtime_spawn() {
324    let join = spawn(async { 5 });
325    assert_eq!(join.await.unwrap(), 5);
326  }
327
328  #[test]
329  fn runtime_block_on() {
330    assert_eq!(block_on(async { 0 }), 0);
331  }
332
333  #[tokio::test]
334  async fn handle_spawn() {
335    let handle = handle();
336    let join = handle.spawn(async { 5 });
337    assert_eq!(join.await.unwrap(), 5);
338  }
339
340  #[test]
341  fn handle_block_on() {
342    let handle = handle();
343    assert_eq!(handle.block_on(async { 0 }), 0);
344  }
345
346  #[tokio::test]
347  async fn handle_abort() {
348    let handle = handle();
349    let join = handle.spawn(async {
350      // Here we sleep 1 second to ensure this task to be uncompleted when abort() invoked.
351      tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
352      5
353    });
354    join.abort();
355    if let crate::Error::JoinError(raw_error) = join.await.unwrap_err() {
356      assert!(raw_error.is_cancelled());
357    } else {
358      panic!("Abort did not result in the expected `JoinError`");
359    }
360  }
361}