use std::{io, sync::Arc, time::Duration};
use futures_util::Future;
#[cfg(all(feature = "tokio-comp", feature = "smol-comp"))]
use std::sync::OnceLock;
use super::RedisRuntime;
#[cfg(feature = "smol-comp")]
use super::smol as crate_smol;
#[cfg(feature = "tokio-comp")]
use super::tokio as crate_tokio;
use crate::errors::RedisError;
#[cfg(feature = "smol-comp")]
use smol_timeout::TimeoutExt;
#[derive(Clone, Copy, Debug)]
pub(crate) enum Runtime {
#[cfg(feature = "tokio-comp")]
Tokio,
#[cfg(feature = "smol-comp")]
Smol,
}
pub(crate) enum TaskHandle {
#[cfg(feature = "tokio-comp")]
Tokio(tokio::task::JoinHandle<()>),
#[cfg(feature = "smol-comp")]
Smol(smol::Task<()>),
}
impl TaskHandle {
#[cfg(feature = "connection-manager")]
pub(crate) fn detach(self) {
match self {
#[cfg(feature = "smol-comp")]
TaskHandle::Smol(task) => task.detach(),
#[cfg(feature = "tokio-comp")]
_ => {}
}
}
}
pub(crate) struct HandleContainer(Option<TaskHandle>);
impl HandleContainer {
pub(crate) fn new(handle: TaskHandle) -> Self {
Self(Some(handle))
}
}
impl Drop for HandleContainer {
fn drop(&mut self) {
match self.0.take() {
None => {}
#[cfg(feature = "tokio-comp")]
Some(TaskHandle::Tokio(handle)) => handle.abort(),
#[cfg(feature = "smol-comp")]
Some(TaskHandle::Smol(task)) => drop(task),
}
}
}
#[derive(Clone)]
#[allow(dead_code)]
pub(crate) struct SharedHandleContainer(Arc<HandleContainer>);
impl SharedHandleContainer {
pub(crate) fn new(handle: TaskHandle) -> Self {
Self(Arc::new(HandleContainer::new(handle)))
}
}
#[cfg(all(feature = "tokio-comp", feature = "smol-comp"))]
static CHOSEN_RUNTIME: OnceLock<Runtime> = OnceLock::new();
#[cfg(all(feature = "tokio-comp", feature = "smol-comp"))]
fn set_runtime(runtime: Runtime) -> Result<(), RedisError> {
const PREFER_RUNTIME_ERROR: &str = "Another runtime preference was already set. Please call this function before any other runtime preference is set.";
CHOSEN_RUNTIME
.set(runtime)
.map_err(|_| RedisError::from((crate::ErrorKind::Client, PREFER_RUNTIME_ERROR)))
}
#[cfg(all(feature = "smol-comp", feature = "tokio-comp"))]
pub fn prefer_smol() -> Result<(), RedisError> {
set_runtime(Runtime::Smol)
}
#[cfg(all(feature = "smol-comp", feature = "tokio-comp"))]
pub fn prefer_tokio() -> Result<(), RedisError> {
set_runtime(Runtime::Tokio)
}
impl Runtime {
pub(crate) fn locate() -> Self {
#[cfg(all(feature = "smol-comp", feature = "tokio-comp"))]
if let Some(runtime) = CHOSEN_RUNTIME.get() {
return *runtime;
}
#[cfg(all(feature = "tokio-comp", not(feature = "smol-comp")))]
{
Runtime::Tokio
}
#[cfg(all(not(feature = "tokio-comp"), feature = "smol-comp",))]
{
Runtime::Smol
}
cfg_if::cfg_if! {
if #[cfg(all(feature = "tokio-comp", feature = "smol-comp"))] {
if ::tokio::runtime::Handle::try_current().is_ok() {
Runtime::Tokio
} else {
Runtime::Smol
}
}
}
#[cfg(all(not(feature = "tokio-comp"), not(feature = "smol-comp")))]
{
compile_error!("tokio-comp or smol-comp features required for aio feature")
}
}
#[must_use]
pub(crate) fn spawn(&self, f: impl Future<Output = ()> + Send + 'static) -> TaskHandle {
match self {
#[cfg(feature = "tokio-comp")]
Runtime::Tokio => crate_tokio::Tokio::spawn(f),
#[cfg(feature = "smol-comp")]
Runtime::Smol => crate_smol::Smol::spawn(f),
}
}
pub(crate) async fn timeout<F: Future>(
&self,
duration: Duration,
future: F,
) -> Result<F::Output, Elapsed> {
match self {
#[cfg(feature = "tokio-comp")]
Runtime::Tokio => tokio::time::timeout(duration, future)
.await
.map_err(|_| Elapsed(())),
#[cfg(feature = "smol-comp")]
Runtime::Smol => future.timeout(duration).await.ok_or(Elapsed(())),
}
}
#[cfg(any(feature = "connection-manager", feature = "cluster-async"))]
pub(crate) async fn sleep(&self, duration: Duration) {
match self {
#[cfg(feature = "tokio-comp")]
Runtime::Tokio => {
tokio::time::sleep(duration).await;
}
#[cfg(feature = "smol-comp")]
Runtime::Smol => {
smol::Timer::after(duration).await;
}
}
}
#[cfg(feature = "cluster-async")]
pub(crate) async fn locate_and_sleep(duration: Duration) {
Self::locate().sleep(duration).await
}
}
#[derive(Debug)]
pub(crate) struct Elapsed(());
impl From<Elapsed> for RedisError {
fn from(_: Elapsed) -> Self {
io::Error::from(io::ErrorKind::TimedOut).into()
}
}