foyer_common/
spawn.rs

1// Copyright 2026 foyer Project Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    fmt::Debug,
17    mem::ManuallyDrop,
18    ops::{Deref, DerefMut},
19    sync::Arc,
20};
21
22use tokio::{
23    runtime::{Handle, Runtime},
24    task::JoinHandle,
25};
26
27use crate::error::{Error, ErrorKind, Result};
28
29/// A wrapper around [`Runtime`] that shuts down the runtime in the background when dropped.
30///
31/// This is necessary because directly dropping a nested runtime is not allowed in a parent runtime.
32pub struct BackgroundShutdownRuntime(ManuallyDrop<Runtime>);
33
34impl Debug for BackgroundShutdownRuntime {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        f.debug_tuple("BackgroundShutdownRuntime").finish()
37    }
38}
39
40impl Drop for BackgroundShutdownRuntime {
41    fn drop(&mut self) {
42        // Safety: The runtime is only dropped once here.
43        let runtime = unsafe { ManuallyDrop::take(&mut self.0) };
44
45        #[cfg(madsim)]
46        drop(runtime);
47        #[cfg(not(madsim))]
48        runtime.shutdown_background();
49    }
50}
51
52impl Deref for BackgroundShutdownRuntime {
53    type Target = Runtime;
54
55    fn deref(&self) -> &Self::Target {
56        &self.0
57    }
58}
59
60impl DerefMut for BackgroundShutdownRuntime {
61    fn deref_mut(&mut self) -> &mut Self::Target {
62        &mut self.0
63    }
64}
65
66impl From<Runtime> for BackgroundShutdownRuntime {
67    fn from(runtime: Runtime) -> Self {
68        Self(ManuallyDrop::new(runtime))
69    }
70}
71
72/// A wrapper for [`JoinHandle`].
73#[derive(Debug)]
74pub struct SpawnHandle<T> {
75    inner: JoinHandle<T>,
76}
77
78impl<T> std::future::Future for SpawnHandle<T> {
79    type Output = Result<T>;
80
81    fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
82        match std::pin::Pin::new(&mut self.inner).poll(cx) {
83            std::task::Poll::Ready(res) => match res {
84                Ok(v) => std::task::Poll::Ready(Ok(v)),
85                Err(e) => std::task::Poll::Ready(Err(Error::new(ErrorKind::Join, "tokio join error").with_source(e))),
86            },
87            std::task::Poll::Pending => std::task::Poll::Pending,
88        }
89    }
90}
91
92/// A wrapper around a dedicated tokio runtime or handle to spawn tasks.
93#[derive(Debug, Clone)]
94pub enum Spawner {
95    /// A dedicated runtime to spawn tasks.
96    Runtime(Arc<BackgroundShutdownRuntime>),
97    /// A handle to spawn tasks.
98    Handle(Handle),
99}
100
101impl From<Runtime> for Spawner {
102    fn from(runtime: Runtime) -> Self {
103        Self::Runtime(Arc::new(runtime.into()))
104    }
105}
106
107impl From<Handle> for Spawner {
108    fn from(handle: Handle) -> Self {
109        Self::Handle(handle)
110    }
111}
112
113impl Spawner {
114    /// Wrapper for [`Runtime::spawn`] or [`Handle::spawn`].
115    pub fn spawn<F>(&self, future: F) -> SpawnHandle<<F as std::future::Future>::Output>
116    where
117        F: std::future::Future + Send + 'static,
118        F::Output: Send + 'static,
119    {
120        let inner = match self {
121            Spawner::Runtime(rt) => rt.spawn(future),
122            Spawner::Handle(h) => h.spawn(future),
123        };
124        SpawnHandle { inner }
125    }
126
127    /// Wrapper for [`Runtime::spawn_blocking`] or [`Handle::spawn_blocking`].
128    pub fn spawn_blocking<F, R>(&self, func: F) -> SpawnHandle<R>
129    where
130        F: FnOnce() -> R + Send + 'static,
131        R: Send + 'static,
132    {
133        let inner = match self {
134            Spawner::Runtime(rt) => rt.spawn_blocking(func),
135            Spawner::Handle(h) => h.spawn_blocking(func),
136        };
137        SpawnHandle { inner }
138    }
139
140    /// Get the current spawner.
141    pub fn current() -> Self {
142        Spawner::Handle(Handle::current())
143    }
144}