use std::{
fmt::Debug,
mem::ManuallyDrop,
ops::{Deref, DerefMut},
sync::Arc,
};
use tokio::{
runtime::{Handle, Runtime},
task::JoinHandle,
};
use crate::error::{Error, ErrorKind, Result};
pub struct BackgroundShutdownRuntime(ManuallyDrop<Runtime>);
impl Debug for BackgroundShutdownRuntime {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("BackgroundShutdownRuntime").finish()
}
}
impl Drop for BackgroundShutdownRuntime {
fn drop(&mut self) {
let runtime = unsafe { ManuallyDrop::take(&mut self.0) };
#[cfg(madsim)]
drop(runtime);
#[cfg(not(madsim))]
runtime.shutdown_background();
}
}
impl Deref for BackgroundShutdownRuntime {
type Target = Runtime;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for BackgroundShutdownRuntime {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl From<Runtime> for BackgroundShutdownRuntime {
fn from(runtime: Runtime) -> Self {
Self(ManuallyDrop::new(runtime))
}
}
#[derive(Debug)]
pub struct SpawnHandle<T> {
inner: JoinHandle<T>,
}
impl<T> std::future::Future for SpawnHandle<T> {
type Output = Result<T>;
fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
match std::pin::Pin::new(&mut self.inner).poll(cx) {
std::task::Poll::Ready(res) => match res {
Ok(v) => std::task::Poll::Ready(Ok(v)),
Err(e) => std::task::Poll::Ready(Err(Error::new(ErrorKind::Join, "tokio join error").with_source(e))),
},
std::task::Poll::Pending => std::task::Poll::Pending,
}
}
}
#[derive(Debug, Clone)]
pub enum Spawner {
Runtime(Arc<BackgroundShutdownRuntime>),
Handle(Handle),
}
impl From<Runtime> for Spawner {
fn from(runtime: Runtime) -> Self {
Self::Runtime(Arc::new(runtime.into()))
}
}
impl From<Handle> for Spawner {
fn from(handle: Handle) -> Self {
Self::Handle(handle)
}
}
impl Spawner {
pub fn spawn<F>(&self, future: F) -> SpawnHandle<<F as std::future::Future>::Output>
where
F: std::future::Future + Send + 'static,
F::Output: Send + 'static,
{
let inner = match self {
Spawner::Runtime(rt) => rt.spawn(future),
Spawner::Handle(h) => h.spawn(future),
};
SpawnHandle { inner }
}
pub fn spawn_blocking<F, R>(&self, func: F) -> SpawnHandle<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
let inner = match self {
Spawner::Runtime(rt) => rt.spawn_blocking(func),
Spawner::Handle(h) => h.spawn_blocking(func),
};
SpawnHandle { inner }
}
pub fn current() -> Self {
Spawner::Handle(Handle::current())
}
}