redis/aio/
runtime.rs

1use std::{io, sync::Arc, time::Duration};
2
3use futures_util::Future;
4
5#[cfg(all(feature = "tokio-comp", feature = "smol-comp"))]
6use std::sync::OnceLock;
7
8#[cfg(feature = "smol-comp")]
9use super::smol as crate_smol;
10#[cfg(feature = "tokio-comp")]
11use super::tokio as crate_tokio;
12use super::RedisRuntime;
13use crate::errors::RedisError;
14#[cfg(feature = "smol-comp")]
15use smol_timeout::TimeoutExt;
16
17#[derive(Clone, Copy, Debug)]
18pub(crate) enum Runtime {
19    #[cfg(feature = "tokio-comp")]
20    Tokio,
21    #[cfg(feature = "smol-comp")]
22    Smol,
23}
24
25pub(crate) enum TaskHandle {
26    #[cfg(feature = "tokio-comp")]
27    Tokio(tokio::task::JoinHandle<()>),
28    #[cfg(feature = "smol-comp")]
29    Smol(smol::Task<()>),
30}
31
32impl TaskHandle {
33    #[cfg(feature = "connection-manager")]
34    pub(crate) fn detach(self) {
35        match self {
36            #[cfg(feature = "smol-comp")]
37            TaskHandle::Smol(task) => task.detach(),
38            #[cfg(feature = "tokio-comp")]
39            _ => {}
40        }
41    }
42}
43
44pub(crate) struct HandleContainer(Option<TaskHandle>);
45
46impl HandleContainer {
47    pub(crate) fn new(handle: TaskHandle) -> Self {
48        Self(Some(handle))
49    }
50}
51
52impl Drop for HandleContainer {
53    fn drop(&mut self) {
54        match self.0.take() {
55            None => {}
56            #[cfg(feature = "tokio-comp")]
57            Some(TaskHandle::Tokio(handle)) => handle.abort(),
58            #[cfg(feature = "smol-comp")]
59            Some(TaskHandle::Smol(task)) => drop(task),
60        }
61    }
62}
63
64#[derive(Clone)]
65// we allow dead code here because the container isn't used directly, only in the derived drop.
66#[allow(dead_code)]
67pub(crate) struct SharedHandleContainer(Arc<HandleContainer>);
68
69impl SharedHandleContainer {
70    pub(crate) fn new(handle: TaskHandle) -> Self {
71        Self(Arc::new(HandleContainer::new(handle)))
72    }
73}
74
75#[cfg(all(feature = "tokio-comp", feature = "smol-comp"))]
76static CHOSEN_RUNTIME: OnceLock<Runtime> = OnceLock::new();
77
78#[cfg(all(feature = "tokio-comp", feature = "smol-comp"))]
79fn set_runtime(runtime: Runtime) -> Result<(), RedisError> {
80    const PREFER_RUNTIME_ERROR: &str =
81    "Another runtime preference was already set. Please call this function before any other runtime preference is set.";
82
83    CHOSEN_RUNTIME
84        .set(runtime)
85        .map_err(|_| RedisError::from((crate::ErrorKind::Client, PREFER_RUNTIME_ERROR)))
86}
87
88/// Mark Smol as the preferred runtime.
89///
90/// If the function returns `Err`, another runtime preference was already set, and won't be changed.
91/// Call this function if the application doesn't use multiple runtimes,
92/// but the crate is compiled with multiple runtimes enabled, which is a bad pattern that should be avoided.
93#[cfg(all(feature = "smol-comp", feature = "tokio-comp"))]
94pub fn prefer_smol() -> Result<(), RedisError> {
95    set_runtime(Runtime::Smol)
96}
97
98/// Mark Tokio as the preferred runtime.
99///
100/// If the function returns `Err`, another runtime preference was already set, and won't be changed.
101/// Call this function if the application doesn't use multiple runtimes,
102/// but the crate is compiled with multiple runtimes enabled, which is a bad pattern that should be avoided.
103#[cfg(all(feature = "smol-comp", feature = "tokio-comp"))]
104pub fn prefer_tokio() -> Result<(), RedisError> {
105    set_runtime(Runtime::Tokio)
106}
107
108impl Runtime {
109    pub(crate) fn locate() -> Self {
110        #[cfg(all(feature = "smol-comp", feature = "tokio-comp"))]
111        if let Some(runtime) = CHOSEN_RUNTIME.get() {
112            return *runtime;
113        }
114
115        #[cfg(all(feature = "tokio-comp", not(feature = "smol-comp")))]
116        {
117            Runtime::Tokio
118        }
119
120        #[cfg(all(not(feature = "tokio-comp"), feature = "smol-comp",))]
121        {
122            Runtime::Smol
123        }
124
125        cfg_if::cfg_if! {
126            if #[cfg(all(feature = "tokio-comp", feature = "smol-comp"))] {
127                if ::tokio::runtime::Handle::try_current().is_ok() {
128                    Runtime::Tokio
129                } else {
130                    Runtime::Smol
131                }
132            }
133        }
134
135        #[cfg(all(not(feature = "tokio-comp"), not(feature = "smol-comp")))]
136        {
137            compile_error!("tokio-comp or smol-comp features required for aio feature")
138        }
139    }
140
141    #[must_use]
142    pub(crate) fn spawn(&self, f: impl Future<Output = ()> + Send + 'static) -> TaskHandle {
143        match self {
144            #[cfg(feature = "tokio-comp")]
145            Runtime::Tokio => crate_tokio::Tokio::spawn(f),
146            #[cfg(feature = "smol-comp")]
147            Runtime::Smol => crate_smol::Smol::spawn(f),
148        }
149    }
150
151    pub(crate) async fn timeout<F: Future>(
152        &self,
153        duration: Duration,
154        future: F,
155    ) -> Result<F::Output, Elapsed> {
156        match self {
157            #[cfg(feature = "tokio-comp")]
158            Runtime::Tokio => tokio::time::timeout(duration, future)
159                .await
160                .map_err(|_| Elapsed(())),
161            #[cfg(feature = "smol-comp")]
162            Runtime::Smol => future.timeout(duration).await.ok_or(Elapsed(())),
163        }
164    }
165
166    #[cfg(any(feature = "connection-manager", feature = "cluster-async"))]
167    pub(crate) async fn sleep(&self, duration: Duration) {
168        match self {
169            #[cfg(feature = "tokio-comp")]
170            Runtime::Tokio => {
171                tokio::time::sleep(duration).await;
172            }
173
174            #[cfg(feature = "smol-comp")]
175            Runtime::Smol => {
176                smol::Timer::after(duration).await;
177            }
178        }
179    }
180
181    #[cfg(feature = "cluster-async")]
182    pub(crate) async fn locate_and_sleep(duration: Duration) {
183        Self::locate().sleep(duration).await
184    }
185}
186
187#[derive(Debug)]
188pub(crate) struct Elapsed(());
189
190impl From<Elapsed> for RedisError {
191    fn from(_: Elapsed) -> Self {
192        io::Error::from(io::ErrorKind::TimedOut).into()
193    }
194}