1use std::{io, sync::Arc, time::Duration};
2
3use futures_util::Future;
4
5#[cfg(all(feature = "tokio-comp", feature = "smol-comp"))]
6use std::sync::OnceLock;
7
8#[cfg(feature = "smol-comp")]
9use super::smol as crate_smol;
10#[cfg(feature = "tokio-comp")]
11use super::tokio as crate_tokio;
12use super::RedisRuntime;
13use crate::errors::RedisError;
14#[cfg(feature = "smol-comp")]
15use smol_timeout::TimeoutExt;
16
17#[derive(Clone, Copy, Debug)]
18pub(crate) enum Runtime {
19 #[cfg(feature = "tokio-comp")]
20 Tokio,
21 #[cfg(feature = "smol-comp")]
22 Smol,
23}
24
25pub(crate) enum TaskHandle {
26 #[cfg(feature = "tokio-comp")]
27 Tokio(tokio::task::JoinHandle<()>),
28 #[cfg(feature = "smol-comp")]
29 Smol(smol::Task<()>),
30}
31
32impl TaskHandle {
33 #[cfg(feature = "connection-manager")]
34 pub(crate) fn detach(self) {
35 match self {
36 #[cfg(feature = "smol-comp")]
37 TaskHandle::Smol(task) => task.detach(),
38 #[cfg(feature = "tokio-comp")]
39 _ => {}
40 }
41 }
42}
43
44pub(crate) struct HandleContainer(Option<TaskHandle>);
45
46impl HandleContainer {
47 pub(crate) fn new(handle: TaskHandle) -> Self {
48 Self(Some(handle))
49 }
50}
51
52impl Drop for HandleContainer {
53 fn drop(&mut self) {
54 match self.0.take() {
55 None => {}
56 #[cfg(feature = "tokio-comp")]
57 Some(TaskHandle::Tokio(handle)) => handle.abort(),
58 #[cfg(feature = "smol-comp")]
59 Some(TaskHandle::Smol(task)) => drop(task),
60 }
61 }
62}
63
64#[derive(Clone)]
65#[allow(dead_code)]
67pub(crate) struct SharedHandleContainer(Arc<HandleContainer>);
68
69impl SharedHandleContainer {
70 pub(crate) fn new(handle: TaskHandle) -> Self {
71 Self(Arc::new(HandleContainer::new(handle)))
72 }
73}
74
75#[cfg(all(feature = "tokio-comp", feature = "smol-comp"))]
76static CHOSEN_RUNTIME: OnceLock<Runtime> = OnceLock::new();
77
78#[cfg(all(feature = "tokio-comp", feature = "smol-comp"))]
79fn set_runtime(runtime: Runtime) -> Result<(), RedisError> {
80 const PREFER_RUNTIME_ERROR: &str =
81 "Another runtime preference was already set. Please call this function before any other runtime preference is set.";
82
83 CHOSEN_RUNTIME
84 .set(runtime)
85 .map_err(|_| RedisError::from((crate::ErrorKind::Client, PREFER_RUNTIME_ERROR)))
86}
87
88#[cfg(all(feature = "smol-comp", feature = "tokio-comp"))]
94pub fn prefer_smol() -> Result<(), RedisError> {
95 set_runtime(Runtime::Smol)
96}
97
98#[cfg(all(feature = "smol-comp", feature = "tokio-comp"))]
104pub fn prefer_tokio() -> Result<(), RedisError> {
105 set_runtime(Runtime::Tokio)
106}
107
108impl Runtime {
109 pub(crate) fn locate() -> Self {
110 #[cfg(all(feature = "smol-comp", feature = "tokio-comp"))]
111 if let Some(runtime) = CHOSEN_RUNTIME.get() {
112 return *runtime;
113 }
114
115 #[cfg(all(feature = "tokio-comp", not(feature = "smol-comp")))]
116 {
117 Runtime::Tokio
118 }
119
120 #[cfg(all(not(feature = "tokio-comp"), feature = "smol-comp",))]
121 {
122 Runtime::Smol
123 }
124
125 cfg_if::cfg_if! {
126 if #[cfg(all(feature = "tokio-comp", feature = "smol-comp"))] {
127 if ::tokio::runtime::Handle::try_current().is_ok() {
128 Runtime::Tokio
129 } else {
130 Runtime::Smol
131 }
132 }
133 }
134
135 #[cfg(all(not(feature = "tokio-comp"), not(feature = "smol-comp")))]
136 {
137 compile_error!("tokio-comp or smol-comp features required for aio feature")
138 }
139 }
140
141 #[must_use]
142 pub(crate) fn spawn(&self, f: impl Future<Output = ()> + Send + 'static) -> TaskHandle {
143 match self {
144 #[cfg(feature = "tokio-comp")]
145 Runtime::Tokio => crate_tokio::Tokio::spawn(f),
146 #[cfg(feature = "smol-comp")]
147 Runtime::Smol => crate_smol::Smol::spawn(f),
148 }
149 }
150
151 pub(crate) async fn timeout<F: Future>(
152 &self,
153 duration: Duration,
154 future: F,
155 ) -> Result<F::Output, Elapsed> {
156 match self {
157 #[cfg(feature = "tokio-comp")]
158 Runtime::Tokio => tokio::time::timeout(duration, future)
159 .await
160 .map_err(|_| Elapsed(())),
161 #[cfg(feature = "smol-comp")]
162 Runtime::Smol => future.timeout(duration).await.ok_or(Elapsed(())),
163 }
164 }
165
166 #[cfg(any(feature = "connection-manager", feature = "cluster-async"))]
167 pub(crate) async fn sleep(&self, duration: Duration) {
168 match self {
169 #[cfg(feature = "tokio-comp")]
170 Runtime::Tokio => {
171 tokio::time::sleep(duration).await;
172 }
173
174 #[cfg(feature = "smol-comp")]
175 Runtime::Smol => {
176 smol::Timer::after(duration).await;
177 }
178 }
179 }
180
181 #[cfg(feature = "cluster-async")]
182 pub(crate) async fn locate_and_sleep(duration: Duration) {
183 Self::locate().sleep(duration).await
184 }
185}
186
187#[derive(Debug)]
188pub(crate) struct Elapsed(());
189
190impl From<Elapsed> for RedisError {
191 fn from(_: Elapsed) -> Self {
192 io::Error::from(io::ErrorKind::TimedOut).into()
193 }
194}